Table of Contents
Adam mini
GrokFast
MobileLLM
JEST: Data curation via joint example selection accelerates learning
Adam Mini: Use Fewer Learning Rates To Gain More
If you have ever tried to train or fine tune models, one of the key decision points is to choose an optimiser. Optimisers dictate how weights change wrt loss on that particular mini batch. There are two prominent optimisers people use. One is Stochastic Gradient Descent (SGD) which is basically the default. It takes a step in the direction opposite to gradient. Each parameter has to keep track of its own gradient. So memory usage is equal to that of the parameters. Then comes Adam which takes it a step further (pun intended) to determine the step size by taking into consideration the previous few step sizes dubbed as momentum and variance. So each parameter has to keep track of momentum, variance and gradient. So 3 values for each parameter. Memory requirement is 3x the number of parameters, which is huge for big models. It is equivalent to saying that each parameter has its own learning rate.
Now this begs the question, can’t there be a middle ground? Surely there has to be some optimal point between using 1x params vs 3x params, between SGD and Adam. One intuitive thing one can think of is, instead of having one learning rate for each parameter, find a set of parameters for which we can get away with having same learning rate (aka step size factor). But how do you decide on which parameters get the same step size factor? If you go mathematically, one thing you can do is, look at the hessian. Hessian is basically 2nd order derivative which denotes the curvature at a given point. The hessian matrix also has information of how derivatives wrt one parameter are influenced by other parameters.
Hessian being zero means the gradient wrt one parameter is pretty uninfluenced by the change in another parameter. Like saying independent/uninfluenced but in gradient scale. So they definitely need different learning rates. But there are some structures inside the hessian which are non zero. Technically speaking, hessian looks like a block diagonal matrix. So gradients of a parameter do have some influence to the parameters around it. To be precise, parameters inside blocks or modules of a transformer layer have non zero hessian. This leads us to the thought that maybe we can use only one parameter to express their momentum or variance equivalents. To think of it, Adam would have a diagonal only hessian aka diff learning rate step sizes for each parameter while SGD would have something like full matrix sparse or dense.
The paper experiments with this very idea. Single learning rate per block. How do you decide on block? One intuitive thing in transformers is to use layer wise. Once can also do module wise but layer wise would suffice. But what would we gain from all these? One definite thing is memory savings. How much you ask? Well we store one variance parameter per block/layer instead of doing that for all the parameters. So that is a ~50% reduction in memory compared to Adam. Here are the numbers for different model sizes.
If your optimiser uses less memory, you can fit in larger batch size in memory and hence shave off the training time assuming a memory bottleneck. embd_blocks is basically embedding layer and lm_head layer. Because gradients for those matrices are on only for those tokens that are present in the current examples, using a single variance parameter would hurt the performance. The Algorithms 2 or 3 are to partition the models into blocks. Also if you’re curious about how the single variance parameter for a block is selected, it is basically the mean of the variance values that Adam uses in that particular block.
Performance wise, Adam mini is on par with AdamW. Sometimes, it even outperforms AdamW. Here is the code to the whole thing in pytorch. It is an exciting work. AdamW is the default that everyone uses for transformers and any improvement on it lets us eek out the best performance from a given hardware.
GrokFast: Accelerated Grokking by Amplifying Slow Gradients
Grokking is a phenomenon where a machine learning model begins to generalize effectively long after it has already overfitted to the training data. In simpler terms, the model first memorises the training data and only later learns to make accurate predictions on new, unseen data. If you’re curious about what grokking is and how it affects Transformers, check out our recent blog about Grokking Transformers. Grokking traditionally demands extensive computational resources due to the prolonged training periods required post overfitting. This makes grokking impractical for many machine learning practitioners with limited resources. The primary objective of the paper is to accelerate the process of grokking.
The approach focuses on updating model parameters θ over iterations in a training process. At each iteration t, the parameter update u(t) is influenced by the gradient g(t) of the loss function with respect to θ:
where u(g(t),t) represents the update to the parameter θ based on the gradient and iteration t. Over multiple iterations, the cumulative effect of these updates can be expressed as follows where θ(0) is the initial value of the parameter θ.
To gain insights into how these updates affect the learning process over time, the approach utilizes the Discrete-Time Fourier Transform (DTFT). This transformation converts the discrete signal u(t), representing parameter updates over iterations, into the frequency domain:
Here, U(ω) is the frequency domain representation of the parameter updates u(t).
In many deep learning applications, the gradient g(t) is directly related to the parameter update u(t). Specifically, u(t) is often proportional to g(t), reflecting how changes in the loss function gradient influence the model parameters. Understanding this relationship is crucial for optimizing training dynamics. Slow generalization, particularly post-overfitting, is associated with the low-frequency components of u(t), and indirectly even the gradient signal g(t). These low-frequency components indicate gradual changes in the model parameters that contribute to better generalization to unseen data.
Amplifying this low-frequency component of G(ω) accelerates the speed of generalization under the grokking phenomenon. A straightforward approach is to add a low-pass filtered version of the gradient to itself. This filtered gradient is given by:
where h(t) is a low-pass filter and ∗ denotes convolution. Recollect that convolution in time domain is equivalent to multiplication in frequency domain. In the frequency domain, the modified gradient can be represented as:
For simplicity, the authors start with a windowed Moving Average (MA) as the low-pass filter, where Π is the Heaviside Pi function as described above, λ is a scalar factor, and w is the window size. This filter computes a weighted average of recent gradients, emphasizing recent trends in the gradient signal.
The authors also propose an improved version of their gradient filtering algorithm, replacing the windowed moving average (MA) filter with an Exponential Moving Average (EMA) filter. The EMA filter has a smaller memory footprint (no longer need to queue the gradients), making it more practical for real-world applications. The impulse response of the filter becomes:
where δ(t) is the discrete unit impulse (zero everywhere except at the origin). The EMA filter only requires memory equivalent to the size of the model itself, significantly reducing the memory requirements compared to the windowed MA filter. Notice that for Moving Average, we need to store the gradients in a queue and average them. This takes a lot of memory. For EMA, we do it iteratively hence the memory savings.
The update formula in line 7 of algorithm 2 resembles the momentum term in many optimizers. However, there are notable differences.
The smoothed gradient (EMA) is used as a residual, added to the current gradient g(t) before being fed into the optimizer. This is more similar to Nesterov's momentum, but with the crucial difference that the filtering (smoothing) is applied before the optimizer processes the gradients. Also, the filtering step (lines 7-8) is applied independently of the underlying optimizer. This means the process can be combined with various optimizers, provided they are first-order gradient descent-based. The EMA acts as a pre-processing step to modify the gradients before they are utilized by the optimizer.
GROKFAST algorithm was employed for the MNIST classification task using a three-layer ReLU-MLP model, a graph convolutional neural network (GCNN) trained on the QM9 molecule dataset, and a 2-layer LSTM network for sentiment analysis on the IMDb dataset. The objective was to demonstrate the generalizability of GROKFAST beyond simple algorithmic tasks to more conventional machine learning problems. The results were clear: models trained with GROKFAST-EMA showed enhanced convergence speeds and better validation loss and accuracy metrics.
MobileLLM: Optimizing Sub-billion Parameter Language Models for On-Device Use Cases
Large language models (LLMs) are increasingly woven into daily life, enriching communication, work, and entertainment. Examples like ChatGPT and Perplexity AI typically operate on powerful cloud servers, yet scaling these models for widespread use presents substantial environmental and computational challenges. For instance, running GPT-4 at high throughput would necessitate approximately 100 million H100 GPUs, equivalent to the output of 160 large companies, leading to significant energy consumption and carbon emissions.
There's a growing imperative to adapt LLMs for mobile deployment due to constraints in memory and energy. Current mobile devices, have limited DRAM capacity of 6-12 GB, necessitating smaller models that fit within memory limits while maintaining efficiency. For example, a 7-billion-parameter model consumes excessive energy draining a device in less than two hours, whereas a 350-million-parameter model can efficiently for an entire day. Moreover, the decoding speed can be significantly enhanced, capable of operating at 50 tokens/s, compared to the state-of-the-art iPhone App MLC Chat utilizing the LLaMA 7B model at 3∼6 tokens/s. Recently, Apple introduced a 3B On-device model with 100s of adapters, each fine-tuned for a specific task.
This paper proposes and evaluates LLMs with fewer than 1 billion parameters. Key contributions include:
FFN Enhancement: Replacing traditional FFNs with SwiGLU improves average performance across zero-shot reasoning tasks.
Depth vs. Width: For sub-billion LLMs, depth (number of layers) is more critical than width (number of neurons per layer) in improving performance. Experiments demonstrate they consistently outperform shallower and wider configurations across a range of tasks.
Parameter Efficiency: In sub-billion parameter models, input and output embedding layers constitute a significant portion of the total parameters. By employing input-output embedding sharing—where identical weights serve both the input embedding and the output fully connected layer—the model complexity is reduced without compromising accuracy. This is beneficial for compact models intended for mobile devices with constrained memory capacities. MiniCPM also employed a similar strategy for their 2B models.
Computational Efficiency: In conventional small LLMs, an equal count of key-value heads to query heads is typically employed. However, this paper utilises grouped query attention (GQA) to mitigate redundancy in key-value heads. This method lowers computational overhead while preserving competitive accuracy.
Memory Efficiency: To deepen the model's capabilities without expanding storage needs, immediate block-wise layer sharing is employed. Instead of duplicating transformer blocks separately, this method shares their weights, enhancing model performance especially in memory-constrained environments such as mobile devices.
Performance Benchmarking: Introduces MobileLLM models of different sizes, achieving state-of-the-art results in zero-shot tasks and downstream applications, comparable to larger models like LLaMA-v2 7B in certain tasks.
MobileLLM-125M has outperformed previous models with fewer than a billion parameters in zero-shot common sense reasoning tasks. Taking it further, MobileLLM-LS-125M introduces innovative layer-sharing techniques that significantly boost performance, rivaling larger models. These advancements make MobileLLM models particularly effective in practical applications such as chat interfaces and API integrations, delivering competitive outcomes comparable to much larger counterparts.
At 125M parameters, it would take 250MB of RAM at 16 bit precision. Even at 32 bit, it would only take up 500MB of RAM which is much much less than what chrome would consume on mobile :) For reference, in 2019 I/O, Google announced that they shrunk the voice recognition model to 0.5GB thus enabling Google Assistant to run locally on device. So there’s very easy possibility of something like MobileLLM being always on. Also on that note, I’d highly urge you to check out MiniCPM 2.0 blog which talks in detail about building and the science of small LMs. The blog is one of the most detailed write ups I’ve seen off late.
Data curation via joint example selection further accelerates multimodal learning
The quality of data is crucial for the success of large-scale pretraining in fields like language, vision, and multimodal modeling. Well-curated datasets can achieve stellar performance with less data, but manual data curation is tough and costly to scale. So, what's the solution? You guessed it—back to the model. Model-based data curation leverages the model’s features to pick high-quality data, potentially boosting pretraining efficiency. Traditional methods focus on individual data points, but the composition of data batches - how data points are grouped - also matters significantly. Batches with “hard negatives” (similar data points with different labels) offer a stronger learning signal than easier examples. The authors propose extending the model-based criteria to select entire batches, hypothesizing this will speed up learning more effectively than selecting individual examples.
Imagine you have a mountain of data and want to select the best pieces to train your model, the "learner." Instead of picking data randomly, you can use smart techniques to choose the most useful pieces. Here are some strategies:
Hard Learner Method: This method targets the most challenging data, selecting batches that the learner struggles with (high prediction errors). It’s great for small, clean datasets but can backfire on larger, noisier datasets by focusing too much on irrelevant data. Basically, the higher the loss is, the higher learnability score of that data is.
\(s^{hard}(\mathcal{B}|\theta) = l(\mathcal{B}|\theta)\)Easy Reference Method: This method does the opposite, choosing data that a pre-trained reference model (already trained on similar data) finds easy. However, it might not always align with your learner model's needs and can be computationally heavy for large datasets. So if data is easy for a reference model, that means it is possibly learnable for other models as well. Notice that θ* is basically parameters of a different (reference) model.
\(s^{easy}(\mathcal{B}|\theta^*) = - l(\mathcal{B}|\theta^*)\)Learnability Scoring: This method strikes a balance by combining the previous two approaches. It selects data that is both challenging and learnable, helping the model learn more effectively. This approach has been shown to speed up learning even with large datasets.
\(s^{learn}(\mathcal{B}|\theta,\theta^*) = s^{hard}(\mathcal{B}|\theta) + s^{easy}(\mathcal{B}|\theta^*) = l(\mathcal{B}|\theta) - l(\mathcal{B}|\theta^*)\)
In multimodal learning, datasets include various data types, like images paired with text, which is common for models that need to understand and generate both text and images. Contrastive Learning trains models to align paired examples and distinguish them from unpaired ones.
So how do we sample data with JEST? Given a set of micro batches/samples, we sample the batches/samples in proportion to the learnability score we calculated above. The more learnable the data is, the more times (in probability) we see the data. The idea is to not waste a lot of time on unlearnable data. Now that ratio of those discarded aka useless for learning samples from the given batch is called filtering ratio.
JEST’s method of sampling sub-matrices based on summed learnability significantly improves the learnability of batches with fewer iterations compared to brute-force Gibbs sampling. It efficiently selects highly learnable sub-batches, accelerating training in multimodal learning. For instance, using filtering ratios of 50%, 80%, and 90%, JEST achieves the final performance of a 3B-example baseline after only 2B, 1B, and 0.67B examples, respectively. JEST also improves final performance by up to 6% when filtering 90% of data, outperforming independent example selection methods.
Flexi-JEST, a compute-efficient variant, combines multi-resolution training to cut down computational overhead. They calculate the score of half of the data points in a batch with bigger patch sizes of 32, aka low resolution and half of the data with smaller patch size of 16 aka high resolution. While Flexi-JEST slightly reduces per-iteration performance compared to full-resolution JEST, it’s much faster and more FLOP-efficient, achieving similar average performance with 9.9× fewer FLOPs.
Even though this is mainly targeted at multi modal data where one should bring together the representations of two different types of data to a common ground like we do for Vision Language models, this has tremendous implications. Data quality is a huge thing and automating the qualitative assessment and filtering of data can go a long way. Especially given that the speed ups are in orders of magnitude at 10x.
Applied LLMs: Learnings
This blog by the awesome folks details into what to do when you’re trying to build LLMs for your use case. It starts with prompting, then gets to varied insights into prompting. Slowly delves into when to fine tune an LLM and also providers references to amazing write ups by other people in a similar space. Please carve out some time to read it through if you want to avoid common pitfalls.