GRPO=高度な拒否サンプリング?強化学習の解明の時:負のサンプル「選択と集中」が鍵!

图片

MLNLPコミュニティは、国内外の機械学習および自然言語処理の研究者によって共同設立された民間の学術コミュニティであり、現在では国内外でよく知られた機械学習および自然言語処理コミュニティに発展しています。その目的は、機械学習、自然言語処理の学術界、産業界、および広範な愛好家間の交流と進歩を促進することです。

コミュニティのビジョンは、国内外の自然言語処理、機械学習の学術界、産業界、そして広範な愛好家、特に初心者の学生たちの間の交流と進歩を促進することです。

出典 | PaperWeekly

現在、大規模言語モデル(LLM)を強化学習(RL)で訓練し、数学問題のような比較的複雑な推論タスクを実行することが広く行われています。PPOのようなアルゴリズムは主流ですが、追加のネットワーク(critic network)が必要で、その設定は複雑で手間がかかります。

一方で、GRPOのようなアルゴリズムは実用上非常に良い結果を出しています(DeepSeek-R1の訓練など)が、なぜこれほど効果的なのか、あるいはよりシンプルな手法よりも本当に優れているのかは、あまり明確ではありません。

その一方で、RAFT(拒否サンプリング)のような非常にシンプルな手法、つまりモデルが正解したサンプルのみを用いてファインチューニングする方法も、良い結果を出しているようです。これが研究者たちの疑問を呼びました:

1. これらの複雑なRLアルゴリズム(PPO、GRPOなど)と、シンプルなSFT系手法(RAFTなど)の差はどこにあるのか?本当にそこまで複雑にする必要があるのか?

2. GRPOが効果的なのは、アルゴリズム自体の設計(報酬の正規化など)によるものなのか、それともサンプル使用における特定の戦略(不正解サンプルをどう扱うかなど)によるものなのか?

3. LLMのように出力がテキストシーケンスであり、環境が比較的確定しているシナリオにおいて、より簡潔で適したRLアルゴリズムは利用できないのでしょうか?

したがって、この論文の出発点は、代表的なRL手法(特にGRPO、極めてシンプルなRAFT、および基本的なReinforce)を再検討・比較し、その成功の鍵となる要因、特に負のサンプル(モデルが間違った例)の適切な利用方法、そしてシンプルかつ効果的なRL訓練方法を見つけることにあります。

論文タイトル:

A Minimalist Approach to LLM Reasoning: from Rejection Sampling to Reinforce

論文URL:

https://arxiv.org/abs/2504.11343

GitHub URL:

https://github.com/rlhflow/minimal-rl

この論文の主な発見と貢献は以下の通りです:

極めてシンプルな手法の有効性を検証:研究者たちは、RAFTという非常にシンプルな拒否サンプリング手法(正解したサンプルのみで訓練)が、現在の一般的なGRPO手法とほぼ同等の効果を発揮し、訓練初期にはより速く収束することを発見しました。

これは、この種のタスクにおいて、シンプルな「良いものだけを学ぶ」戦略自体が非常に強力なベースラインとなることを示しています。しかし、RAFTは正のサンプルのみを使用するため、訓練後期にはモデルの探索性が低下し(エントロピーが急速に減少)、GRPOに追い越される可能性も指摘されています。

GRPOの優位性の源泉を解明:詳細な比較実験(アブレーション研究)を通じて、GRPOが標準的なReinforceアルゴリズムに対して優位性を持つ主な理由は、サンプル処理において「全ての回答が間違っていた」プロンプトを暗黙的にフィルタリングしていることにあると彼らは発見しました。

つまり、完全に失敗した例から学習することを避けることが、GRPOの効果向上における鍵です。対照的に、GRPOで用いられる、同じ質問に対する異なる回答の良し悪しに基づいて報酬を正規化する技術の影響は小さいです。これは、全ての負のサンプルが有用なわけではなく、一部の負のサンプル(全て間違っているもの)はむしろ足を引っ張る可能性があることを示しています。

新しい簡略化されたRLアルゴリズムを提案:上記の発見に基づき、彼らはReinforce-Rejという新しい手法を提案しました。この手法は、基本的なReinforceアルゴリズムに対する小さな改善であり、その核心思想は、「全ての回答が正しい」質問(簡単すぎる可能性がある)からも、「全ての回答が間違っている」質問(有害である可能性がある)からも学習せず、ただ「良いものと悪いものが混在する」質問からのみ学習するというものです。

実験の結果、このReinforce-Rej手法は最終的な性能がGRPOとほぼ同等であるものの、KL効率(モデル更新の度合いを測る尺度)が高く、訓練がより安定していることが示されました。

実践的なガイダンスを提供:総じて、この研究は、LLMを報酬でファインチューニングする際、「どの複雑なRLアルゴリズムを用いるか」よりも、「訓練サンプル(特に負のサンプル)をどのように選択し使用するか」がより重要である可能性を強調しています。彼らは、RAFTをシンプルで信頼できるベースラインとして推奨し、将来の研究では、負のサンプルを無差別に使用するのではなく、より深く原理的にその利用方法を設計すべきであると提言しています。

既存手法の詳細分析

まず、LLMの事後訓練(post-training)に用いられる代表的なアルゴリズムをいくつか概観します。

1.1 RAFT(拒否サンプリングファインチューニング)

RAFTという手法は、文献では拒否サンプリングファインチューニングとも呼ばれています。その操作手順は非常にシンプルで、主に以下の3つのステップに分かれます:

1. データ収集:一連のプロンプト x を用いて、参照モデル(例えば現在のモデル自身)で各プロンプトに対して n 個の回答を生成します。

2. データフィルタリング(拒否サンプリング):報酬関数 r(x, a) を用いて各回答にスコアを付け、最もスコアが高いもの(通常は報酬が1、つまり正しい回答)のみを保持します。これらのフィルタリングされた「良い」サンプルをデータセット D としてまとめます。

3. モデルファインチューニング:この良いサンプルのみを含むデータセット D を用いて現在のモデル π をファインチューニングします。目標は、これらの良いサンプルに対するモデルの対数尤度を最大化することです。

1.2 方策勾配(Policy Gradient)とReinforce

これは強化学習における古典的な手法です。核心思想は、目標関数 J(θ) を最適化することです。この関数は、モデルが可能な全てのプロンプト x の下で、回答 a を生成し、報酬 r(x, a) を獲得する期待値を表します:

图片

目標は、J(θ) を最大化するモデルパラメータ θ を見つけることです。通常、勾配上昇を用いてパラメータを更新します:

图片

ここでの ∇θ J(θ) は方策勾配であり、その計算式は以下の通りです:

图片

訓練をより安定させ、新旧モデルの差が大きくなりすぎて重要度サンプリングの重みが爆発するのを防ぐため、研究者たちはPPOアルゴリズムにおけるクリッピング(clipping)技術を参考にしました。最終的に、Reinforceアルゴリズムの損失関数(ここでは負の目標関数を最小化)は次のように書くことができます:

图片

LLMは自己回帰的であるため(トークンごとに生成されるため)、通常、上記の損失関数はトークンレベルで適用されます:

图片

1.3 GRPO

GRPOの損失関数形式は、上記のReinforceのトークンレベル損失とよく似ています。主な違いは、原始の報酬 r(x, a) を使用せず、代わりに各トークンについて計算されるアドバンテージ関数(Advantage Function)を使用することです。

具体的な計算方法は次の通りです:各プロンプト x に対して、 n 個の回答をサンプリングし、対応する報酬 を得ます。次に、これらの報酬の平均値 mean と標準偏差 std を計算します。 i 番目の回答における t 番目のトークンのアドバンテージ値は次のように計算されます:

图片

ここでの mean(r_1, ..., rn) は強化学習においてベースラインと呼ばれ、勾配推定の分散を減らし、訓練をより安定させる役割があります。

1.4(Iterative)DPO(直接選好最適化)

DPOは異なる手法であり、報酬スコアを直接使用せず、ペアの比較データに依存します。データセットには、 (x, a+, a-) のようなサンプルが含まれます。これは、プロンプト x に対して、回答 a+ が a- よりも優れていることを示します。

DPOが最適化する目標は、コントラスト損失(contrastive loss)です:

图片

ここで、 σ はシグモイド関数、 β はハイパーパラメータ(0より大きい)、 π_ref は通常、初期モデルまたは固定された参照モデルです。

オリジナルのDPOはオフラインデータで訓練されます。しかし、その後の研究で、反復的に行うことが可能であることが判明しました:訓練中のモデルを使用して新しい回答を生成し、何らかの方法(モデル自身による評価や人間によるアノテーションなど)で新しい選好ペア (a+, a-) を取得し、これらの新しいオンラインデータでモデルの訓練を続行します。この反復的な方法は、モデル性能を大幅に向上させることができます。

1.5 RAFT++

研究者たちは、RAFTが各イテレーションで収集されたデータ(リプレイバッファ)を用いて多段階の勾配更新を行う場合、オフポリシーアルゴリズムと見なすこともできると注目しました。

このアイデアに基づき、彼らはRAFT++を提案しました。これは、Reinforceの重要度サンプリングとクリッピング技術をRAFTにも適用したものです。その損失関数形式はReinforceに似ていますが、重要な違いがあります:それは、最高のサンプル(最も報酬が高い、つまり正のサンプル)のみで訓練を行うことです。これは指示関数 I を用いて実現されます:

ここで は指示関数であり、現在の回答 a が n 個全ての回答の中で最も報酬が高い場合に I は1となり、そうでない場合は0となります。これにより、正のサンプルのみが損失に貢献することが保証されます。

実験結果と興味深い発見のまとめ

以下は、提供された実験部分の解釈に基づき、まとめられた主な結果と興味深い発見です:

图片

シンプルな手法が驚くべき性能を発揮:

RAFTおよびその改良版であるRAFT++は、比較的シンプルな「拒否サンプリング」に基づく手法(良いサンプルのみを使用)であるにもかかわらず、数学的推論タスクにおいて驚くほど良い性能を示しました。

それらの効果は、PPO、GRPOといったより複雑な深層強化学習手法と同等であり、iterative DPOを上回りました。

特にQwenモデルにおいて、RAFT++(52.5%)の平均精度は、当時最も効果的だったGRPO(53.9%)に非常に近いものでした。

图片

RAFT++の改善は有効:

RAFTに重要度サンプリング(データ分布の偏り修正)とクリッピング(更新幅の制限)技術を加えたRAFT++は、確かにオリジナル版のRAFTよりも収束が速く、最終的な精度も向上しました。

実験の結果、クリッピングステップが非常に重要であることが証明されました。重要度サンプリングのみでクリッピングを行わない場合、効果が逆に悪化することが示され、無制限の更新が訓練の安定性を損なう可能性があることを意味します。

图片

学習ダイナミクス比較:速い初期vs持続的な向上:

RAFT++は訓練初期にGRPOよりも速く学習しました。

* しかし、RAFT++の性能向上は訓練の中後期で明らかに減速し、最終的にGRPOに追い抜かれました。

图片

負のサンプルは「諸刃の剣」か?

RAFT++(正のサンプルのみを使用)の性能向上が減速するのは、その方策エントロピー(モデルの探索性/回答の多様性)の急速な低下に関連しています。エントロピーが低すぎると、モデルは新しい推論パスを探索しにくくなります。

GRPOは負のサンプルも考慮しているため、方策エントロピーの低下が遅く、より長い期間探索能力を維持できるため、後期でも向上を続けることができます。これは、負のサンプルが探索の維持に役立つ可能性を示唆しています。

しかし、シンプルなReinforceアルゴリズム(これも負のサンプルを使用)は、LLaMAモデルにおいて正のサンプルのみを使用するRAFT++よりも効果が劣るという結果になりました。これは、負のサンプルをどのように定義し使用するかが非常に重要であり、単に最終的な正誤に基づくだけでは粗すぎる可能性があり、必ずしも常に利益をもたらすとは限らないことを示唆しています。

图片

▲ GRPOおよび強化学習タイプアルゴリズムの各コンポーネントで実施されたアブレーション研究。GRPOを他の強化学習ベースのバリアントと比較し、誤ったサンプル、正しいサンプル、および標準化の適用を除去する影響を分離しました。誤ったサンプル(「Remove all wrong」)を除去することが最大の報酬ゲインをもたらし、それらの有害な影響を浮き彫りにしました。対照的に、正しいサンプルを除去してもゲインはありませんでした。平均値ゼロ化標準化はKL損失を増加させ、訓練を不安定にしました。標準偏差による標準化はほとんど追加の利点をもたらしませんでした。「Reinforce + Remove both」バリアントは、報酬、KL安定性、エントロピー正則化の間で良好なバランスを達成しました。

GRPOの強力な核心は「サンプル除去」にある:

詳細なアブレーション実験でReinforceの様々なバリアントを比較したところ、GRPOが優れた性能を示す鍵は、「生成された回答が全て間違っているサンプル」(「Remove all wrong」)を除去することにあることが判明しました。これらの完全に間違ったサンプルは、訓練に対する最大の干渉となります。

対照的に、報酬正規化(平均を引いたり標準偏差で割ったりするなど)は性能向上に大きな影響を与えず、単純な平均値正規化でさえ訓練を不安定にする可能性があります。

全ての回答が正しいサンプル(「Remove all correct」)を除去しても、あまり助けにはなりませんでした。

「全て正しい」サンプルと「全て間違っている」サンプルを同時に除去する戦略(「Reinforce-Rej」と呼ばれる)は、性能、安定性、探索性の維持の間で良好なバランスを達成しました。

いくつかの考察

新しい簡略化されたベースラインの提案:

上記の発見に基づき、研究者たちはRAFT++とReinforce-Rej(全て正しいおよび全て間違っているサンプルを除去したReinforce)が有効かつよりシンプルなベースラインアルゴリズムであり、今後の研究で参照する価値があると考えています。

負のサンプルの役割に関する新たな考察:

研究結果は、強化学習に基づく大規模モデルの訓練において、負のサンプルの役割が想像以上に微妙であることを示しています。全ての負のサンプルを直接使用することが必ずしも最善ではなく、将来的には異なる品質のサンプルをより精緻な方法で選別し利用する必要があるかもしれません。

技術交流グループ招待状

画像

△長押しでアシスタントを追加

QRコードをスキャンしてWeChatでアシスタントを追加してください

備考:氏名-学校/会社-研究分野

(例:田中-東京大学-対話システム)

これにより、自然言語処理/PyTorchなどの技術交流グループに参加申請できます

私たちについて

MLNLPコミュニティは、国内外の機械学習および自然言語処理の研究者によって共同設立された民間の学術コミュニティであり、現在では国内外でよく知られた機械学習および自然言語処理コミュニティに発展しています。その目的は、機械学習、自然言語処理の学術界、産業界、および広範な愛好家間の交流と進歩を促進することです。

コミュニティは、関連する専門家のために、さらなる研究、就職、研究などの面でオープンな交流プラットフォームを提供することができます。皆様の関心と参加を歓迎します。

画像

メインタグ:強化学習

サブタグ:大規模言語モデルネガティブサンプルサンプル選択拒否サンプリング


前の記事:大規模なまとめ!推論モデルにおける強化学習の実装経路

次の記事:混合思考フレームワークMoT:モデルが「人間らしい思考」を学ぶことを可能に

短いURLをシェア