作者:蘇劍林 原文:https://kexue.fm/archives/11126
蘇劍林揭秘Kimi K2關鍵訓練技術QK-Clip:讓Muon在Scaleup之路上更進一步!
四個月前,我們發布了Moonlight,在16B的MoE模型上驗證了Muon最佳化器的有效性。在Moonlight中,我們確認了給Muon添加權重衰減(Weight Decay)的必要性,同時提出了透過Update RMS對齊來遷移Adam超參數的技巧,這使得Muon可以快速應用於大型語言模型(LLM)的訓練。然而,當我們嘗試將Muon進一步拓展到千億參數以上的模型時,遇到了新的「攔路虎」——MaxLogit爆炸。
為了解決這個問題,我們提出了一種簡單但極其有效的新方法,我們稱之為「QK-Clip」。該方法從一個非常本質的角度去看待和解決MaxLogit現象,並且無損模型效果,這成為我們最新發布的萬億參數模型「Kimi K2」的關鍵訓練技術之一。
問題描述
我們首先簡單介紹一下MaxLogit爆炸現象。回顧注意力(Attention)的定義:
Attention(Q, K, V) = Softmax(QK^T)V
這裡省略了縮放因子,因為它總可以吸收到V的定義中。「MaxLogit爆炸」中的Logit,指的是Softmax前的注意力矩陣,即QK^T,而MaxLogit指的是全體Logit的最大值,我們將它記為:
MaxLogit = max(QK^T)
這裡的MaxLogit其實還要從批量大小(batch_size)維度上取,最終得到一個純量。而MaxLogit爆炸是指,MaxLogit隨著訓練的推進一直往上漲,增長速度是線性甚至是超線性的,並且在相當長的時間內沒有穩定的跡象。
MaxLogit本質上是一個異常值指標,它的爆炸意味著異常值超出了可控範圍。具體來說,我們有:
||QK^T||_∞ <= ||Q||_∞ ||K||_∞
由於Q、K通常會加上RMS正規化(RMSNorm),所以一般情況下Q、K是不會爆炸的,因此MaxLogit爆炸意味著Q或K的譜範數有往無窮大發展的風險,這顯然不是一個好消息。
由於再大的數值經過Softmax後都變得小於1,所以比較幸運的情況下,這個現象不會帶來太嚴重的後果,頂多是浪費了一個注意力頭(Attention Head),但比較糟糕的情況下,可能會引起梯度尖峰(Grad Spike)甚至訓練崩潰。因此,保險起見應當盡量避免MaxLogit爆炸的出現。
已有嘗試
在《Muon續集:為什麼我們選擇嘗試Muon?》中我們簡單分析過,權重衰減能在一定程度上預防MaxLogit爆炸,所以小模型出現MaxLogit爆炸的機率很小,即便像Moonlight這樣16B的模型,MaxLogit最多漲到120後就自動降下來了。
Moonlight的MaxLogit自動降了下來
換句話說,MaxLogit爆炸更多出現在非常大參數量的模型中,模型越大,訓練的不穩定因素越多,權重衰減越難穩定訓練。這時候增加權重衰減自然也能加強控制,但同時也會帶來明顯的效果損失,所以此路不通。另一個比較直接的思路是直接給Logit加上裁剪(Clip):
Logit = Clip(QK^T, max_value=max_logit_threshold)
其中max_logit_threshold,由Google的Gemma2引入。由於裁剪的有界性,Logit自然是能夠保證裁剪後的Logit有界的,但無法保證裁剪前的Logit是有界的(親測),所以裁剪只是將一個問題轉化為另一個問題,實際上並沒有解決問題。
也許Google自己都意識到了這一點,所以在後來的Gemma3中沒有再使用裁剪,而改用「QK-Norm」:
QK-Norm: Q_norm = Q / ||Q||_2, K_norm = K / ||K||_2
QK-Norm確實是壓制MaxLogit非常有效的方法,然而它只適用於MHA、GQA等,不適用於MLA,因為QK-Norm需要將QK^T具體化(Materialize)出來,但對於MLA來說,它訓練階段與解碼階段的QK^T並不一樣(如下式所示),在解碼階段我們無法完全具體化訓練階段的QK^T,換言之,解碼階段無法做QK-Norm。
MLA訓練階段:QK^T_train = ... (複雜公式)
MLA解碼階段:QK^T_decode = ... (複雜公式)
為什麼要用MLA?我們已經用兩篇文章《Transformer升級之路:21、MLA好在哪裡?(上)》和《Transformer升級之路:21、MLA好在哪裡?(下)》討論了這個問題,這裡不再重複。總之,我們希望MLA也能有類似QK-Norm的手段來壓制MaxLogit。
直擊目標
期間我們還嘗試了一些間接手段,例如單獨降低Q、K的學習率、單獨增大它們的權重衰減等,但都不奏效。最接近成功的一次是部分QK-Norm(Partial QK-Norm),對於MLA來說,它的QK^T分為qr、qc、kr、kc四個部分,其中前三部分在解碼時都是可以具體化的,所以我們給這三部分都加上RMS正規化,結果是可以壓制MaxLogit,但長度激活效果非常糟糕。
在失敗多次之後,我們不禁開始反思:前面我們的嘗試其實都只是壓制MaxLogit的「間接手段」,真正能保證解決MaxLogit爆炸的直接手段是什麼?從不等式||QK^T||_∞ <= ||Q||_∞ ||K||_∞我們不難聯想到可以對Q、K做奇異值裁剪,但這本質上還是間接手段,而且奇異值裁剪的計算成本也不低。
但很明顯,對Q、K進行事後縮放理論上是可行的,問題是什麼時候縮放、縮放多少。終於,某天福至心靈之下,筆者總算反應過來:MaxLogit本身就是觸發縮放的最直接訊號!具體來說,當MaxLogit超過期望閾值max_logit_threshold時,我們直接給Q、K乘上因子factor = max_logit_threshold / MaxLogit,那麼新的MaxLogit肯定就不會超過max_logit_threshold了。乘factor的操作,我們可以分別吸收到權重Q、K的權重上去,於是我們得到初版QK-Clip:
Q_new = Q * factor, K_new = K * factor (where factor = min(1, max_logit_threshold / MaxLogit_L))
其中MaxLogit_L是第L層注意力的MaxLogit,Q、K是它的權重。也就是說,在最佳化器更新之後,根據MaxLogit_L的大小來決定是否對Q、K的權重進行裁剪,裁剪的幅度直接由MaxLogit_L與閾值max_logit_threshold的比例來決定,直接保證裁剪後的矩陣不再MaxLogit爆炸。同時,由於是直接對權重進行操作,所以不影響推論模式,自然也就相容MLA了。
精細調整
初版QK-Clip確實已經能成功壓制MLA的MaxLogit,但經過仔細觀察模型的「內科」後,我們發現它會出現「過度裁剪」的問題,修復該問題後就得到最終版QK-Clip。
我們知道,不管哪種注意力變體都有多個頭(Head),一開始我們是每一層注意力只監控一個MaxLogit指標,所有頭的Logit是放在一起取Max的,這導致QK-Clip也是所有頭一起裁剪的。然而,當我們分別監控每個頭的MaxLogit後發現,實際上每層只有為數不多的頭會出現MaxLogit爆炸,如果所有頭按同一個比例來裁剪,那麼大部分頭都是被「無辜受累」了,這就是過度裁剪的含義。
簡單來說,QK-Clip的操作是乘以一個小於1的數,這個數對於MaxLogit爆炸的頭來說是剛剛好抵消增長趨勢,但是對於其他頭來說是單純的縮小(它們沒有增長趨勢或者增長趨勢很弱)。由於長期無端被乘一個小於1的數,那麼很容易出現趨於零的現象,這是「過度裁剪」的表現。
所以,為了避免「殃及池魚」,我們應該針對每個頭(Per-Head)進行MaxLogit監控和QK-Clip。不過這裡面又隱藏了另一個魔鬼細節:初版QK-Clip是將裁剪因子平均分配到Q、K上的,但是MLA的QK^T有qr、qc、kr、kc四部分,其中kr是所有頭共享的,如果對它裁剪,那麼同樣會有「殃及池魚」的問題。因此,對於(qr, kr)組合,我們應該只裁剪到qr上去。
經過上述調整,最終版的QK-Clip為:
Q_new = Q * factor_L_h, K_new = K * factor_L_h (where factor_L_h = min(1, max_logit_threshold / MaxLogit_L_h))
其中上標L表示第L層、h表示第h個頭。
擴展之路
至此,QK-Clip的操作細節已經介紹完畢,它直接以我們期望的MaxLogit為訊號,對Q、K的權重進行盡可能小的改動,達到了將MaxLogit值控制在指定閾值內的效果。同時因為這是直接對權重進行修改的方法,所以它相容性比QK-Norm更好,可以用於MLA。
在Kimi K2的訓練中,我們設定閾值max_logit_threshold為100,總訓練步數約為220k步,從大約7k步開始,就出現了MaxLogit超過max_logit_threshold的頭,此後在相當長的時間內,Muon更新和QK-Clip都在「拉鋸戰」,即Muon想要增加MaxLogit而QK-Clip想要降低MaxLogit,它們一直處於微妙的平衡狀態。有趣的是,70k步之後,所有頭的MaxLogit都主動降低到了100以下,QK-Clip不再生效。
經過接近70k步的Muon和QK-Clip拉鋸戰後,MaxLogit主動降了下來
這表明,在權重衰減的作用下,只要我們能穩住訓練,模型最後很可能都會主動將MaxLogit降下來,QK-Clip的作用,正是幫助模型更平穩地度過訓練初期。可能有讀者擔心QK-Clip會有損效果,但我們在小模型上做了對比實驗,即便透過QK-Clip將MaxLogit壓得特別小(例如30),也沒有觀察到效果有實質區別,再加上中後期模型會主動將MaxLogit降下來的這一現象,我們有理由相信QK-Clip對效果是無損的。
我們在實驗中也觀察到,Muon普遍比Adam更容易MaxLogit爆炸,所以某種程度上來說,QK-Clip是專門為Muon補充的更新規則,它是Muon在超大規模訓練上的「通關秘笈」之一,這也是本文標題的含義。為此,我們將我們Moonlight中所提的Muon改動跟QK-Clip組合起來,起了個「MuonClip」的名字:
MuonClip結合了Moonlight中Muon的修改與QK-Clip。
請注意,「Muon普遍比Adam更容易MaxLogit爆炸」並不意味著只有Muon會MaxLogit爆炸,我們知道DeepSeek-V3是Adam訓練的,而我們從DeepSeek-V3的開源模型中也觀察到了MaxLogit爆炸現象,還有Gemma2使用Clip防止MaxLogit爆炸,它也是Adam訓練的。因此,雖然我們強調了QK-Clip對Muon的價值,但如果讀者堅持使用Adam,那麼它也是可以跟Adam結合形成AdamClip的。
原因思考
為什麼Muon更容易導致MaxLogit爆炸呢?這一節筆者嘗試給出一個理論角度的解釋,供大家參考。
從不等式||QK^T||_∞ <= ||Q||_∞ ||K||_∞可以看出,MaxLogit爆炸往往意味著Q或K的譜範數出現爆炸的跡象,實際上譜範數的定義中也包含了取Max操作,兩者本質上是相通的。因此,問題可以轉化為「為什麼Muon更容易導致譜範數爆炸」。我們知道譜範數等於最大的奇異值,所以又可以進一步聯想到「為什麼Muon更傾向於增大奇異值」。
Muon和Adam的區別是什麼呢?Muon給出的更新量是經過奇異值分解運算的,所有奇異值都相等,即它的有效秩是滿秩;而一般情況下的矩陣,奇異值通常都是有大有小,並且以前面幾個奇異值為主,從有效秩的角度看它們是低秩的,我們對Adam更新量的假設也是如此。這個假設並不新鮮,例如高階muP同樣假設了Adam更新量的低秩性。
用公式來說,我們設參數W的奇異值分解(SVD)為W = U_W Σ_W V_W^T,Muon更新量的SVD為ΔW_muon = U_muon Σ_muon V_muon^T,Adam更新量的SVD為ΔW_adam = U_adam Σ_adam V_adam^T,那麼
Σ_muon是一個與單位矩陣成純量倍數的矩陣,而Σ_adam具有不同的奇異值。
很明顯,如果奇異向量對U_W、V_W跟某個U_muon或V_muon很接近,那它們將會直接疊加起來,從而增大W的奇異值。由於Muon的更新量是滿秩的,所以它與W的「碰撞機率」會遠大於Adam的,所以Muon更容易增大參數的奇異值。
當然,上述分析是通用的,不限於Q、K的權重,實際上在Moonlight中我們已經驗證過,Muon訓練出來的模型權重的奇異值熵普遍更高,這也佐證了上述猜測。注意力Logit的特殊之處在於,它是雙線性形式QK^T,Q、K的連乘使得爆炸的風險更大,還容易導致「糟的更糟」的惡性循環,最終促成了MaxLogit爆炸。
Muon與Adam訓練的模型權重奇異值熵(等價於有效秩)比較
最後就是「Muon的碰撞機率遠大於Adam」是相對而言的,實際上奇異向量碰撞在一塊還是小機率事件,這也能解釋為什麼只有小部分注意力頭會出現MaxLogit爆炸現象了。
一些延伸
寫到這裡,關於QK-Clip比較重要的計算和實驗細節應該都講清楚了。另外還需要提醒的是,QK-Clip思想很簡單,但由於需要針對每個頭進行裁剪,因此在分散式訓練中寫起來還是略微有點難度的,因為此時的參數矩陣往往被切分得「支離破碎」(在Muon基礎上修改不算難,在Adam基礎上修改則稍顯複雜)。
對於筆者及其團隊來說,QK-Clip不單單是解決MaxLogit爆炸問題的一個具體方法,還是反覆嘗試透過間接手段來解決問題卻失敗後的一次「幡然醒悟」:既然有了明確的度量指標,那麼我們應該尋求能夠保證解決問題的直接思路,而不是在降低學習率、增大權重衰減、部分QK-Norm等可能但不一定能解決問題的思路上浪費時間。
從方法上來看,QK-Clip的思路也不限於解決MaxLogit爆炸,它可以說是解決很多訓練不穩定問題的「抗生素」。所謂抗生素,指的是它也許並不是解決問題最精妙的方法,但往往是解決問題最直接有效的方法之一,QK-Clip正是具有這個特點,它可以一般地推廣成「哪裡不穩裁剪哪裡」。
例如,有些情況下模型會出現「MaxOutput爆炸」的問題,這時候我們可以考慮根據MaxOutput的值來裁剪權重W。類比QK-Clip的針對每個頭的操作,這裡我們也需要考慮針對每個維度(Per-Dim)的操作,但針對每個維度的裁剪成本顯然太大,可能需要折衷一下。總之,「哪裡不穩裁剪哪裡」提供了統一的解決思路,但具體細節就要看大家發揮了。
最後,QK-Clip這種根據某些訊號手動制定更新規則的操作,一定程度上是受到了DeepSeek的Loss-Free負載均衡策略的啟發而悟到的,這裡再次致敬DeepSeek!
文章小結
本文提出了QK-Clip,它是MaxLogit爆炸問題的一種新思路,與QK-Norm不同,它是對Q、K權重的一種事後調整方案,並不改變模型的前向計算,因此適用性更廣,它是「Muon + MLA」組合在超大規模訓練上的重要穩定策略,也是我們最新發布的萬億模型Kimi K2的關鍵技術之一。
引用連結
Transformer升級之路:21、MLA好在哪裡?(上)