AI Unplugged 11: LoRA vs FFT, Multi Token Prediction, LinkedIn's AI assistant
Insights over Information
Table of Contents
LoRA learns less and forgets less
Multi Token prediction improves LLMs
Building AI product ft LinkedIn
LlaMA3 implemented from scratch
Mistral v3 and Phi3 Medium released
LoRA Learns Less and Forgets Less
LoRA or Low Rank Adaptation of LLMs has been a god send for a lot of people trying out LLMs and adapting them to their specific needs. It allows one to fine tune a model with very little compute and memory overhead. It has gained immense popularity ever both in the world of LLMs and in the world of Image Generators like Stable Diffusion.
But one question was unanswered. How does LoRA compare to full fine tuning? I mean sure, the original LoRA paper does compare their method to full fine tuning. But they used models like GPT-2, DistilBert for their comparisons. Those models are ancient and no one uses them anymore. With Larger models being more mainstream these days, its fair to re-evaluate the stance. People have often found that LoRA has a ceiling when it comes to how much you can squeeze out of it.
The whole backbone of LoRA is the hypothesis that upon fine tuning, the change in weights are of low rank. So why not force the update to be low rank aka tune only a few number of additional parameters. There were many Parameter efficient fine tuning methods before LoRA but where LoRA shines is, you can always use base model separately and if needed, merge the adapters for faster inference offering tremendous flexibility. All is well and good, but what if the updates aren’t really low rank? Then the whole theory falls flat. This is exactly what they find.
In fact, a few months ago, I was experimenting with the same thing. Finding if the Low Intrinsic dimension is really true after all. Here’s the thread.
Similar to my findings, the authors of the paper find that the delta between fine tuned weights and base model weights. Turns out for the delta, you need approx 2000/4096 weights to explain 90% of the variance. Which is in no way low rank. But one good thing arising from LoRA is, because the weights are forced to be low rank, it acts as a good regularisation.
The authors in the paper take a couple of datasets in the fields of Math and Code on a couple of tasks namely Continued Pre training and Instruction Fine Tuning (the diff is basically the size of data used. Billions of tokens in the former and in the order of millions for the later.
As you see in the above images, LoRA tends to preserve the accuracy numbers on HellaSwag, ARC and WinoGrande(as reference of the pre-existing knowledge) but underperforms on the target metrics as compared to full fine tuning. Full fine tuning regresses on source baselines (dotted lines).
People have already empirically observed that LoRA is only good if you want to change the style of generation or some aesthetic things. It was generally advised to not do LoRA if you wanted to impart knowledge into the models. In the above image, you observe that the token diversity is somewhat intact when it comes to LoRA. While full fine tuning exploits the distribution and tends to predict the same samples with higher frequencies, LoRA is a good middle ground between base model (having equal likelihood for almost all generations) and Full fine tuning (which concentrates the predictions among few unique samples)
Daniel from Unsloth pointed out some wonderful takes and scopes for improvement in the paper for evaluating LoRA. Note that DoRA, VeRA etc are improvements to LoRA.
Better LLMs via Multi token Prediction
Ever wondered why we only predict one token at once? Yeah it does make sense to output one token at once but why not at least try predicting multiple tokens at once?
The challenge for a long time was to condense the whole information required to predict multiple tokens into a single network. Even then another problem is, how do you define order in the set of generated tokens. But in case of Image generations, you don’t generate one pixel at a time, you often generate the whole image or patches of image at once. An image, assuming meagre 224x224x3 is 150528 numbers where as a token with hidden dimension of 4096 is just 4096 numbers. So it should make sense to be able to try predicting multiple tokens at once, emphasis on try here.
Now how do you approach this problem? The trunk (as they call it) of the model (which basically is the model excluding embeddings and lm_head) already has knowledge to predict the next token given a series of previous tokens. So it only makes sense to use that as the starting point. So if we equip the body with more heads it can output multiple tokens. Remember Ravana, one body ten heads same concept.
Fine, now the heads output different tokens, assume that we define an ordering on the same. Now how do we even go about calculating loss? We currently have a set of vectors that are output by the model and another set to compare them against. The easiest way is to treat them separately and add their losses up.
So the first term, is Probability of generating token t+1 to t+n given tokens 1 to t aka generating n tokens given t tokens. But since we only have t tokens to begin with, while generating tokens t+1 to t+n, those only depend on tokens 1 to t hence the joint probability can be expressed as product of the individual probabilities.
So all is well and good, but what is the need for all this? What advantages do we even get from predicting multiple tokens at once? The answer is, surprisingly, performance improves on datasets. Quite fascinating. Maybe the model learns to predict tokens while having its future response in its thoughts aka making mental note of things, making the model grounded in a way towards correct response. while predicting tokens. Also while training, the model is pretty much incentivised to think in short term and predict next toke only. Even if the prediction diverges at a token, for the next tokens, we use the ground truth as the input and not previous generation. This is contrasting to what is done at inference where once we predict a wrong token, we have to live with it. Multi token prediction in a way incentives model to look farther.
Note that the improvement in performance increases with increase in model size. Pretty expected cuz small models don’t really have the ability in them to store all this info (to predict multiple tokens) in their weights ig. And this is quite possibly why works almost always consisted of one token at a time prediction as I alluded to in the beginning.
Multi token prediction also has an orthogonal advantage. If you can generate multiple tokens in a single go, you can perform speculative decoding. You wouldn’t even need a draft model or an assistant model for generation. And they observe, on average, 2.5-2.7 tokens are accepted among the three suggested tokens. That is a pretty high number. So you’re basically speeding up the generation by 2-3x.
Another interesting find is, this generalises to non text vocabularies. You might be wondering whats non text vocab? Well some works like ByteGPT (we covered in one of our previous works) and MambaByte explore. Also they tried comparing different values of n, n being number of tokens predicted at single time step. And 4 seems to be the best. There are some other interesting ways in which they compare the performance of multi token prediction to single token prediction with a pretty good reasoning as to why the multi token prediction is the reason behind that performance gap.
So all in all, an interesting work. The aspect of speeding up inference, at no cost in inference time while also improving the performance on down stream tasks is basically free lunch. So they experimented with predicting 1, 2 and 4 tokens at once. It’d be interesting to extend this further and try something more than 8. It might be an overkill though. Compressing information worth 8 tokens into weights can be tricky. Especially that in english, words like the (& prepositions) and verb forms occur frequently which aren’t really hard to predict. Once you get to 8 tokens or maybe 16, you start to get into the territory of 1 and 1+ sentence. Where you have to predict more nouns, adjectives while also being coherent and consistent which need a lot more prowess.
LinkedIn: Musings of building GenAI product
Well all the research in AI is amazing. But if it doesn’t turn into a product, it doesn’t add much value to the end user right? So one has to know the struggles and journey of how one can integrate an LLM into their operations and products. LinkedIn shared their story of trying to enhance user experience
So LinkedIn’s data would be forever changing. You can’t just train the model with a knowledge cut off and expect it to work right? So what do you do in that case? Augment the existing model with context of current data. But the data has to be fetched dynamically and semantically aka using Natural Language search. The best way to achieve this is to use tools. For example, you ask GPT4 or Bing or HuggingChat a question. They have the ability to search something. But its for those models to decide what to query and search and how to interpret the response.
Search is one example of Tool. GPT4 (on chatgpt plus) has access to code execution environment. So it also has to know when to run code. The user wouldn’t explicitly prompt for that. User asks something like “Draw me a histogram” or “perform data analysis on this data” for example. The model has to whether to use a tool for the job, and if yes, what tool to use, with what parameters etc.
Once the user prompts a question, there’s another step that asks for more info if needed, then starts the crucial part of Tool Selection and Usage. This is basically Retrieval Augmented Generation (RAG) with retrieval from sources like from your APIs or Tools and Generation follows.
So the approach LinkedIn took was to freeze a 3 step pipeline. A small model to retrieve stuff (basically embedding and similarity search if needed). A bigger model for generation, just like any other RAG app. They mention injecting response examples into the prompt. So a few shot prompting then. And for evaluation, they evaluate individual steps of the pipeline separately.
Now the hard part, Evaluation. The models have to respond in helpful yet empathetic way. To understand model performance real time, they monitor and evaluate up to 500 conversations daily to assess hallucination, style and coherence etc. Now they’re working on automated evaluation using models (probably something big like GPT4 or GPT5)
Another task is to properly expose the API spec so that model can understand what is available to it for querying and uses them as needed. So they had to write schemas with human/AI readable descriptions with clear input and output schema. They mention using YAML schema for API parameter generation as it is less verbose and hence uses fewer tokens than JSON. YAML is a lot easier to handle for LLMs given that it is less complex when it comes to using quotes (“) and expressions. To tackle the cases where model generates invalid YAML (~10% of the time), they created an in house YAML parser and correct the mistakes by code bringing the error rate to incredible ~0.01%. Note that prompting LLM to fix it will work but would add significant delay in response times.
Currently, they’re fine tuning the models to better fit their data and use case. Which is sometimes necessary to improve your RAG once you hit a threshold with Prompting and Tool usage OOB.
Fittingly, they ask their LLM to summarise the blog and provide take aways.
LlaMA3 re-implemented from scartch
This is an excellent repo with an awesome readme outlining each of the steps that goes on in a llama decode step. No fancy transformers library used. No fancy stuff. All the layers, all the computation that goes into a layer, is split according to the logical constructs, even splitting the attention heads and RoPE embeddings. So if you are just beginning to get into learning LLMs, just read the paper Attention is all you need and wondering how that translates to current day, this is an excellent source. You can play around with each operation, see how things change and get a great understanding of everything.
Note that the repo uses original weights from LlaMA-3 and that is a great news since we’d have a base line to compare things against. Do take a look, dive in and tweak lol.
Mistral v3 and Phi3 Medium released
So yet again, mistral.ai casually dropped updates to their models. Initially it started with v0.1, there was v0.2 of instruct and now there’s v0.3 of base and instruct models.
There doesn’t seem to be a big change in architecture. The only difference I see is vocabulary went up from 32000 to 32768. The new extended vocab now consists of these special tokens like [INST] and [TOOL_CALLS]. So maybe the models are further trained on tool usage. After these 5-6 special tokens, the rest of the 760 are basically [control_xyz] which probably are place holders?
Also, we talked about Phi3 family of models couple of weeks ago. Back then only Phi3-small was released. Now they release Phi3-medium and Phi3-small. I mentioned there that Phi-3-Medium comes around 13.6B params and as it turns out, there are a couple of minor differences from my config like intermediate_size being 17920 instead of 16384 as there are 10 kv_heads(1 for every 4 attn_heads, was my bad) and vocab_size is 32k only(?) and the released model is at 13.96B. Similarly for Phi-3-Small, intermediate_size is 14336 instead of 16384 taking the total param count to 7.39B