Abstract
The paper proposes SparseDiT, a new framework that improves the computational efficiency of Diffusion Transformer (DiT) through token sparsification in spatial (model structure) and temporal (timestep) dimensions.
In the spatial dimension, SparseDiT employs a three-stage architecture: the bottom layer uses Poolingformer for efficient global feature extraction, the middle layer utilizes Sparse-dense generation token model (SDTM) to balance global and local features, and the top layer adopts dense tokens to refine high-frequency details.
In the temporal dimension, SparseDiT dynamically adjusts the number of tokens in the denoising stages, gradually increasing tokens as timesteps progress to better capture details while achieving efficiency. This spatial and temporal strategy enhances computational efficiency while maintaining generation quality.
Experiments demonstrate that SparseDiT excels in multiple generation tasks such as image generation, video generation, and text-to-image, for example, in the 512×512 resolution image generation task, it reduces FLOPs by 55% and boosts inference speed by 175% while maintaining comparable generation quality.
Paper Title:
SparseDiT: Token Sparsification for Efficient Diffusion Transformer
Paper Link:
https://arxiv.org/pdf/2412.06028
Code Link:
https://github.com/changsn/SparseDiT
Existing Problems and Challenges
Although Diffusion Transformer (DiT) performs excellently in generation performance, its high computational complexity in self-attention and sampling steps limits its widespread use in practical applications.
Most existing methods reduce complexity by accelerating the sampling process but overlook efficiency issues in the DiT structure itself.
Compared to U-Net, DiT introduces more computational overhead at the token level through self-attention, necessitating innovative DiT-specific methods to balance computational efficiency and generation quality.
Value of the Work and Method Introduction
Value of the Work
The current work effectively addresses computational efficiency issues in DiT models through the SparseDiT method. SparseDiT reduces model computational complexity via dynamic token sparsification strategies while maintaining high-quality generation performance across multiple tasks.
Additionally, the method shows reduced computation (e.g., FLOPs) in multiple experiments and significantly improves inference speed, which is crucial for large-scale applications and deployments. Thus, this work provides a scalable solution for high-quality, efficient diffusion models.
Method Introduction
The core innovation of SparseDiT lies in token sparsification across spatial and temporal dimensions to enhance diffusion model efficiency while preserving generation quality.
The design is divided into two main parts: spatial token density management and timestep-wise dynamic token management strategies.
• Spatial Dimension: Three-Stage Architecture
1. Bottom Layer: Poolingformer
In the bottom Transformer, SparseDiT uses the Poolingformer structure to replace traditional self-attention for capturing global features.
Experiments reveal that complex computations in bottom self-attention do not yield extra information; instead, global average pooling improves efficiency. Poolingformer eliminates key and value computations, directly applying global average pooling to tokens and integrating into input tokens, reducing overhead.
The above experiment shows that directly replacing attention layers with global pooling without any finetuning has minimal impact on the image, indicating the effectiveness of bottom attention layers.
2. Middle Layer: Sparse-dense Token Module
The middle structure uses Sparse-dense token module (SDTM) technology, dividing representation into global structure extraction and local detail enhancement.
Sparse tokens handle global structure capture, effectively reducing costs, while dense tokens enhance details and stabilize training.
SDTM achieves mutual conversion between sparse and dense tokens via interaction attention layers, where sparse transformer processes sparse tokens, dense transformer handles dense tokens recovered from sparse ones, preserving information and saving compute.
3. Top Layer: Standard Transformer
In the top layer, SparseDiT continues using standard transformer layers with dense tokens to focus on refining high-frequency details, ensuring quality.
• Temporal Dimension: Time-wise Pruning Rate
Dynamic time-wise pruning rate is another key innovation, adjusting token density as denoising progresses. Specifically:
• Early Stages:
In early denoising, dominated by low-frequency structures, SparseDiT applies high pruning rates to save resources, using fewer sparse tokens for low complexity.
• Later Stages:
As denoising advances, pruning rates decrease progressively to increase density, accurately capturing high-frequency details as token needs grow.
This spatiotemporal adaptive strategy dramatically improves efficiency by reducing FLOPs and accelerating inference while preserving details.
Experimental Results
In the paper, SparseDiT is tested on Class-conditional image generation, Class-conditional video generation, and Text-to-image generation tasks with significant results.
1. Class-conditional Image Generation
At 256×256 resolution, SparseDiT-XL achieves 43% FLOPs reduction and 87% inference speedup, with FID increasing by only 0.11, maintaining performance with ~25% tokens.
At 512×512, SparseDiT offers superior performance-efficiency trade-off at high pruning, pruning >90% tokens for 55% FLOPs reduction and 175% speedup, FID up by 0.09.
These results prove SparseDiT alleviates DiT's computational burden, delivering efficiency gains without quality loss.
2. Class-conditional Video Generation
Tested on public datasets FaceForensics, SkyTimelapse, UCF101, Taichi-HD at 256×256 resolution.
SparseDiT applies higher pruning on video's extra time dimension, achieving 56% FLOPs reduction while maintaining competitive FVD scores, validating effectiveness in video generation.
3. Text-to-Image Generation Experimental Setup:
Based on PixArt-α model, trained and evaluated on SAM dataset for text-to-image.
SparseDiT achieves lower FID than original PixArt-α while significantly speeding up generation, demonstrating effectiveness.