摘要
論文提出了 SparseDiT,一種透過空間(模型結構)和時間維度(時間步)的 token 稀疏化來提升 Diffusion Transformer(DiT)運算效率的新框架。
在空間維度上,SparseDiT 採用三段式架構:底層使用 Poolingformer 進行高效全球特徵提取,中層利用 Sparse-dense generation token model(SDTM)平衡全球與局部的特徵,頂層則採用密集 token 來提煉高頻細節。
在時間維度上,SparseDiT 動態調整去噪階段中的 token 數量,隨著時間步(Timestep)推進逐漸增加 token,以在實現高效性的同時更好地捕捉細節。這種空間和時間策略提升了運算效率,同時維持生成品質。
實驗證明 SparseDiT 在圖片生成、影片生成和文生圖等多個生成任務上取得優異表現,例如在 512×512 解析度上的圖片生成任務中,在減少 FLOPs(55%)和提升推理速度(175%)的同時,仍能維持接近的生成品質。
論文標題:
SparseDiT: Token Sparsification for Efficient Diffusion Transformer
論文地址:
https://arxiv.org/pdf/2412.06028
程式碼地址:
https://github.com/changsn/SparseDiT
現存問題及挑戰介紹
儘管 Diffusion Transformer(DiT)在生成效能上表現出色,但其在自注意力及取樣步驟上的高運算複雜度限制了其在實際應用中的廣泛性。
大多數現有方法透過加速取樣過程來降低複雜度,但忽略了 DiT 自身結構上的效率問題。
與 U-Net 相比,DiT 在 token 級別的自注意力引入了更多運算開銷,因此需要設計針對 DiT 的創新方法來實現運算效率與生成品質之間的平衡。
工作價值以及方法介紹
工作價值
目前工作透過 SparseDiT 方法有效解決了 DiT 模型中的運算效率問題。SparseDiT 透過動態 token 稀疏化策略,不僅減少了模型運算複雜度,同時在多個生成任務上維持了高品質的生成效能。
此外,該方法在多個實驗驗證中均顯示出減少的運算量(如 FLOPs),並顯著提升了推理速度,這對於大規模應用和實際部署尤為重要。因此,該工作為高品質、高效率的擴散模型提供了可擴展的解決方案。
方法介紹
SparseDiT 的核心創新在於透過空間和時間維度對 token 進行稀疏化,以提升擴散模型的運算效率,同時維持生成品質。
該方法的设计可以分為兩個主要部分:空間上的 token 密度管理和時間上的 timestep-wise 動態 token 管理策略。
• 空間維度:三段式架構
1. 底層架構:Poolingformer
在底層 Transformer,SparseDiT 使用 Poolingformer 結構取代傳統自注意力機制以捕捉全球特徵。
實驗發現,底層自注意力的複雜運算並不能帶來額外資訊,反而可以透過全球平均 pooling 實現效率的提升。Poolingformer 去除了 key 和 value 的運算,直接對 token 進行全球平均池化,整合到輸入 token 中,從而減少運算開銷。
上圖實驗顯示,不經過任何 finetuning 的情況下,直接將注意力層取代為全球 pooling,不會對畫面造成很大的影響,說明底層的注意力層產生的效果有效。
2. 中層架構:Sparse-dense token module
中層結構採用 Sparse-dense token module(SDTM)技術,將表示過程分為全球結構提取和局部細節強化。
Sparse token 負責全球結構資訊的捕捉,有效降低運算成本,而 dense token 則用於細節強化,穩定訓練過程。
SDTM 透過交互注意力層實現 sparse token 與 dense token 的相互轉換,其中 sparse transformer 處理 sparse token,dense transformer 處理由 sparse token 恢復而來的 dense token,同時實現了資訊的保留與算力的節省。
3. 頂層架構:標準 Transformer
在頂層,SparseDiT 繼續使用標準 transformer 層,以 dense token 處理模式專注於高頻細節的提煉,確保生成品質。
• 時間維度:Time-wise pruning rate
動態 Time-wise pruning rate 是 SparseDiT 的另一個關鍵創新,旨在隨著去噪的進行而動態調整 token 密度。具体來說:
• 早期階段:
在早期去噪階段,由於以低頻結構為主,SparseDiT 應用較高的剪枝率來節省運算資源。此時,使用更少的 sparse token,運算複雜度較低,節省了不必要的 token 操作。
• 後期階段:
隨著去噪階段的推進,漸進式減少剪枝率以增加 token 密度,確保高頻細節能夠被準確捕捉。此時,token 需求逐步增加,反應了對細節需求的成長。
透過這種時空雙重適應性策略,SparseDiT 在維持生成細節的同時,大幅度提升了運算效率,表現為透過減少 FLOPs 和加速推理速度。
實驗結果
在論文中,SparseDiT 在 Class-conditional image generation、Class-conditional video generation 和 Text-to-image generation 三個生成任務中進行了實驗並取得了顯著效果。
1. Class-conditional image generation
在 256×256 解析度下,SparseDiT-XL 實現了 43% 的 FLOPs 減少,與 87% 的推理速度提升,同時 FID 分數僅增加了 0.11。這表明,即使只使用約 25% 的 tokens,依然能維持類似的效能。
在 512×512 解析度條件下,SparseDiT 在高剪枝率情況下表現出更優質的效能-效率 trade-off,透過剪枝超過 90% 的 tokens,得到 55% 的 FLOPs 減少及 175% 的速度提升,FID 分數僅增加了 0.09。
這些結果證明 SparseDiT 解決了 DiT 架構中的運算負擔問題,並在維持效能品質同時帶來了顯著的運算效率提升。
2. Class-conditional video generation
在 FaceForensics、SkyTimelapse、UCF101 和 Taichi-HD 等四個公開資料集上進行,解析度為 256×256。
SparseDiT 在影片資料的額外時間維度上應用了更高的剪枝率,達到了 FLOPs 減少 56% 的效果,同時維持了競爭性的 FVD 評分,證明其在影片生成任務上的有效性。
3. Text-to-image generation 實驗設定:
使用 PixArt-α 模型為基礎模型,採用 SAM 資料集進行訓練與評估,進行文字到影像生成。
SparseDiT 在該任務上達到了與原始 PixArt-α 模型更小的 FID 分數,同時顯著加快了生成速度,顯示出方案在文字到影像生成任務中的有效性。