大規模モデルのトレーニング時間、一体どこに費やされているのか?徹底解説

モデルの規模が拡大するにつれて、大規模モデル時代のエンジニアリング能力と研究能力は同等に重要になってきています。数年前、研究をする中で、MATLABやPythonのコードを数行変更するだけで、効果的なパフォーマンスが得られる論文をよく見かけたことを覚えています。しかし、現在では、それは非常に困難になっていると思います。

現在の、大規模モデルのトレーニング、複雑なコードの入れ子構造、そして様々なエンジニアリング上の課題は、かつて学術研究を行っていた人々にとってもあまり友好的ではないと感じています。私自身も深く実感しており、Megatron-LMに触れて以来、すっかり落ち着きを取り戻しました。

ちょうど最近、大規模モデルのトレーニングに関する様々な技術を詳細に解説したブログを見つけました。それを読んで、私は非常に感銘を受け、その理解と反省を自分なりにまとめて記録することにしました。これを共有できることは嬉しいことであり、読者の皆様にも役立つことができれば、大変光栄です。

https://jax-ml.github.io/scaling-book/

LLMのトレーニングはしばしば錬金術のように感じられますが、モデルの性能を理解し最適化することは、それほど神秘的なことではありません。このブログの目的は、言語モデルをスケールさせる背後にある科学的な原理を明らかにすることです。TPU(およびGPU)がどのように機能するのか、それらがどのように通信するのか、LLMが実際のハードウェア上でどのように動作するのか、そしてトレーニングと推論の過程で、大規模なシナリオで効率的に動作するようにモデルを適切に並列化する方法についてです。「LLMのトレーニングにどれくらい費用がかかるのか?」「このモデルを自分でデプロイするには、どれくらいのメモリが必要か?」「AllGatherって一体何?」といった疑問を抱いたことがあるなら、この内容はあなたにとって役立つはずです。

ハードウェア上でアルゴリズムを実行する際、パフォーマンスは主に3つの側面によって制限されます。

  • • 計算速度 — つまり、マシンが数学演算を行う能力、例えば1秒あたりに処理できる操作数(OPs/second)

  • • 帯域幅 — メモリ、キャッシュ、チップ間でデータを転送する速度(bytes/second)

  • • 総メモリ容量 — つまり、デバイスが最大で保持できるデータ量(bytes)

これらの制限は、計算に対して「ルーフライン」を描くようなものです。上には計算能力が、下にはメモリと帯域幅のボトルネックがあります。このルーフラインモデルを通じて、ある計算が最速でどれくらい速く実行できるか(上限)と、最も遅くどこで詰まるか(下限)を大まかに見積もることができます。

簡単に言えば、アルゴリズムのデータが多すぎてメモリが小さい場合、それは「パンク」しやすいです。アルゴリズムの計算量が非常に大きいのにデータ転送が遅い場合、それは「荷物運び」を待つことになります。どちらも問題ない場合は、計算コアの能力次第です。これら3つの要素が組み合わさって、モデルがハードウェア上でどれだけ速く、どれだけ大規模に動作できるかが決まります。後ほど、このモデルを使って、TPUのようなより大規模なハードウェアにモデルを効率的に拡張する方法を分析します。

モデルのトレーニングにはどこに時間がかかるのか?

モデルを実行する際、時間がかかる場所は主に3種類あります。

1. 計算(Computation)

深層学習モデルは本質的に多数の行列乗算であり、各行列乗算は大量の浮動小数点加算と乗算、いわゆるFLOPs(floating point operations)で構成されます。私たちが使用するアクセラレータ(GPUまたはTPU)がFLOPsを処理する速度が、計算に必要な時間を決定します。ここで計算時間t_compを定義します。

t_comp = FLOPs / FLOPs/s

具体例で直感的に理解してみましょう。NVIDIA H100は1秒あたり約9.89e14個のbfloat16 FLOPsを完了でき、TPU v6eも同様に約9.1e14です。もし1e12個のFLOPsを実行する必要があるモデルがある場合、H100上では約1e12 / 9.89e14 = 1.01msしかかからず、TPU v6e上では約1.1msです。

このことから、数学演算だけを考えれば、モデルは非常に高速に動作できることがわかります。

2. チップ内部の通信(Communication within a chip)

モデルの実行中、テンソルはチップのメモリ(HBMなど)と計算コアの間で転送される必要があります。この転送速度はHBM帯域幅と呼ばれます。H100の場合、これは約3.35TB/sです。一方、TPU v6eの場合、約1.6TB/sです。

このデータ転送の時間も考慮に入れる必要があります。特にデータ量が大きい場合です。

3. チップ間の通信(Communication between chips)

モデルが大きすぎて複数のアクセラレータ(例えば複数のGPUやTPU)が連携する必要がある場合、テンソルは異なるチップ間でも転送される必要があります。異なる接続方式(例えばICI、DCN、PCIe)では速度が異なり、単位はチップ内部通信速度と同じくbytes/secondです。ここで、通信時間t_commを定義します。

t_comm = Bytes / Bytes/s

総時間の推定方法

計算であろうと通信であろうと、私たちは必要な時間を概算することができます。そして:

  • • 理論的下限(Lower bound):計算と通信のうち、より時間がかかる方です。

  • • 理論的上限(Upper bound):両方の時間の合計です。

ほとんどの場合、「通信と計算の並列化」によって、実際の所要時間を下限に近づけることができます。そのため、最適化の目標は通常、通信時間と計算時間をできるだけ重複させることです。最悪の場合でも、せいぜい2倍の差しかありません(つまり、最も遅くても2倍以上遅くなることはありません)。また、計算時間が通信時間よりはるかに長い場合、それはハードウェアが常に計算しており、データを待っていないことを意味し、我々が主に計算能力に制限されている、計算律速(compute-bound)の状態であることを示します。逆に通信時間が計算時間より長い場合、それはほとんどの時間を「データを待って」いることを意味し、通信ボトルネック(communication-bound)によって制限されており、FLOPsの潜在能力が完全に発揮されていない、つまり無駄があることを示します。では、ある操作が計算律速か通信律速かをどのように判断するのでしょうか?

算術強度

重要な指標である「算術強度(Arithmetic Intensity)」、または「運用強度(Operational Intensity)」を見てみましょう。

I_a = FLOPs / Bytes

ここで「通信」とは、チップ内部(HBMなど)とチップ間(GPU-GPUなど)の両方を含みます。この値は、「1バイトを転送するごとに、どれだけの浮動小数点演算を実行できるか」を示します。直感的に理解すると、算術強度が高い場合:データを一度転送するごとに多くの計算ができることを意味し、計算時間が支配的であり、計算ボトルネックです。算術強度が低い場合、ほとんどの時間がデータ転送に費やされ、計算ユニットがアイドル状態であることを意味し、通信ボトルネックです。そして、これら両者の「どちらが主か」の転換点を、「ピーク算術強度(peak arithmetic intensity)」と呼びます。これは、

peak FLOPs/s ÷ 帯域幅(bytes/s)

ピーク算術強度はハードウェアの特性です。TPU v5eを例にとると、最大計算能力は1.97e14 FLOPs/s、帯域幅は8.2e11 bytes/sなので、そのピーク算術強度は約1.97e14 / 8.2e11 ≈ 240 FLOPs/byteです。これは、もしあるアルゴリズムの算術強度が240を下回る場合、それは通信ボトルネックになることを意味します。ドット積を例に、2つのベクトルx、yを計算する場合を考えます。それらの長さがNで、各要素がbfloat16だと仮定します。

  • • メモリからxとyを読み出す必要があり、それぞれ2Nバイト、合計で4Nバイトです(注意:bf16は各要素が2バイトなので、ここで2を乗じています)。

  • • N回の乗算、N-1回の加算を行い、合計2N-1個のFLOPsを実行します。

  • • 最後に2バイトを書き戻します。

Nが無限大に近づくとき、全体の算術強度 (2N-1) FLOPs / (4N+2) bytes は約0.5 FLOPs/byteに近似されます。この値はTPUの240 FLOPs/byteよりはるかに小さいため、この操作は通信ボトルネックであり、ハードウェアが非常に強力であっても、その計算能力を最大限に活用できないことを意味します。

一言でまとめると、算術強度は、あなたが転送するデータ一つ一つを「骨の髄までしゃぶり尽くせる」かどうかを教えてくれます。転送が多くて計算が少ない場合、どんなに強力なハードウェアも無駄になります。計算が多くて転送が少ない場合、あなたは計算能力を最大限に活用できます。これが計算 vs 通信ボトルネックを判断する核心的な基準です。

ルーフライン(Roofline)

私たちは「ルーフライン(Roofline)図」という方法を使って、計算能力とメモリ帯域幅のトレードオフを視覚的に表現することができます。この図は対数座標図であり、横軸は算術強度(FLOPs per byte)、縦軸は特定のハードウェアで達成できる最大スループット(FLOPs/s)を示します。

画像

図上にはいくつかの線が表示されます。

  • • 赤色領域:アルゴリズムの算術強度が低すぎるため、帯域幅がBW1であろうとBW2であろうと、帯域幅に制限され、ハードウェアのFLOPsが十分に活用されていません。

  • • 黄色領域:アルゴリズムが低帯域幅(BW1)で制限されていますが、高帯域幅(BW2)に切り替えると高速に動作します。

  • • 緑色領域:算術強度が十分に高いため、メモリ帯域幅の影響を受けなくなり、この時点でのボトルネックは計算能力となり、ハードウェアはフル稼働しています。

図には2つのアルゴリズムも示されています。

  • • Algo 1(左側):算術強度が低く、メモリ帯域幅に制限されており、ハードウェアの計算能力のほんの一部しか使用しておらず、通信ボトルネックです。

  • • Algo 2(右側):算術強度が高く、ハードウェアのピークFLOPs/sに達しており、計算ボトルネックであり、ハードウェアを十分に活用しています。

この図が教えてくれるのは、あるアルゴリズムが赤色領域にある場合、算術強度を高めること(例えば、計算量を増やす、メモリアクセスを減らす)またはメモリ帯域幅を向上させることで、それを高速化できるということです。もしアルゴリズムがすでに緑色領域にある場合、すでに計算ボトルネックであるため、帯域幅や強度をさらに向上させても意味がありません。

行列乗算

ここまで色々と説明しましたが、実際の応用例を見てみましょう。最も一般的なアルゴリズムの一つである行列乗算(matrix multiplication、略してmatmul)を考えてみます。

行列X(サイズbf16[B, D])と行列Y(サイズbf16[D, F])があり、結果行列Z(bf16[B, F])を得るとします。この行列乗算を計算するには、メモリから2DF + 2BDバイトのデータを読み込み、2BDF個のFLOPs(各乗算+加算)を実行する必要があります。その後、Zをメモリに書き戻すには2BFバイトを書き込む必要があります。

もしB(バッチサイズ)がDとFに比べて非常に小さいと仮定する、つまりトークン数が埋め込みサイズやヘッド数に比べてはるかに小さい場合(これはTransformerでよく見られる)、算術強度は概ね以下のようになります。

2BDF / (2DF + 2BD + 2BF) ≈ F / (F + B)

言い換えれば、この強度はバッチサイズに比例します。したがって、TPU v5eでbfloat16タイプを使用している場合、バッチサイズが240トークンを超えていれば、計算律速(compute-bound)にすることができます。

もちろん、これらは単一カード内のメモリ帯域幅の制限であり、つまりVRAM帯域幅が計算を十分に供給できるかどうかの問題です。しかし、これはルーフラインの最も単純なタイプに過ぎません。現実には、特に複数のTPU/GPUで分散行列乗算を行う場合、より頻繁に遭遇するボトルネックは複数カード間の通信帯域幅です。例を挙げましょう。元のXとYを、次元Dで2つに分割し、それぞれ2枚のカード(例えば2つのTPU)に配置します。ここで行列乗算を行う手順は以下の通りです。

  • • TPU 0で:前半を計算:A = X[:, :D//2] @ Y[:D//2, :]

  • • TPU 1で:後半を計算:B = X[:, D//2:] @ Y[D//2:, :]

  • • そして2枚のカードがAとBの結果を交換し、それらを合計して最終出力を得ます。

この場合、各カードは作業の半分しか行わないため、計算時間は元の半分になります。

t_comp = BDF / FLOPs/s

一方、通信時間は2枚のカードが結果を交換する時間であり、以下のようになります。

t_comm = 2BF / Bandwidth

通信時間が計算時間よりも短くなる時期(つまり、まだカードをフル稼働でき、通信の制約を受けない時期)を見つけるために、この不等式を解きます。

BDF / FLOPs/s > 2BF / Bandwidth

D > 2 * (FLOPs/s) / Bandwidth

つまり、埋め込みサイズDが8755より大きければ、通信はボトルネックではなく、計算律速(compute-bound)になります。そうでなければ通信律速(communication-bound)です。このシナリオでは、計算律速かどうかを決定するのはバッチサイズBではなく、特徴次元Dです。

多カード分散計算におけるルーフラインの制約は、通信帯域幅が重要であり、ハードウェアを最大限に活用できるかどうかは、通信と計算の比率に依存します。そして、この比率はモデルの次元に依存し、バッチサイズではありません。この法則を理解することで、いつモデルを分割すべきか、どのように分割すれば通信ボトルネックに陥らないかを把握できます。

メインタグ:大規模言語モデル

サブタグ:性能最適化機械学習エンジニアリングハードウェア制約分散学習


前の記事:Qwen3が廃止した混合推論モードを振り返る

次の記事:人間によるアノテーションを超えて:MetaがCoT-Self-Instructを発表 – 「推論的自己進化」でLLMトレーニングを再構築する方法

短いURLをシェア