盤點一下!大型模型訓練的時間都花在哪了?

隨著模型規模越來越大,大型模型時代的工程能力和研究能力逐漸變得同等重要。還記得幾年前做科研經常看到一些論文,只要修改幾行MATLAB或Python程式碼,就能取得顯著的效能提升。然而,在當前,我估計這已變得相當困難。

現今大型模型的訓練、複雜的程式碼巢狀結構以及各種工程挑戰,我認為對於過去從事學術研究的人來說都不是那麼友善。我個人也深有體會,接觸Megatron-LM之後,整個人都變得更加沉靜了。

恰好最近看到一個部落格,詳細地介紹了訓練大型模型的各種技術。我看完之後,感覺深受啟發,因此寫了一版自己對其的理解與反思,權當作記錄,也樂於分享,若能對讀者有所幫助,也感到十分榮幸。

https://jax-ml.github.io/scaling-book/

訓練大型語言模型(LLM)常常讓人感覺像是在搞煉金術,但其實理解和最佳化模型效能並沒有那麼神秘。這個部落格的目標就是幫助你揭開擴展語言模型背後的科學原理:TPU(還有GPU)到底是如何運作的,它們之間是如何通訊的,LLM在真實硬體上是如何執行的,以及在訓練和推論過程中,如何將模型合理地平行化,讓它能在大規模情境下高效執行。如果你曾經有過這些疑問:「訓練一個LLM到底要花多少錢?」、「我要自己部署這個模型,需要多大的記憶體?」、「AllGather到底是啥玩意?」那這些內容,應該會對你有幫助。

當我們在硬體上執行演算法時,效能主要受到三個方面的限制:

  • • 計算速度 — 也就是機器進行數學運算的能力,例如每秒能處理多少次操作(OPs/second)

  • • 頻寬 — 資料在記憶體、快取、晶片之間搬運的速度(bytes/second)

  • • 總記憶體容量 — 也就是設備最多能儲存多少資料(bytes)

這些限制就像為計算繪製了一條「Roofline」(屋頂線):上方頂著計算能力,下方則承載著記憶體和頻寬瓶頸。透過這套Roofline模型,我們可以大致估算出某段計算最快能跑多快(上限)和最慢會卡在哪裡(下限)。

簡單來說:如果演算法資料太多、記憶體太小,那它就容易被「撐爆」;如果演算法計算量很大但資料搬運太慢,那就等著「搬磚」;如果都不是問題,那就看計算核心有多麼強大。這三點加起來,就決定了你的模型在硬體上到底能跑多快、擴展多大。後面我們會用這個模型來分析如何高效地將模型擴展到更大的硬體,例如TPU。

訓練模型,時間都花在哪了?

我們執行一個模型,耗費時間的地方主要有三類:

1. 計算(Computation)

深度學習模型本質上就是一堆矩陣乘法,每個矩陣乘法由大量浮點加法和乘法組成,也就是所謂的FLOPs(浮點運算)。我們使用的加速卡(GPU或TPU)處理FLOPs的速度決定了計算所需的時間。這裡我們定義計算時間t_comp為

t_comp = FLOPs / FLOPs/s

舉個例子直觀理解一下,NVIDIA H100每秒可以完成大約9.89e14個bfloat16的FLOPs,TPU v6e也差不多是9.1e14。如果你有一個模型需要做1e12個FLOPs,那在H100上大概只需要1e12 / 9.89e14 = 1.01毫秒,在TPU v6e上大概是1.1毫秒。

這告訴我們,只計算數學運算的話,其實模型是可以很快的。

2. 晶片內部的通訊(Communication within a chip)

模型執行過程中,張量需要在晶片的記憶體(例如HBM)和計算核心之間傳輸。這個傳輸速度叫做HBM頻寬。對於H100來說,這大約是3.35TB/s;而對TPU v6e,大約是1.6TB/s。

這部分資料搬運的時間也需要算進去,特別是當資料量比較大時。

3. 晶片之間的通訊(Communication between chips)

當你的模型太大,需要多個加速卡(例如多個GPU或TPU)協作時,張量還需要在不同的晶片之間傳輸。不同的連接方式(例如ICI、DCN、PCIe)速度不一樣,單位和晶片內部通訊速度一樣,仍然是bytes/second。到這裡,我們也給通訊時間t_comm做一個定義:

t_comm = Bytes / Bytes/s

估算總時間的方式

無論是計算還是通訊,我們都可以粗略估計所需時間。並且:

  • • 理論下限(Lower bound):就是計算和通訊中耗時更長的那一個;

  • • 理論上限(Upper bound):是兩者耗時之和。

大多數時候,我們可以透過「通訊和計算平行」來讓實際耗時更接近下限。因此,最佳化的目標通常是讓通訊時間和計算時間盡可能重疊。而即使最壞的情況,也只是差一個因子2(也就是最長不會慢兩倍)。另外,如果計算時間遠大於通訊時間,那就說明硬體一直在計算、沒有等待資料,說明我們主要受限於計算能力,處於計算瓶頸(compute-bound)狀態;如果反過來是通訊時間大於計算時間,那就是我們大部分時間都在「等待資料」,說明被通訊瓶頸(communication-bound)限制了,FLOPs的潛力沒有完全發揮出來,有浪費。那麼我們怎麼判斷一個操作是計算瓶頸還是通訊瓶頸呢?

算術強度

看一個關鍵指標:算術強度(Arithmetic Intensity),也叫運算強度(Operational Intensity)。

I_a = FLOPs / Bytes

注意這裡的「通訊」既包括晶片內部(如HBM)也包括晶片之間(如GPU-GPU)。這個值就表示:「每搬運一個位元組,能做多少次浮點運算」。直覺上來理解,如果算術強度高:說明每搬運一次資料能做很多計算,計算時間就佔主導,是計算瓶頸;如果算術強度低,說明大部分時間都在搬運資料,計算單元在空轉,是通訊瓶頸;而這兩者「誰主誰次」的轉捩點,就叫peak arithmetic intensity(峰值算術強度),也就是:

peak FLOPs/s ÷ 頻寬(bytes/s)

峰值算術強度是硬體的一個特性。我們以TPU v5e為例,它最大計算能力是1.97e14 FLOPs/s,頻寬是8.2e11 bytes/s,所以它的峰值算術強度大約是1.97e14 / 8.2e11 ≈ 240 FLOPs/byte,這就意味著,如果一個演算法的算術強度低於240,它會變成通訊瓶頸。以點積為例,計算兩個向量x、y,假設它們的長度都是N,每個元素是bfloat16:

  • • 要從記憶體讀出x和y,每個都是2N位元組,一共4N位元組(注意:bf16每個元素是2位元組,所以這裡乘以2);

  • • 做N次乘法,N-1次加法,共2N-1個FLOPs;

  • • 最後再寫回2位元組。

當N趨近於無限大時,整體的算術強度(2N-1) FLOPs / (4N+2) bytes近似為0.5 FLOPs/byte,這個值遠遠小於TPU的240 FLOPs/byte,所以這個操作是通訊瓶頸,說明你即使硬體很強,也沒辦法發揮它的全部計算力。

總結一句話:算術強度就是告訴你:你搬運的每一份資料,能不能被「吃乾榨淨」。如果搬運得多,計算得少,硬體再強也浪費;如果計算得多,搬運得少,那你就能用滿算力。這就是判斷計算 vs. 通訊瓶瓶頸的核心標準。

Roofline

我們可以利用一種叫做Roofline圖的方式,將計算能力和記憶體頻寬的權衡形象地描繪出來。這個圖是一個對數座標圖,橫軸是算術強度(FLOPs per byte),縱軸是你在特定硬體上能達到的最大吞吐量(FLOPs/s)。

圖片

圖上會看到幾條線:

  • • 紅色區域:演算法的算術強度太低,無論你的頻寬是BW1還是BW2,都會被頻寬限制住,硬體的FLOPs沒有被完全利用。

  • • 黃色區域:演算法在低頻寬(BW1)下受限,但如果你換個高頻寬(BW2)就能跑得更快。

  • • 綠色區域:算術強度足夠高,不再受記憶體頻寬影響,這時瓶頸變成了計算能力,硬體已經達到滿載。

圖中也有兩個演算法:

  • • Algo 1(左邊):算術強度低,受限於記憶體頻寬,只用到硬體一小部分計算力,是通訊瓶頸。

  • • Algo 2(右邊):算術強度高,達到硬體的峰值FLOPs/s,是計算瓶頸,充分利用了硬體。

這張圖告訴我們:如果一個演算法在紅區,你可以透過提高算術強度(例如增加計算量、減少記憶體存取)或者提高記憶體頻寬來讓它更快;如果一個演算法已經在綠區,那再提高頻寬或者強度就沒什麼意義了,因為已經達到計算瓶頸。

矩陣乘法

講了這麼多,我們來看一個實際的應用。我們考慮最常見的演算法之一:矩陣乘法(matrix multiplication,簡稱matmul)。

假設你有兩個矩陣:X,大小bf16[B, D],Y,大小bf16[D, F],得到結果矩陣Z bf16[B, F]。為了計算這個matmul,你得從記憶體裡讀取2DF + 2BD位元組的資料,執行2BDF個FLOPs(每次乘法+加法);然後把Z寫回記憶體,要寫2BF位元組。

如果我們假設B(批次大小batch size)遠小於D和F,也就是token數量相對embedding size和頭數小得多(這在Transformer中很常見),那麼算術強度大約就是:

2BDF / (2DF + 2BD + 2BF) ≈ F / (F + B)

換句話說,這個強度與批次大小成正比。所以,在TPU v5e上,如果你使用的是bfloat16類型,那麼只要你的批次大小超過240個token,你就可以做到計算受限(compute-bound)。

當然,這些都是單卡內的記憶體頻寬限制,也就是看顯存頻寬能不能滿足計算需求。但這其實只是最簡單的一類roofline,現實中我們更常遇到的瓶頸是多卡之間的通訊頻寬,尤其是在多個TPU/GPU上做分散式矩陣乘法時。來舉個例子,還是原來的X和Y,你把它們在維度D上一分為二,分別放在兩張卡(例如2個TPU)。現在你要做矩陣乘法,做法如下:

  • • 在TPU 0上:計算前一半:A = X[:, :D//2] @ Y[:D//2, :]

  • • 在TPU 1上:計算後一半:B = X[:, D//2:] @ Y[D//2:, :]

  • • 然後兩張卡交換A和B的結果,把它們加起來得到最終輸出。

對於這一次,每張卡只做一半的工作,所以計算時間是原來的一半:

t_comp = BDF / FLOPs/s

而通訊時間是兩個卡交換結果的時間,也就是:

t_comm = 2BF / Bandwidth

我們要找出什麼時候通訊時間小於計算時間(也就是還能跑滿卡,不受通訊限制),解這個不等式:

BDF / FLOPs/s > 2BF / Bandwidth

D > 2 * (FLOPs/s) / Bandwidth

也就是說:只要你的embedding size D > 8755,通訊就不是瓶頸,你是計算受限(compute-bound);反之就是通訊受限(communication-bound)。在這個場景下,決定是否計算受限的,不是批次大小B,而是特徵維度D。

多卡分散式計算下的roofline限制關鍵在於通訊頻寬,是否能充分利用硬體取決於通訊 vs. 計算的比例 — 而這個比例依賴於模型維度,而不是批次大小。理解這個規律,才能知道你什麼時候該切分模型、怎麼切分才能不踩通訊瓶頸。

主標籤:大型語言模型

次標籤:效能最佳化機器學習工程硬體限制分散式訓練


上一篇:回顧Qwen3廢棄的混合推理模式

下一篇:超越人類標註,Meta 提出 CoT-Self-Instruct:如何用「推理式自進化」重塑大型語言模型訓練

分享短網址