Is Your Model's Attention Drifting? RUC and Tsinghua University Introduce LeaF: Pruning Distracting Tokens for Focused Learning

LeaF Framework Diagram

This paper proposes the LeaF framework, which integrates a causality-based interference identification mechanism into the knowledge distillation process. This guides the student model to focus on causally critical features during inference, thereby improving reasoning accuracy and generalization ability.

Paper Information

Paper Title:

Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

Authors' Affiliation:

Gaoling School of Artificial Intelligence, Renmin University of China; Department of Computer Science and Technology, Tsinghua University

Paper Link: https://arxiv.org/pdf/2506.07851

Code Link:

https://github.com/RUCBM/LeaF

Problem Background

Problem Background

Although Large Language Models (LLMs) demonstrate powerful contextual understanding and language generation capabilities in natural language processing tasks, they still exhibit significant shortcomings in long-text reasoning and complex instruction following, particularly in their ability to focus on key information. This attention dispersion severely limits the models' reasoning accuracy and generation quality.

To systematically investigate this phenomenon, this work first identifies distracting patterns in the input by comparing the gradient sensitivity of teacher and student models, and evaluates student model performance on the NuminaMath-CoT and AceCode-87K datasets.

As shown in Figure 1 and Figure 2, simply by pruning these distracting information, the average accuracy can be significantly improved—by over 20% on the math training set and over 10% on the code training set.

Furthermore, in more complex tasks like AMC_AIME, the model's performance improvement even surpasses that on GSM8K, indicating that complex reasoning tasks often contain more misleading factors that interfere with the model's effective judgment.

Code Task Accuracy Improvement

▲ Figure 1: Code Task Accuracy Improvement

Math Task Accuracy Improvement

▲ Figure 2: Math Task Accuracy Improvement

These findings indicate that eliminating distracting information and enhancing the model's ability to autonomously focus on key information is a crucial path to improving the reasoning performance of large language models.

To this end, the authors propose the LeaF framework. From a causal perspective, it uses gradient guidance to identify and eliminate distracting factors in the input, guiding the student model to learn to focus on crucial information regions during distillation, thereby improving the model's reasoning performance.

Experimental results show that LeaF achieves significant performance improvements across multiple downstream tasks, including mathematical reasoning and code generation. On datasets such as GSM8K, MATH, and OlympiadBench, the average accuracy increased by 2.41%; in code tasks like HumanEval+, LeetCode, and LivecodeBench, the average improvement reached 2.48%.

Additionally, the model's attention distribution during inference becomes more concentrated and consistent, and attention visualization results further validate the interpretability of the method.

LeaF Framework Overview

LeaF: Two-Stage Modeling to Enhance Causal Attention in Models

To alleviate the problem of models being easily misled by distracting information and struggling to focus on key information during inference, the authors propose a causality-driven attention transfer method—the LeaF (Learning to Focus) framework. This framework consists of two core stages:

Interference Information Identification: Using Gradients to Characterize Model Attention Bias

The first stage aims to identify tokens in the input that mislead the student model but are not necessary for the reasoning itself, referred to as confounding tokens.

Specifically, from samples where the student predicts incorrectly but the teacher predicts correctly, the gradient sensitivity of both models to each input token is compared. Tokens that the student model focuses on (larger gradient values) but the teacher model does not (smaller gradient values) are identified as potential interfering factors.

Furthermore, if both the student and teacher models can provide correct predictions after these tokens are removed, they can be identified as confounder tokens. That is, information that misleads student reasoning but is not essential for deriving the correct answer.

LeaF Framework: Optimizing Reasoning Capability via Gradient-Driven Interference Identification and Causal Distillation

▲ Figure 3. LeaF Framework: Optimizing Reasoning Capability via Gradient-Driven Interference Identification and Causal Distillation

After identifying confounding tokens, LeaF compares two methods for constructing counterfactual input samples:

Collective Pruning: Directly removing all identified confounding tokens at once;

Span Pruning: A more refined approach, where only one continuous interference span is removed at a time, preserving more semantic context.

Pre-experiments demonstrated that Span Pruning is more stable and a superior choice.

Pruning Strategy Diagram

▲ Figure 4. Pruning Strategy Diagram

Causal Distillation: Learning Focusing Strategies from Counterfactual Contrasts

To effectively guide the student model to learn more robust attention patterns, after constructing original and counterfactual samples, LeaF designs a hybrid distillation objective that integrates two supervision signals:

Standard Distillation: Keeping the student model aligned with the teacher on the original input;

Counterfactual Distillation: Encouraging the student to remain consistent with the teacher on inputs after distracting information has been removed.

This dual distillation mechanism not only prompts the student model to align with the teacher model's output behavior but also strengthens its causal judgment ability regarding key tokens in the input. By simultaneously modeling semantic information and causal dependencies, LeaF effectively prevents the student model from merely mimicking superficial patterns and neglecting crucial causal relationships, thereby enhancing reasoning robustness and generalization ability.

Furthermore, LeaF extends the instruction-level processing (Instruction-level Pruning), which originally only applied to the input side, to response-level processing (Response-level Pruning).

Specifically, in addition to identifying and removing distracting tokens in the input instructions, LeaF also treats the model's generated historical responses as contextual input, dynamically identifying and deleting tokens within them that might mislead subsequent reasoning.

This strategy helps to continuously eliminate interference during the generation process, further enhancing the model's ability to focus on key information, thereby producing more accurate and focused content.

Instruction-level Processing Extended to Response-level Processing

▲ Figure 5 Instruction-level Processing Extended to Response-level Processing

Main Experiment Results

Experimental Evaluation: Focusing Key Attention to Improve Reasoning Performance

The authors systematically evaluated the effectiveness of the LeaF framework on two major tasks: mathematical reasoning and code generation. The experiments covered two mainstream model architectures, Llama and Qwen, and six evaluation benchmarks, validating LeaF's role in enhancing the model's reasoning capabilities.

Main Experiment Results

Experiments show that LeaF yields performance improvements across all mathematical and coding benchmark tasks, with average accuracy increasing by 2.41% and 2.48% respectively compared to standard distillation methods. Notably, improvements on the high-difficulty OlympiadBench benchmark were particularly significant, indicating LeaF's effectiveness in handling attention interference in complex reasoning.

Main Experiment Results

▲ Figure 6 Main Experiment Results

Furthermore, extending the processing scope of confounding tokens from input instructions (Instruction-level) to the model generation process (Response-level) significantly improved model performance. This indicates that the generation phase also contains distracting information that affects reasoning, and a segmented processing strategy helps the model maintain focus on key information.

LeaF Gradient Masking Strategy Analysis Experiment Results

Analysis Experiments: How LeaF Precisely Identifies and Avoids Reasoning Misguidance

To systematically evaluate the effectiveness of the LeaF framework in identifying and removing distracting tokens, the authors conducted an in-depth analysis from four perspectives: masking strategies, response processing methods, threshold sensitivity, and case studies, comprehensively verifying its performance in improving reasoning robustness and model focusing capabilities.

4.1 Gradient Masking Strategy Analysis: How Does LeaF Precisely Identify Distracting Information?

To systematically evaluate the effectiveness of LeaF's gradient masking strategy, the authors compared it with two common masking methods: random masking and perplexity (PPL) masking. Experiments were conducted on GSM8K, MATH, and OlympiadBench, covering mathematical tasks from basic to complex scenarios.

LeaF Gradient Masking Strategy Analysis Experiment Results

▲ Figure 7: LeaF Gradient Masking Strategy Analysis Experiment Results

Experimental Observations:

Gradient Masking Significantly Outperforms Other Strategies

Achieves optimal performance on complex reasoning tasks like MATH and OlympiadBench, validating that LeaF's gradient guidance mechanism can effectively locate distracting tokens.

Random Masking Strategy Shows Unstable Performance

On GSM8K and OlympiadBench, it even led to performance degradation, indicating that blindly pruning tokens without semantic guidance can destroy distillation signals, and further emphasizing that data augmentation alone is insufficient to enhance a model's reasoning capabilities.

Perplexity Masking Only Provides Slight Improvement in Simple Tasks

In complex tasks (e.g., OlympiadBench), its effect is close to random masking. This suggests that the student model's own attention to tokens might be biased, making it difficult to accurately determine which information is truly important, highlighting the necessity of introducing a teacher model for comparative guidance.

Conclusion: In complex reasoning tasks, the gradient difference-based masking strategy can more accurately identify confounder tokens, validating the effectiveness and rationality of the 'teacher-student gradient comparison mechanism' in the LeaF framework.

4.2 Response-level Processing Strategy: Distracting Information in the Generation Process Cannot Be Ignored Either

LeaF not only identifies distracting tokens in the input instructions (Instruct-level) but further extends the scope of interference detection to the model's generated content (Response-level) to cover the full chain of attention bias during the reasoning process.

To this end, the authors designed three processing strategies for comparison:

Instruction-level only: Only identifies and removes distracting tokens in the input text, without processing the model's generated content.

Response-level two-segment processing (2 segments): Divides the generated content into two segments (front and back), detecting and removing distracting tokens in each segment separately.

Response-level multi-segment processing (3 segments): Divides the generated content into three continuous segments, independently detecting and processing interference in each segment.

LeaF Response-level Processing Strategy Experiment Results

▲ Figure 8: LeaF Response-level Processing Strategy Experiment Results

Experimental Observations:

Introducing response-level processing significantly improves model performance: Compared to only processing the input, further identifying and removing distracting elements during the generation process effectively enhances the model's reasoning accuracy, indicating that subsequent generated content is also susceptible to attention bias.

Two-segment and three-segment processing yield similar effects: The more fine-grained three-segment processing did not bring significant benefits, indicating that two segments are sufficient for the model to identify and learn the interference patterns in the response; excessive segmentation might increase the risk of overfitting.

Conclusion: Confounder tokens are not only present in the input instructions but also often hidden in the model's generation path. Extending the interference identification mechanism to the generation stage, and reasonably controlling the granularity of segmentation, helps to improve the model's attention focusing ability and overall performance in long reasoning tasks.

4.3 Threshold Sensitivity Analysis: Smaller Models Are More Vulnerable to Interference, Requiring More Aggressive Filtering

To investigate the model's sensitivity to distracting tokens, the authors systematically analyzed the impact of the threshold used to identify confounder tokens within the LeaF framework on final reasoning performance.

Experiments were conducted at two levels (Instruction-level and Response-level) under different model scales (LLaMA3.2-1B and LLaMA3.2-3B).

Instruction-level Threshold Sensitivity Analysis (MathBench)

▲ Figure 9: Instruction-level Threshold Sensitivity Analysis (MathBench)

Step-level Threshold Sensitivity Analysis (MathBench)

▲ Figure 10: Step-level Threshold Sensitivity Analysis (MathBench)

Experimental Observations:

● Instruction-level:

LLaMA3.2-1B performed best with a threshold of 0.10;

LLaMA3.2-3B achieved optimal performance with a threshold of 0.05.

● Response-level:

LLaMA3.2-1B performed best with a threshold of 0.15;

LLaMA3.2-3B achieved the best results with a threshold of 0.10.

Analysis and Interpretation:

Regardless of whether it's at the instruction level or the generation level, smaller models (1B) perform better with higher thresholds. This indicates that they are more sensitive to distracting tokens in the original input and thus rely more on aggressive filtering strategies to ensure robustness.

Higher thresholds can more effectively identify and filter out these misleading tokens, leading to better learning outcomes. In contrast, larger models (3B) possess stronger representation and anti-interference capabilities, thus achieving ideal performance even with lower thresholds.

Conclusion: Model scale influences its tolerance for distracting tokens. Smaller models are more easily misled and are suitable for more aggressive interference filtering using higher thresholds.

4.4 Interpretability Case Study: Has the Model Truly Learned to "Focus on the Key Elements"?

To verify whether LeaF truly guides the model to learn more causal attention patterns, the authors constructed a representative reasoning case in a mathematical task, comparing the attention differences between LeaF and a standard Knowledge Distillation (KD) model within the reasoning chain.

Case Task: Determine if all roots of the equation are real numbers.

Case Study

▲ Figure 11 Case Study

LeaF Model Performance:

The model successfully focused on key information such as "real number," "all," and "are real";

Clearly understood the constraint that "all roots must be real numbers," and subsequently adopted a sound reasoning strategy:

Identified x = -1 as an obvious real root;

Applied the Discriminant condition to ensure the quadratic factor also yields real solutions.

The entire reasoning process was logically clear and sound, successfully yielding the correct answer.

KD Model Performance:

Ignored the core condition that "all roots must be real numbers";

Incorrectly used the AM–GM inequality (potentially introducing negative numbers) without considering variable signs, leading to an incorrect final solution.

Analysis Summary:

This case intuitively demonstrates LeaF's ability to help models identify key information and construct sound reasoning paths, thereby effectively avoiding "superficial matching" reasoning errors. It also proves that LeaF not only improves accuracy but also enhances the interpretability and rationality of model behavior.

Future Outlook

Future Outlook

This work validates the effectiveness of the LeaF framework in enhancing the causal attention and reasoning robustness of large language models, providing a new path to understanding and mitigating attention bias. By introducing a gradient difference analysis and counterfactual distillation mechanism between teacher and student models, LeaF can guide models to effectively identify and avoid distracting tokens, thus learning to focus on truly critical information regions.

In the future, several directions are worth further exploration. For instance, current experiments primarily focus on mathematical and code reasoning tasks; extending this to broader task scenarios such as language understanding, question answering, and multi-hop reasoning to verify its universality and cross-task robustness is also a promising research direction.

Main Tag:Large Language Models

Sub Tags:Knowledge DistillationModel OptimizationCausal InferenceAttention Mechanism


Previous:Can Models Truly "Reflect on Code"? Beihang University Releases Repository-Level Understanding and Generation Benchmark, Refreshing the LLM Understanding Evaluation Paradigm

Next:Oxford Anthropologist Anna Machin: Dating Apps Are Making Your Brain's "Mate Selection Algorithm" Fail

Share Short URL