字節跳動Seed新作DeltaFormer:下一世代模型架構的嘗試

图片

MLNLP 社群是國內外知名的機器學習與自然語言處理社群,受眾涵蓋國內外 NLP 碩博士生、大學老師以及企業研究人員。

社群的願景是促進國內外自然語言處理、機器學習學術界、產業界和廣大愛好者之間的交流與進步,特別是初學者同學們的進步。

來源 | 知乎

作者|小明同學

簡單介紹一下最近在 Seed 的工作,希望能拋磚引玉。

https://arxiv.org/pdf/2505.19488

Deltaformer 與 Transformer 核心元件差異一覽。簡單來說,對標準 attention 中 q、k、v 的 v 元件進行了修正,而是使用 q、k、u 進行 attention。而 u 則是使用 w、k、u 進行 attention,再與 v 組合得到。

Deltaformer 與 Transformer 核心元件差異一覽。簡單來說,對標準 attention 中 q、k、v 的 v 元件進行了修正,而是使用 q、k、u 進行 attention。而 u 則是使用 w、k、u 進行 attention,再與 v 組合得到。

動機

表達性與不可並行性之間有本質矛盾

從高層次來看,表達性與可並行性之間存在著不可調和的矛盾。某些問題的正確結果輸出,客觀上需要一定的深度。通俗地說,解題時有些步驟可以並行處理,但有些關鍵步驟必須循序漸進。如果這些關鍵步驟的最大長度低於某個下界,那麼就不可能得到正確答案。為此,上個世紀研究計算複雜度的科學家們,也開始關注並行複雜度。他們在 P 類問題中,根據單一節點允許的操作類型、單一節點允許的扇入以及整個計算圖的關鍵路徑長度,將 P 問題劃分為若干類,例如   。

關於不同複雜度類別的示意圖。值得注意的是,這張圖上關於真包含的關係不是很嚴謹,目前已被證明的是 AC^0 != TC^0,其他層次之間是否真包含尚未有嚴謹證明,但我們普遍認為 NC^1 != TC^0。此外,在 NC^1$ 和 $NC^2 之間還有許多小類,例如 SL、NL 等。而對數精度的 Transformer 模型則被證明在 TC^0 中。

關於不同複雜度類別的示意圖。值得注意的是,這張圖上關於真包含的關係不是很嚴謹,目前已被證明的是 AC^0 != TC^0,其他層次之間是否真包含尚未有嚴謹證明,但我們普遍認為 NC^1 != TC^0。此外,在 NC^1$ 和 $NC^2 之間還有許多小類,例如 SL、NL 等。而對數精度的 Transformer 模型則被證明在 TC^0 中。

LSTM 與 Transformer 之間或許有廣闊的發展空間

上個世紀末開始流行的 LSTM 是一種本質上不可並行的 P 模型。但近十年來,GPU 重新定義了環境,使得高並行的 Transformer 模型成為現今大模型領域中最受歡迎的骨幹架構。同時,並行性與表現力之間的根本性矛盾也導致大模型的缺陷,例如計數能力的不足,必須依賴「思維鏈」(Chain-of-thought)才能解決複雜問題。

那麼,難道沒有一種仍然可以高度並行,只是並行程度比 Transformer 稍差,但表現力更高的架構嗎?前人告訴我們在   和   之間還存在著大量的複雜度類別,這給我們帶來了想像空間。或許真的存在也能在 GPU 上高效並行實現,且比 Transformer 更具表現力的模型。

複雜度模型的回歸

相較於 Transformer 和線性注意力(Linear attention)不顧及先前狀態,直接寫入或附加鍵值對的做法,Delta 規則在每次寫入時會考量先前的狀態進行修改。這件事情在上個世紀已有不少研究,包括 Schmidhuber[1]、Sutton[2] 以及 Hinton[3]。儘管當時的名稱是快速權重程式設計(fast weight programming),但核心概念是一致的。2021 年,Schimidhuber[4] 還重新提及了一次。但在 GPU 時代下,難以在 GPU 上高度並行實現的方式,是無法成為下一代模型的;否則,直接退回 LSTM 模型,不斷加大隱藏層尺寸就好了。2024 年,Songlin Yang[5] 等人發現了 Delta 規則的並行潛力,將 DeltaNet 在 GPU 上並行化,這也使得 Delta 規則重新受到關注。而這個模型是能夠達到  複雜度的模型,因此能在狀態追蹤(State tracking)相關任務上表現良好。

Transformer + Delta 規則 = Deltaformer

DeltaNet 受限於有限的狀態空間,其最基本的長文資訊檢索能力有限,而 Transformer 的長文資訊檢索能力則相當優異。將兩者有機融合,尋求一個完全超越 Transformer 架構的模型,便是我們這項工作的目的。

方法

Deltaformer = Delta 規則 + 核函數技巧

核函數技巧(Kernel trick)也是一個古老的方法,從 SVM 時代起,核函數 SVM 便佔有一席之地。這種將特徵隱式地擴展到無限維度的方法,或許是增加記憶容量的好方法。

引入核函數:  ,其中  是一個從有限維度映射到無限維度的映射,我們一般不將其顯式地寫出來。那麼我們將 Delta 規則重寫如下:

Delta 規則 + 核函數技巧的版本

Delta 規則 + 核函數技巧的版本

最大的問題在於這裡面的   和 S 都是無限維度的,無法在電腦上運算。

好在經過一些推導,可以將   和 S 這些涉及無限維度的東西都消除了,只保留了   。

寫入方式為:

讀取方式為:

當然也可以在其中進行一些其他操作,例如上方和下方的   採用不同的方式,以及添加一些可學習的參數等。

我們使用 softmax 作為   ,那麼我們就得到了 Transformer 的 Delta 規則升級版。

接下來則要回答兩個問題:

• 1) 如何在 GPU 上高效實現

• 2)如何證明這種表達式能執行   的任務

分塊演算法(Chunk-wise algorithm)

困難的部分在於   的計算,而   的計算則可正常使用 Flash attention。

直接使用   在解碼階段可以這麼做,但模型訓練時,這種遞歸計算還不如非線性 RNN。

但我們也可以寫成更緊湊的形式:  ,其中

那麼有   ,直接進行這麼大的矩陣求逆,雖然並行度高,但 I/O 無法支援。

c 作為下標表示當前區塊(current chunk)的對應變數,p 作為下標表示先前的變數(previous variable),因此有:

故而:

利用這種方法可以逐區塊地計算   。如果序列長度是   ,區塊大小是   ,頭維度(head dim)是   ,使用的是前代法求逆,那麼總的浮點運算次數(Flops)是  

能追蹤 n 個元素的交換

我們理論證明這個模型架構的上限是能夠達到   的。我們研究了追蹤 n 個元素交換這項任務,這是一個   。我們採用了建構性的方式進行了證明,具體證明可以參閱原論文。結論是能夠以   的頭維度來追蹤   個物體的交換。

關於證明 Deltaformer 能夠追蹤 n 個元素交換的定理。

關於證明 Deltaformer 能夠追蹤 n 個元素交換的定理。

實驗

Deltaformer 可以追蹤交換,但 Transformer 難以做到

图片

例如,我們可以發現 Transformer 想要追蹤 5 個元素的交換還是挺困難的。但核函數的選擇對於執行   還是挺重要的。

Deltaformer 可以進行有向無環圖的連通性判斷

图片

這也挺合理的,因為 Deltaformer 裡面的求逆操作   ,如果它編碼了 i 節點和 j 節點是否相鄰,那麼也就編碼了   節點和   節點是否是 k 步可達的。那麼它就編碼了 i 節點和 j 節點是否連通的資訊。(從另一個角度來說,也是因為求逆這個遠遠超出   的操作拓展了 Transformer 的表現力。)

更多玩具模型(toy model)的實驗和有趣的現象可以參考我們的原論文。

結論

我們提出了 Deltaformer 這個模型,它擁有了 Transformer 模型的記憶力以及能在 GPU 上高效訓練的特性,同時還突破了 Transformer 的   表現力限制。希望能為以後設計更高表現力的模型拋磚引玉。

引用連結

[1] Schmidhuber: https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=2f0becffd2f44b198d28074d01722e4c7905dae2

[2] Sutton: https://web.cs.umass.edu/publication/docs/1980/UM-CS-1980-018.pdf

[3] Hinton: https://www.cs.toronto.edu/~fritz/absps/fastweights.pdf

[4] Schimidbuber: https://proceedings.neurips.cc/paper_files/paper/2021/file/3f9e3767ef3b10a0de4c256d7ef9805d-Paper.pdf

[5] Songlin Yang: https://arxiv.org/pdf/2406.06484

技術交流群邀請函

圖片

△長按添加小助手

掃描 QR Code 加入小助手微信

請備註:姓名-學校/公司-研究方向

(如:小張-哈工大-對話系統)

即可申請加入自然語言處理/Pytorch 等技術交流群

關於我們

MLNLP 社群是由國內外機器學習與自然語言處理學者聯合建立的民間學術社群,目前已發展為國內外知名的機器學習與自然語言處理社群,旨在促進機器學習、自然語言處理學術界、產業界和廣大愛好者之間的進步。

社群可以為相關從業者在深造、就業及研究等方面提供開放交流平台。歡迎大家關注並加入我們。

圖片

主標籤:人工智慧

次標籤:神經網路機器學習計算複雜度Transformer模型


上一篇:越髒越安全?哈佛團隊最新研究:10%毒性訓練讓大型模型百毒不侵

下一篇:全球程式設計師譁然!黃仁勳於倫敦放話:程式語言的未來是「Human」

分享短網址