本文提出 LeaF 框架,在知識蒸餾過程中融入基於因果分析的干擾識別機制,引導學生模型在推理時聚焦因果關鍵特徵,從而提升推理準確性與泛化能力。
論文標題:
Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning
作者單位:
中國人民大學高瓴人工智慧學院,清華大學計算機系
論文連結:https://arxiv.org/pdf/2506.07851
程式碼連結:
問題背景
儘管大型語言模型(LLMs)在自然語言處理任務中展現出強大的上下文理解與語言生成能力,但在長文本推理和複雜指令任務中仍存在明顯不足,特別是在聚焦關鍵資訊方面的能力較弱。這種注意力分散的現象嚴重制約了模型的推理準確性與生成品質。
為系統性研究這一現象,本研究首先透過教師模型與學生模型的梯度敏感度對比,識別輸入中的干擾模式(distracting patterns),並在 NuminaMath-CoT 與 AceCode-87K 資料集上評估學生模型性能。
如圖 1 和圖 2 所示,僅僅透過剪除這些干擾資訊,平均準確率即可顯著提升——在數學訓練集上提升超過 20%,程式碼訓練集上提升超過 10%。
此外,在處理 AMC_AIME 等更具複雜性的任務中,模型表現出的性能提升甚至高於 GSM8K,表明複雜推理任務中往往包含更多誤導性因素,干擾模型做出有效判斷。
▲ 圖1:程式碼任務準確率提升
▲ 圖2:數學任務準確率提升
這些發現表明,消除干擾資訊、提升模型對關鍵資訊的自主關注能力,是提升大型語言模型推理性能的關鍵路徑。
為此,作者提出 LeaF 框架,從因果視角出發,利用梯度引導識別並剔除輸入中的干擾因素,引導學生模型在蒸餾過程中學習關注關鍵的資訊區域,從而提升模型的推理表現。
實驗結果表明,LeaF 在數學推理與程式碼生成等多個下游任務中均取得了顯著性能提升。在 GSM8K、MATH 和 OlympiadBench 等資料集上,平均準確率提高了 2.41%;在 HumanEval+、LeetCode 和 LivecodeBench 等程式碼任務中,平均提升達到 2.48%。
此外,模型在推理過程中的注意力分佈更加集中、一致性更強,注意力視覺化結果也進一步驗證了方法的可解釋性。
LeaF:兩階段建模,提升模型注意力因果性
為緩解模型在推理過程中容易受干擾資訊誤導、難以聚焦關鍵資訊的問題,作者提出了一種因果驅動的注意力遷移方法——LeaF(Learning to Focus)框架。該框架由兩個核心階段構成:
干擾資訊識別:用梯度刻畫模型關注偏差
第一階段旨在識別輸入中對學生模型產生誤導但對推理本身並非必要的Token,稱為干擾型Token(confounding tokens)。
具體而言,從學生預測錯誤而教師預測正確的樣本中,計算兩者對各輸入Token的梯度敏感度對比,篩選出學生模型推理時關注(梯度值較大)而教師模型推理時不關注(梯度值較小)的Token,作為潛在干擾因素。
進一步而言,若在刪除這些Token後,學生模型與教師模型都能給出正確預測,則可將其判定為干擾型Token(confounder tokens)。即對學生推理產生誤導、但對最終得出正確答案並非必要的資訊。
▲ 圖3. LeaF 框架:透過梯度驅動的干擾識別與因果蒸餾優化推理能力
在識別干擾型Token(confounding tokens)後,LeaF 對比了兩種建構反事實輸入樣本的處理方式:
集體移除(Collective Pruning):直接將所有識別出來的干擾型Token一次性刪除;
連續片段移除(Span Pruning):以更精細的方式,每次僅刪除一個連續干擾片段,保留更多語義上下文。
透過預實驗證明,Span Pruning 更具穩定性,是更優選擇。
▲ 圖4. 移除策略示意圖
因果蒸餾:從反事實對比中學習聚焦策略
為了有效引導學生模型學習更穩健的注意力模式,在建構好原始樣本與反事實樣本後,LeaF 設計了一個混合蒸餾目標,將兩種監督訊號融合:
標準蒸餾(Standard Distillation):保持學生模型在原始輸入上與教師模型對齊;
反事實蒸餾(Counterfactual Distillation):鼓勵學生在干擾資訊被刪除後的輸入上依然與教師模型保持一致。
這種雙重蒸餾機制不僅促使學生模型對齊教師模型的輸出行為,更強化了其對輸入中關鍵Token的因果判斷能力。LeaF 透過同時建模語義資訊與因果依賴,有效避免學生模型僅模仿表面模式、忽略關鍵因果關係,從而提升推理穩健性與泛化能力。
此外,LeaF 進一步將原本僅作用於輸入端的指令層級處理(Instruction-level Pruning)拓展至響應層級處理(Response-level Pruning)。
具體而言,除了在輸入指令中識別並移除干擾型Token外,LeaF 還將模型生成的歷史響應視為上下文輸入,動態識別其中對後續推理可能產生誤導的Token並進行刪除。
該策略有助於在生成過程中持續消除干擾,進一步提升模型關注關鍵資訊的能力,從而生成更準確、聚焦的內容。
▲ 圖5 指令級處理擴展至響應級處理
實驗評估:聚焦關鍵注意力,提升推理表現
作者在數學推理與程式碼生成兩大任務上系統評估了 LeaF 框架的有效性,相關實驗涵蓋 Llama 和 Qwen 兩大主流模型架構與 6 個評估基準,驗證了 LeaF 在增強模型推理能力方面的作用。
主實驗結果
實驗表明,LeaF 在所有數學與程式碼基準任務中均帶來性能提升,平均準確率分別較標準蒸餾方法提升 2.41% 與 2.48%。其中,在高難度基準 OlympiadBench 上的改進尤為顯著,表明 LeaF 能有效應對複雜推理中的注意力干擾問題。
▲ 圖6 主實驗結果
此外,將干擾型Token(confounding token)的處理範圍從輸入指令(Instruction-level)拓展到模型生成過程(Response-level),顯著提升了模型性能,表明生成階段同樣存在影響推理的干擾資訊,分段處理策略有助於模型保持對關鍵資訊的關注。
分析實驗:LeaF 如何精準識別並規避推理誤導
為系統評估 LeaF 框架在識別並剔除干擾型Token方面的有效性,作者從四個角度展開深入分析,包括遮蔽策略、響應處理方式、閾值敏感度及案例研究,全面驗證其在提升推理穩健性與模型聚焦能力方面的表現。
4.1 梯度遮蔽策略分析:LeaF 如何精確識別干擾資訊?
為系統評估 LeaF 所採用的梯度遮蔽策略的有效性,作者將其與兩種常見遮蔽方法進行了對比實驗:隨機遮蔽與困惑度(PPL)遮蔽,實驗在 GSM8K、MATH 和 OlympiadBench 上展開,涵蓋從基礎到複雜的數學任務場景。
▲ 圖7:LeaF 梯度遮蔽策略分析實驗結果
實驗觀察:
梯度遮蔽顯著優於其他策略
在 MATH 和 OlympiadBench 等複雜推理任務上取得最優表現,驗證了 LeaF 的梯度引導機制能夠有效定位干擾性Token。
隨機遮蔽策略效果不穩定
在 GSM8K 和 OlympiadBench 上甚至導致性能下降,說明在缺乏語義指導的前提下,盲目刪減Token會破壞蒸餾訊號,也進一步強調了僅僅透過資料增強並不足以提升模型的推理能力。
困惑度遮蔽僅在簡單任務中略有提升
在複雜任務(如 OlympiadBench)中效果接近隨機遮蔽。這表明學生模型自身對Token的關注可能存在偏差,難以準確判斷哪些資訊真正重要,凸顯了引入教師模型進行對比指導的必要性。
結論:在複雜推理任務中,基於梯度差異的遮蔽策略能更精準地識別干擾型Token(confounder token),驗證了 LeaF 框架中「教師-學生梯度對比機制」的有效性與合理性。
4.2 響應層級處理策略:生成過程中的干擾資訊同樣不可忽視
LeaF 不僅在輸入指令中識別干擾型Token(Instruct-level),還進一步將干擾檢測範圍擴展到模型的生成內容中(Response-level),以涵蓋推理過程中的全鏈注意力偏差。
為此,作者設計了三種處理策略進行對比:
僅處理指令層級內容:只在輸入文字中識別和移除干擾型Token,不處理模型生成內容。
響應層級雙段處理(2段):將生成內容劃分為前後兩段,在每段中分別檢測並去除干擾型Token。
響應層級多段處理(3段):將生成內容劃分為三個連續片段,對每段獨立進行干擾檢測與處理。
▲ 圖8:LeaF 響應級處理策略實驗結果
實驗觀察:
引入響應層級處理顯著提升模型表現:相比僅處理輸入,進一步在生成過程中識別並去除干擾項,能有效增強模型的推理準確性,說明後續生成內容同樣容易受到注意力偏差的干擾。
2 段與 3 段處理效果接近:更細粒度的三段處理未帶來明顯收益,說明兩段已足以讓模型識別並學習到響應中的干擾模式;過度切分可能導致過擬合風險上升。
結論:干擾型Token(Confounder tokens)不僅存在於輸入指令中,也常常隱藏在模型生成路徑中。將干擾識別機制擴展至生成階段,並合理控制切分粒度,有助於提升模型在長推理任務中的注意力聚焦能力與整體表現。
4.3 閾值敏感度分析:小型模型對干擾更脆弱,需更積極過濾
為了探究模型對干擾型Token的敏感程度,作者在 LeaF 框架中系統分析了用於識別干擾型Token(confounder tokens)的閾值(threshold)對最終推理性能的影響。
實驗分別在不同模型規模(LLaMA3.2-1B 與 LLaMA3.2-3B)下,在兩個層級(Instruct-level 與 Response-level)進行測試。
▲ 圖9:指令級閾值敏感性分析(MathBench)
▲ 圖10:步驟級閾值敏感性分析(MathBench)
實驗觀察:
● 指令層級(Instruct-level):
LLaMA3.2-1B 在閾值為 0.10 時表現最佳;
LLaMA3.2-3B 在閾值為 0.05 時達到最優性能。
● 響應層級(Response-level):
LLaMA3.2-1B 在閾值為 0.15 時表現最佳;
LLaMA3.2-3B 則在 0.10 閾值下取得最佳效果。
分析解讀:
無論是在指令層級還是生成層級,較小模型(1B)在更高閾值下效果更佳,說明其在原始輸入中對干擾型Token更為敏感,因而更依賴積極的過濾策略以確保穩健性。
較高閾值能夠更有效地識別和過濾掉這些具有誤導性的Token,從而帶來更好的學習效果。而大型模型(3B)自身具備更強的表示與抗干擾能力,因此在更低閾值下即可獲得理想表現。
結論:模型規模影響其對干擾型Token的容忍程度。較小型模型更容易被誤導,適合採用更高的閾值進行更積極的干擾過濾。
4.4 可解釋性案例分析:模型真的學會了「聚焦關鍵」了嗎?
為了驗證 LeaF 是否真正引導模型學習到更具因果性的關注模式,作者在數學任務中建構了一個具有代表性的推理案例,比較 LeaF 與標準知識蒸餾(KD)模型在推理鏈中的注意力差異。
案例任務:判斷所有方程式的根是否為實數。
▲ 圖11 案例分析
LeaF 模型的表現:
模型成功關注到如「real number」、「all」、「are real」等關鍵資訊;
明確理解「所有根需為實數」這一限制,進而採取合理的推理策略:
識別出 x = -1 為顯然的實根;
運用判別式(Discriminant)條件來確保二次因子同樣產生實數解。
整個推理過程邏輯清晰、判斷合理,成功得出正確答案。
KD 模型的表現:
忽略了「所有根需為實數」的核心條件;
在不考慮變量符號的情況下,錯誤使用 AM–GM 不等式(可能引入負數),導致最終解答錯誤。
分析總結:
該案例直觀展示了 LeaF 幫助模型識別關鍵資訊並建構合理推理路徑的能力,從而有效規避「表層匹配式」推理誤判。同時也證明 LeaF 不只是提升準確率,更能提升模型行為的可解釋性與合理性。
未來展望
本研究驗證了 LeaF 框架在提升大型語言模型因果關注能力與推理穩健性方面的有效性,為理解和緩解注意力偏差提供了新路徑。透過引入教師-學生間的梯度差異分析與反事實蒸餾機制,LeaF 能夠引導模型有效識別並規避干擾型Token,從而學會聚焦真正關鍵的資訊區域。
未來,仍有多個值得深入探索的方向。例如,當前實驗主要聚焦數學與程式碼推理任務,進一步拓展至語言理解、問答、多跳推理等更廣泛的任務場景,以驗證其通用性與跨任務穩健性,也是未來值得研究的方向。