Go smol or go home
Why we should train smaller LLMs on more tokens
If you have access to a big compute cluster and are planning to train a Large Language Model (LLM), you will need to make a decision on how to allocate your compute budget. This involves selecting the number of model parameters $N$ and the number of training tokens $D$. By applying the scaling laws, you can get guidance on how to reach the best model performance for your given compute budget, and find the optimal distribution of compute $C$ between the parameters $N_{opt}$ and training tokens $D_{opt}$.
However, for most use cases you should not train a compute-optimal LLM but instead spend some extra compute to obtain a smaller model. Smaller models not only make inference faster and cheaper, they are also much easier to use for developers and researchers with limited GPU resources. Although many LLM practitioners train their models on more tokens than the Chinchilla scaling laws suggest, not everyone is aware that scaling laws can assist in determining how much smaller models we can train and how much additional compute is required.
In this blogpost, I’ll show how to derive the trade-off between model size and compute overhead and reveal there is significant room to reduce the compute-optimal model size with minimal compute overhead. However, there comes a point where spending more compute resources leads to diminishing returns because you’ve hit the critical model size. The critical model size is essentially the minimum LLM capacity required to attain a specific loss level, and further reducing the model size beyond this point becomes near-impossible. My analysis suggest that the critical model size is around 30% of the Chinchilla optimal model and leads to a 100% compute overhead. Notably, recent models such as LLaMa-7B, which is trained on 1T tokens, are far from reaching the critical model size, indicating that there is ample room to train “smaller” LLMs for longer.
Recap of Chinchilla scaling laws
In Chinchilla’s third approach to estimating the scaling laws, the authors argue that the loss can be modelled as a function of the parameter count and number of seen tokens: $$L(N, D) = E + \frac{A}{N^{\alpha}} + \frac{B}{D^{\beta}}$$ The authors fitted the parameters on a series of experiments with various model sizes and training tokens and found the following parameter estimates: $$E=1.69, A=406.4, B=410.7, \alpha=0.32, \beta=0.28.$$
By optimizing this loss function $L$ under the constraint that the compute budget $C = 6ND$, you can show that the compute-optimal number of parameters $N_{opt}$ and compute-optimal number of tokens $D_{opt}$ follow a power law: $$N_{opt}(C) = G\left(\frac{C}{6}\right)^{\frac{\beta}{\alpha+\beta}}, D_{opt}(C) = G^{-1}\left(\frac{C}{6}\right)^{\frac{\alpha}{\alpha+\beta}}, G = \left(\frac{\alpha A}{\beta B}\right)^{\frac{1}{\alpha+\beta}}$$
Model size vs compute overhead
Suppose we reduce the optimal model size $N_{opt}$ by half. How much do we need to increase the training tokens to obtain the same model loss? To keep the same compute budget, we must double the number of training tokens $D_{opt}$ but we should expect some compute overhead and train for longer than that.
We can return to Chinchilla’s parameteric loss function to answer this question. We are looking to scale the parameters by $k_N$ and training tokens by $k_D$ while reaching the same loss as $L(N_{opt}, D_{opt})$. More precisely, we are looking to satisfy the following equation: $$L(N_{opt}, D_{opt}) = L(k_N N_{opt}, k_D D_{opt})$$ $$E + \frac{A}{N_{opt}^{\alpha}} + \frac{B}{D_{opt}^{\beta}} = E + \frac{A}{\left(k_N N_{opt}\right)^{\alpha}} + \frac{B}{\left(k_D D_{opt}\right)^{\beta}}$$
With a few mathematical steps, you find that:
$$k_D= \left(1 - (k_N^{-\alpha} - 1) \frac{A N_{opt}^{-\alpha}}{B D_{opt}^{-\beta}}\right)^{\frac{1}{-\beta}}$$
Once we found the data scaling factor $k_D$, we can determine the new compute budget $$C_{new} = 6 (k_N N_{opt}) (k_D D_{opt})$$ as well as the compute overhead $$C_{overhead} = \frac{C_{new} - C}{C}*100.$$
Interestingly, as I’ll show below, the data scaling factor $k_D$ is independent of the compute budget $C$. The resulting model-size vs compute overhead trade-off is therefore identical across all compute budgets.