As model scales continue to grow, engineering capabilities and research capabilities in the era of large models are gradually becoming equally important. I remember a few years ago, when doing research, you would often see papers achieve effective performance just by changing a few lines of MATLAB or Python code. However, at present, I estimate this is quite difficult.
Nowadays, the training of large models, complex code nesting, and various engineering challenges, I feel, are not very friendly to those who used to do academic research. I have deeply felt this myself; after encountering Megatron-LM, I became much more focused and calm.
Coincidentally, I recently came across a blog that detailed various techniques for training large models. After reading it, I felt deeply inspired, and thus wrote my own interpretation and reflections, considering it a record and happy to share. If it can also be helpful to readers, I would be very honored.
https://jax-ml.github.io/scaling-book/
Training LLMs often feels like alchemy, but understanding and optimizing model performance isn't so mysterious. The goal of this blog is to help you uncover the scientific principles behind scaling language models: how TPUs (and GPUs) actually work, how they communicate with each other, how LLMs run on real hardware, and how to properly parallelize models for efficient operation at scale during training and inference. If you've ever had questions like: "How much does it cost to train an LLM?" "How much memory do I need to deploy this model myself?" "What exactly is AllGather?" then this content should be helpful to you.
When we run algorithms on hardware, performance is primarily limited by three factors:
• Computation speed — that is, the machine's ability to perform mathematical operations, such as how many operations it can process per second (OPs/second)
• Bandwidth — the speed at which data is moved between memory, cache, and chips (bytes/second)
• Total memory capacity — that is, how much data a device can hold at most (bytes)
These limits are like drawing a Roofline for computation: computation capability on top, and memory and bandwidth bottlenecks below. Through this Roofline model, we can roughly estimate how fast a computation can run at its quickest (upper bound) and where it will get stuck at its slowest (lower bound).
Simply put: if an algorithm has too much data and too little memory, it's easily "overwhelmed"; if an algorithm has a large computational load but data transfer is too slow, it's waiting to "move bricks"; if neither of these are issues, then it depends on how powerful the computing core is. These three points combined determine how fast and how large your model can run on hardware. Later, we will use this model to analyze how to efficiently scale models to larger hardware, such as TPUs.
Where Does Model Training Time Go?
When we run a model, time is primarily spent in three categories:
1. Computation
Deep learning models are essentially a collection of matrix multiplications, each composed of a large number of floating-point additions and multiplications, also known as FLOPs (floating point operations). The speed at which our accelerators (GPUs or TPUs) process FLOPs determines the time required for computation. Here we define computation time t_comp as
t_comp = FLOPs / FLOPs/s
To give an intuitive example, an NVIDIA H100 can complete approximately 9.89e14 bfloat16 FLOPs per second, and a TPU v6e is roughly 9.1e14. If you have a model that needs to perform 1e12 FLOPs, it would only take about 1e12 / 9.89e14 = 1.01ms on an H100, and about 1.1ms on a TPU v6e.
This tells us that if we only consider mathematical operations, models can actually be very fast.
2. Communication within a chip
During model execution, tensors need to be transferred between the chip's memory (e.g., HBM) and the computing cores. This transfer speed is called HBM bandwidth. For H100, this is approximately 3.35TB/s; for TPU v6e, it's about 1.6TB/s.
The time spent moving this data also needs to be accounted for, especially when the data volume is large.
3. Communication between chips
When your model is too large and requires multiple accelerators (such as multiple GPUs or TPUs) to collaborate, tensors also need to be transferred between different chips. Different connection methods (e.g., ICI, DCN, PCIe) have different speeds, and their units are the same as within-chip communication speed, still bytes/second. Here, we also define communication time t_comm as:
t_comm = Bytes / Bytes/s
Method for Estimating Total Time
Whether it's computation or communication, we can roughly estimate the time required. And:
• Theoretical lower bound: the longer of the computation and communication times;
• Theoretical upper bound: the sum of the two times.
Most of the time, we can make the actual time closer to the lower bound by "parallelizing communication and computation." Therefore, the optimization goal is usually to overlap communication time and computation time as much as possible. Even in the worst case, it's only off by a factor of 2 (i.e., it won't be more than twice as slow). Furthermore, if computation time is much greater than communication time, it means the hardware is constantly computing and not waiting for data, indicating that we are primarily limited by computation capability, in a compute-bound state; if, conversely, communication time is greater than computation time, it means we spend most of our time "waiting for data," indicating that we are limited by a communication bottleneck (communication-bound), and the potential of FLOPs is not fully utilized, leading to waste. So, how do we determine if an operation is compute-bound or communication-bound?
Arithmetic Intensity
Look at a key metric: Arithmetic Intensity, also called Operational Intensity.
I_a = FLOPs / Bytes
Note that "communication" here includes both within-chip (e.g., HBM) and between-chip (e.g., GPU-GPU) communication. This value indicates: "How many floating-point operations can be performed for every byte moved." Intuitively, if the arithmetic intensity is high: it means many computations can be done for each data transfer, so computation time dominates, and it's a compute bottleneck; if the arithmetic intensity is low, it means most of the time is spent moving data, and the computing units are idle, so it's a communication bottleneck. The turning point where one dominates the other is called peak arithmetic intensity, which is:
peak FLOPs/s ÷ Bandwidth (bytes/s)
Peak arithmetic intensity is a characteristic of hardware. Taking TPU v5e as an example, its maximum computing capability is 1.97e14 FLOPs/s, and its bandwidth is 8.2e11 bytes/s, so its peak arithmetic intensity is approximately 1.97e14 / 8.2e11 ≈ 240 FLOPs/byte. This means that if an algorithm's arithmetic intensity is below 240, it will become a communication bottleneck. Taking dot product as an example, calculating two vectors x, y, assuming their length is N and each element is bfloat16:
• To read x and y from memory, each is 2N bytes, totaling 4N bytes (Note: bf16 elements are 2 bytes, so multiply by 2 here);
• Perform N multiplications and N-1 additions, for a total of 2N-1 FLOPs;
• Finally, write back 2 bytes.
When N approaches infinity, the overall arithmetic intensity (2N-1) FLOPs / (4N+2) bytes approximates 0.5 FLOPs/byte, which is much smaller than the TPU's 240 FLOPs/byte. Therefore, this operation is communication-bound, meaning that even with strong hardware, you cannot fully utilize its computing power.
In summary: Arithmetic intensity tells you whether every piece of data you move can be "fully utilized." If you move a lot and compute little, powerful hardware is wasted; if you compute a lot and move little, you can fully utilize your computing power. This is the core criterion for judging compute vs. communication bottlenecks.
Roofline
We can use a method called the Roofline graph to visually depict the trade-off between computational capability and memory bandwidth. This graph is a logarithmic plot, with the x-axis representing arithmetic intensity (FLOPs per byte) and the y-axis representing the maximum throughput (FLOPs/s) you can achieve on a specific hardware.
On the graph, you will see several lines:
• Red area: The algorithm's arithmetic intensity is too low, and regardless of whether your bandwidth is BW1 or BW2, it will be limited by bandwidth. The hardware's FLOPs are not fully utilized.
• Yellow area: The algorithm is limited by low bandwidth (BW1), but if you switch to high bandwidth (BW2), it can run faster.
• Green area: The arithmetic intensity is high enough, no longer affected by memory bandwidth. At this point, the bottleneck becomes computational capability, and the hardware is fully utilized.
The graph also shows two algorithms:
• Algo 1 (left): Low arithmetic intensity, limited by memory bandwidth, only using a small portion of the hardware's computational power, indicating a communication bottleneck.
• Algo 2 (right): High arithmetic intensity, reaching the hardware's peak FLOPs/s, indicating a compute bottleneck, fully utilizing the hardware.
This graph tells us: if an algorithm is in the red zone, you can make it faster by increasing arithmetic intensity (e.g., increasing computation, reducing memory access) or by increasing memory bandwidth; if an algorithm is already in the green zone, then increasing bandwidth or intensity further has no significant meaning, because it's already compute-bound.
Matrix Multiplication
Having discussed so much, let's look at a practical application. We consider one of the most common algorithms: matrix multiplication (matmul).
Suppose you have two matrices: X, size bf16[B, D], and Y, size bf16[D, F], resulting in matrix Z bf16[B, F]. To compute this matmul, you need to read 2DF + 2BD bytes of data from memory, perform 2BDF FLOPs (each multiplication + addition); then write Z back to memory, which requires writing 2BF bytes.
If we assume B (batch size) is much smaller than D and F, meaning the number of tokens is relatively much smaller than the embedding size and number of heads (this is common in Transformers), then the arithmetic intensity is approximately:
2BDF / (2DF + 2BD + 2BF) ≈ F / (F + B)
In other words, this intensity is directly proportional to the batch size. So, on a TPU v5e, if you are using the bfloat16 type, as long as your batch size exceeds 240 tokens, you can achieve compute-bound performance.
Of course, these are single-card memory-bandwidth limitations, meaning whether the video memory bandwidth can feed the computation. But this is actually just the simplest type of roofline; in reality, we more often encounter bottlenecks in inter-card communication bandwidth, especially when performing distributed matrix multiplication across multiple TPUs/GPUs. Let's take an example: with the original X and Y, you split them in dimension D into two halves, placing each half on two cards (e.g., 2 TPUs). Now to perform matrix multiplication, the process is as follows:
• On TPU 0: Compute the first half: A = X[:, :D//2] @ Y[:D//2, :]
• On TPU 1: Compute the second half: B = X[:, D//2:] @ Y[D//2:, :]
• Then the two cards exchange the results of A and B, adding them up to get the final output.
For this instance, each card only does half the work, so the computation time is half of the original:
t_comp = BDF / FLOPs/s
And the communication time is the time for the two cards to exchange results, which is:
t_comm = 2BF / Bandwidth
We want to find when communication time is less than computation time (i.e., when the cards can still be fully utilized, not limited by communication). Solve this inequality:
BDF / FLOPs/s > 2BF / Bandwidth
D > 2 * (FLOPs/s) / Bandwidth
That is: as long as your embedding size D > 8755, communication will not be the bottleneck; you will be compute-bound. Conversely, you will be communication-bound. In this scenario, what determines whether you are compute-bound is not the batch size B, but the feature dimension D.
The Roofline limitation in multi-card distributed computing critically depends on communication bandwidth. Whether hardware can be fully utilized depends on the communication vs. computation ratio – and this ratio depends on model dimensions, not batch size. Understanding this rule helps you know when and how to split your model to avoid hitting communication bottlenecks.