著者:蘇剣林 原文:https://kexue.fm/archives/11126
蘇剣林がKimi K2の主要な訓練技術QK-Clipを解明:Muonをスケールアップの道でさらに前進させる!
4ヶ月前、私たちはMoonlightを発表し、16BのMoEモデルでMuonオプティマイザの有効性を検証しました。Moonlightでは、MuonにWeight Decayを追加することの必要性を確認し、Update RMSアラインメントを介してAdamのハイパーパラメータを移行する手法を提案しました。これにより、MuonはLLMの訓練に迅速に適用できるようになりました。しかし、Muonを数百億パラメータ以上のモデルにさらに拡張しようとした際、新たな「障害」であるMaxLogit爆発に遭遇しました。
この問題を解決するために、私たちはシンプルだが非常に効果的な新手法「QK-Clip」を提案しました。この手法は、MaxLogit現象を非常に本質的な観点から捉え、解決するもので、モデルの性能を損なうことなく、最近発表された兆パラメータモデル「Kimi K2」の主要な訓練技術の一つとなりました。
問題の記述
まず、MaxLogit爆発現象について簡単に説明します。アテンション(Attention)の定義を振り返ります。
Attention(Q, K, V) = Softmax(QK^T)V
ここではスケール因子が省略されていますが、それは常にVの定義に吸収できるためです。「MaxLogit爆発」におけるLogitとは、Softmaxの前のAttention行列、すなわちQK^Tを指し、MaxLogitとは、全てのLogitの最大値を指します。これを以下のように表記します。
MaxLogit = max(QK^T)
このMaxLogitは実際にはbatch_size次元でも取られ、最終的にスカラー値となります。MaxLogit爆発とは、訓練の進行とともにMaxLogitが線形あるいは超線形に上昇し続け、かなりの期間にわたって安定する兆候が見られない状態を指します。
MaxLogitは本質的に異常値の指標であり、その爆発は異常値が制御可能な範囲を超えたことを意味します。具体的には、次の関係があります。
||QK^T||_∞ <= ||Q||_∞ ||K||_∞
QとKは通常RMSNormが適用されるため、一般的にQとKが爆発することはありません。したがって、MaxLogit爆発はQまたはKのスペクトルノルムが無限大に発散するリスクがあることを意味し、これは明らかに良いニュースではありません。
Softmaxを通すとどんなに大きな値でも1未満になるため、幸運なケースではこの現象が深刻な結果をもたらすことはありません。せいぜいAttention Headが一つ無駄になる程度です。しかし、最悪の場合、Grad Spikeや訓練の崩壊を引き起こす可能性があります。そのため、安全策としてMaxLogit爆発の発生を極力避けるべきです。
これまでの試み
「Muon続編:なぜMuonを試すことを選んだのか?」の中で、Weight DecayがMaxLogit爆発をある程度防ぐことができると簡単に分析しました。そのため、小さなモデルではMaxLogit爆発が発生する確率は非常に低く、Moonlightのような16Bモデルでさえ、MaxLogitは最大120まで上昇した後に自動的に低下しました。
MoonlightのMaxLogitは自動的に低下しました。
言い換えれば、MaxLogit爆発は非常に多数のパラメータを持つモデルでより頻繁に発生し、モデルが大きくなるほど、訓練の不安定要因が増え、Weight Decayでは訓練を安定させることが難しくなります。この時、Weight Decayを増やすことは確かに制御を強化できますが、同時に明らかな性能低下を招くため、この方法は使えません。もう一つの比較的直接的なアプローチは、Logitに直接クリップ(Clip)を追加することです。
Logit = Clip(QK^T, max_value=max_logit_threshold)
ここでmax_logit_thresholdは、GoogleのGemma2によって導入されました。Clipの有界性により、Clip後のLogitが有界であることは保証されますが、Clip前のLogitが有界であることは保証できません(実測済み)。したがって、Clipは問題を別の問題に変換しただけであり、実際には問題を解決していません。
Google自身もこの点に気づいていたのかもしれません。そのため、後のGemma3ではClipを使わず、「QK-Norm」を採用しました。
QK-Norm: Q_norm = Q / ||Q||_2, K_norm = K / ||K||_2
QK-NormはMaxLogitを抑制する非常に効果的な方法ですが、MHAやGQAなどには適用できても、MLAには適用できません。なぜなら、QK-NormはQK^Tを具体化(Materialize)する必要があるからです。しかし、MLAの場合、訓練段階とデコード段階のQK^Tは異なります(下式参照)。デコード段階では、訓練段階のQK^Tを完全に具体化できないため、言い換えれば、デコード段階ではQK-Normを実行できません。
MLA訓練: QK^T_train = ... (複雑な数式)
MLAデコード: QK^T_decode = ... (複雑な数式)
なぜMLAを使うのか?私たちは既に2つの記事、「Transformerアップグレードの道:21、MLAの何が良いのか?(上)」と「Transformerアップグレードの道:21、MLAの何が良いのか?(下)」でこの問題について議論しているので、ここでは繰り返しません。要するに、MLAでもQK-NormのようなMaxLogitを抑制する手段が欲しいのです。
直接的な目標
この間、私たちはQやKの学習率を個別に下げたり、Weight Decayを個別に増やしたりするなど、いくつかの間接的な手段も試しましたが、どれも効果がありませんでした。最も成功に近づいたのはPartial QK-Normでした。MLAの場合、そのQK^Tはqr、qc、kr、kcの4つの部分に分かれており、このうち最初の3つの部分はデコード時に具体化できるため、これら3つの部分全てにRMSNormを追加しました。結果としてMaxLogitを抑制できましたが、長さ活性化の効果が非常に悪かったです。
何度も失敗した後、私たちは反省せずにはいられませんでした。これまでの試みは、MaxLogitの爆発を抑制する「間接的な手段」に過ぎなかったのではないか? MaxLogitの爆発を確実に解決できる直接的な手段とは何だろうか? 不等式 ||QK^T||_∞ <= ||Q||_∞ ||K||_∞ から、QやKに特異値クリッピングを行うことを連想できますが、これも本質的には間接的な手段であり、特異値クリッピングの計算コストも安くありません。
しかし、明らかにQとKに対して事後スケーリングを行うことは理論的に可能です。問題は、いつスケーリングし、どれくらいスケーリングするか、です。ついに、ある日、閃きが訪れ、筆者は気づきました。MaxLogitそのものがスケーリングをトリガーする最も直接的なシグナルだと! 具体的には、MaxLogitが期待される閾値max_logit_thresholdを超えた場合、QとKに直接 factor = max_logit_threshold / MaxLogit を乗算します。そうすれば、新しいMaxLogitは確実にmax_logit_thresholdを超えなくなります。factorを乗算する操作は、それぞれQとKの重みに吸収できるため、初版のQK-Clipが得られます。
Q_new = Q * factor, K_new = K * factor (where factor = min(1, max_logit_threshold / MaxLogit_L))
ここでMaxLogit_LはL層目のAttentionのMaxLogitであり、QとKはその重みです。つまり、オプティマイザ更新後、MaxLogit_Lの大きさに応じてQとKの重みをクリッピングするかどうかを決定し、クリッピングの程度はMaxLogit_Lと閾値max_logit_thresholdの比率によって直接決定されます。これにより、クリッピング後の行列がMaxLogit爆発を起こさなくなることが直接保証されます。同時に、重みに直接操作を行うため、推論モードに影響を与えず、当然ながらMLAとも互換性があります。
精密な調整
初版のQK-ClipはMLAのMaxLogitを正常に抑制できましたが、モデルの「内部」を注意深く観察した結果、「過剰なクリッピング」の問題が発生することを発見しました。この問題を修正した後、最終版のQK-Clipが得られました。
ご存知のように、どのアテンションバリアントにも複数のヘッドがあります。当初、私たちは各アテンション層で1つのMaxLogit指標のみを監視し、すべてのヘッドのLogitをまとめて最大値を取得していました。これにより、QK-Clipもすべてのヘッドをまとめてクリッピングしていました。しかし、各ヘッドのMaxLogitを個別に監視したところ、実際には各層でMaxLogit爆発が発生するヘッドはごく少数であることがわかりました。もし全てのヘッドを同じ比率でクリッピングすると、ほとんどのヘッドが「不当に影響を受ける」ことになります。これが過剰なクリッピングの意味です。
簡単に言えば、QK-Clipの操作は1未満の数を乗じることです。この数は、MaxLogitが爆発しているヘッドにとっては成長傾向を打ち消すのにちょうど良いのですが、他のヘッドにとっては単なる縮小になります(それらは成長傾向がないか、非常に弱い)。長期間にわたって不必要に1未満の数を乗じられると、値がゼロに近づく現象が容易に発生します。これが「過剰なクリッピング」の現れです。
したがって、「巻き添え」を避けるために、MaxLogitの監視とQK-Clipはヘッドごと(Per-Head)に行うべきです。しかし、ここにはもう一つの厄介な詳細が隠されています。初版のQK-Clipはクリッピング因子をQとKに均等に分配していましたが、MLAのQK^Tにはqr、qc、kr、kcの4つの部分があり、そのうちkrは全てのヘッドで共有されています。もしkrをクリッピングすると、同様に「巻き添え」の問題が発生します。そのため、(qr, kr)の組については、qrにのみクリッピングを適用すべきです。
上記の調整を経て、最終版のQK-Clipは以下のようになりました。
Q_new = Q * factor_L_h, K_new = K * factor_L_h (where factor_L_h = min(1, max_logit_threshold / MaxLogit_L_h))
ここで上付き文字LはL層目を、hはh番目のヘッドを表します。
拡張の道
ここまでに、QK-Clipの操作の詳細について説明しました。これは、期待されるMaxLogitを信号として直接QとKの重みに可能な限り小さな変更を加え、MaxLogit値を指定された閾値内に制御する効果を達成します。同時に、これは重みに直接変更を加える方法であるため、QK-Normよりも互換性が高く、MLAにも使用できます。
Kimi K2の訓練では、閾値max_logit_thresholdを100に設定しました。総訓練ステップ数は約220kステップで、およそ7kステップからMaxLogitがmax_logit_thresholdを超えるヘッドが出現し始めました。その後、かなりの期間にわたってMuon UpdateとQK-Clipは「綱引き」状態、すなわちMuonはMaxLogitを増やそうとし、QK-ClipはMaxLogitを減らそうとするという微妙なバランスが続いていました。興味深いことに、70kステップ後には、すべてのヘッドのMaxLogitが自発的に100以下に低下し、QK-Clipは無効になりました。
約70kステップにわたるMuonとQK-Clipの綱引きの後、MaxLogitは自発的に低下しました。
これは、Weight Decayの作用下では、訓練を安定させられれば、モデルは最終的に自発的にMaxLogitを低下させる可能性が高いことを示しています。QK-Clipの役割は、まさにモデルが訓練初期をより安定して乗り越えるのを助けることです。読者の中にはQK-Clipが性能を損なうのではないかと心配する方もいるかもしれませんが、私たちは小さなモデルで比較実験を行いました。たとえQK-ClipによってMaxLogitを特に小さく(例えば30)抑制しても、性能に実質的な違いは観察されませんでした。さらに、中後期にはモデルが自発的にMaxLogitを低下させる現象も相まって、QK-Clipが性能を損なわないと信じるに足る理由があります。
実験では、MuonがAdamよりも一般的にMaxLogit爆発を起こしやすいことも観察されました。そのため、ある意味で、QK-ClipはMuonのために特別に補完された更新ルールであり、超大規模訓練におけるMuonの「切り札」の一つであり、これが本記事のタイトルの意味するところでもあります。この目的のため、私たちはMoonlightで提案したMuonの変更とQK-Clipを組み合わせ、「MuonClip」という名前を付けました。
MuonClipは、MoonlightにおけるMuonの変更とQK-Clipを組み合わせたものです。
注意すべきは、「MuonがAdamよりも一般的にMaxLogit爆発を起こしやすい」という点が、MuonだけがMaxLogit爆発を起こすという意味ではないことです。DeepSeek-V3はAdamで訓練されていますが、DeepSeek-V3のオープンソースモデルでもMaxLogit爆発現象が観察されましたし、Gemma2はClipを用いてMaxLogit爆発を防いでいますが、これもAdamで訓練されています。したがって、QK-ClipのMuonに対する価値を強調しましたが、読者がAdamの使用にこだわる場合でも、Adamと組み合わせてAdamClipを形成することが可能です。
原因考察
なぜMuonはMaxLogit爆発を引き起こしやすいのでしょうか?このセクションでは、筆者が理論的な観点からの説明を試みますので、参考にしてください。
不等式 ||QK^T||_∞ <= ||Q||_∞ ||K||_∞ から、MaxLogit爆発はQまたはKのスペクトルノルムが爆発する兆候を示すことが多いことがわかります。実際、スペクトルノルムの定義にもMax操作が含まれており、両者は本質的に共通しています。したがって、問題は「なぜMuonはスペクトルノルム爆発を引き起こしやすいのか」に変換できます。スペクトルノルムが最大の特異値に等しいことを知っているので、さらに「なぜMuonは特異値を増大させる傾向があるのか」と連想できます。
MuonとAdamの違いは何でしょうか? Muonが提供する更新量は特異値分解(SVD)演算を経ており、全ての特異値が等しく、つまりその有効ランクはフルランクです。一方、一般的な行列の特異値は通常大小様々であり、前のいくつかの特異値が支配的です。有効ランクの観点から見ると、これらは低ランクであり、Adamの更新量についても同様の仮定をしています。この仮定は新しいものではなく、例えば高階muPもAdamの更新量の低ランク性を仮定しています。
式で言うと、パラメータWのSVDを W = U_W Σ_W V_W^T、Muon更新量のSVDを ΔW_muon = U_muon Σ_muon V_muon^T、Adam更新量のSVDを ΔW_adam = U_adam Σ_adam V_adam^T とします。
Σ_muonは単位行列のスカラー倍であるのに対し、Σ_adamは異なる特異値を持つ。
明らかに、特異ベクトル対U_W、V_Wが特定のU_muonまたはV_muonに非常に近い場合、それらは直接重ね合わされ、その結果Wの特異値が増大します。Muonの更新量がフルランクであるため、Wとの「衝突確率」はAdamよりもはるかに高く、そのためMuonはパラメータの特異値を増大させやすいのです。
もちろん、上記の分析は一般的なものであり、QやKの重みに限定されるものではありません。実際、Moonlightでは、Muonで訓練されたモデルの重みの特異値エントロピーが一般的に高いことが検証されており、これも上記の推測を裏付けています。アテンションLogitの特殊性は、それが双線形形式QK^Tである点にあります。QとKの連乗は爆発のリスクを増大させ、さらに「悪いことがさらに悪くなる」という悪循環を引き起こしやすく、最終的にMaxLogit爆発を促進します。
MuonとAdamで訓練されたモデルの重みの特異値エントロピー(有効ランクと同等)の比較。
最後に、「Muonの衝突確率がAdamよりもはるかに大きい」というのは相対的な話であり、実際には特異ベクトルが衝突するのは依然として低い確率の事象です。これが、ごく一部のAttention HeadでしかMaxLogit爆発現象が発生しない理由を説明しています。
いくつかの拡張
ここまでで、QK-Clipに関する重要な計算と実験の詳細について説明を終えました。補足として、QK-Clipの考え方は非常にシンプルですが、ヘッドごとのクリップが必要なため、分散訓練で実装するにはやや難易度があります。これは、その際パラメータ行列が「バラバラ」に分割されることが多いためです(Muonをベースに修正するのは難しくありませんが、Adamをベースに修正する場合はやや複雑になります)。
筆者とそのチームにとって、QK-ClipはMaxLogit爆発問題を解決する具体的な方法であるだけでなく、間接的な手段による問題解決の試みが繰り返し失敗した後の「目覚め」でもありました。明確な測定指標があるならば、問題解決を保証できる直接的なアプローチを追求すべきであり、学習率の低下、Weight Decayの増大、部分的なQK-Normなど、可能性はあるものの必ずしも問題を解決しない考え方に時間を浪費すべきではありません。
手法から見ると、QK-Clipの考え方はMaxLogit爆発の解決に限定されません。多くの訓練不安定問題を解決する「抗生物質」と言えるでしょう。抗生物質とは、問題を解決する上で最も洗練された方法ではないかもしれませんが、多くの場合、最も直接的で効果的な方法の一つであり、QK-Clipはこの特徴を持っています。「不安定な場所をクリップする」という一般的な考え方に拡張できます。
例えば、場合によってはモデルで「MaxOutput爆発」の問題が発生することがあります。このとき、MaxOutputの値に基づいて重みWをクリップすることを検討できます。QK-Clipのヘッドごとの操作と同様に、ここでは次元ごとの操作を考慮する必要がありますが、次元ごとのクリップは明らかにコストが大きすぎるため、妥協が必要かもしれません。要するに、「不安定な場所をクリップする」は統一された解決策を提供しますが、具体的な詳細は皆さんの工夫次第です。
最後に、QK-Clipのような特定の信号に基づいて手動で更新ルールを策定する操作は、DeepSeekのLoss-Freeロードバランス戦略に触発されて思いついたものであり、ここに改めてDeepSeekに敬意を表します!
記事のまとめ
本稿ではQK-Clipを提案しました。これはMaxLogit爆発問題に対する新しいアプローチであり、QK-Normとは異なり、Q、K重みに対する事後調整スキームであり、モデルの順方向計算を変更しません。そのため適用範囲が広く、「Muon + MLA」の組み合わせにおける超大規模訓練の重要な安定化戦略であり、我々が最近発表した兆パラメータモデルKimi K2の主要な技術の一つです。
参考文献
Transformerアップグレードの道:21、MLAの何が良いのか?(上)