However, as LLMs grow in size and new capabilities emerge, we have a problem and that is aliasing. The authors of the paper DoLa: Decoding by Contrasting Layers Improves Factuality in Large Language Models propose a method to avoid this problem.
This notebook has been automatically translated to make it accessible to more people, please let me know if you see any typos.
They propose a contrastive decoding approach, where the output probability of the next word is obtained from the difference in logits between an upper and a lower layer. By emphasizing knowledge of higher layers and de-emphasizing knowledge of lower layers, we can make LMs more factual and thus reduce hallucinations.
The following figure shows this idea. While Seattle
maintains a high probability in all layers, the probability of the correct answer Olympia
increases after the upper layers inject more factual knowledge. Contrasting the differences between the different layers may reveal the correct answer in this case.
An LLM consists of an embedding layer, several sequential transformers and then an output layer. What they propose is to measure the output of each transformer using the Jensen-Shannon divergence (JSD).
The following figure shows this measurement at the output of each transformer for an LLM input sentence. Each column corresponds to one token of the phrase
Two patterns can be observed
- The first occurs when predicting named entities or important dates, such as
Wole Soyinka
and1986
, which require factual knowledge. It can be seen that the calculated JSD remains extremely high in the upper layers. This pattern indicates that the model keeps changing its predictions in the later layers, and potentially injecting more factual knowledge into the predictions. - The second occurs when predicting function words, such as
was
,the
,to
,in
, and tokens copied from the input question, such asfirst Nigerian
,Nobel Prize
. When theseeasy
tokens are predicted, we can observe that the JSD becomes very small from the intermediate layers. This finding indicates that the model has already decided which token to generate in the intermediate layers, and keeps the output distributions almost unchanged in the upper layers. This finding is also consistent with the assumptions in the early output LLMsSchuster et al., 2022
.
When the prediction of the next word requires factual knowledge, the LLM appears to change the predictions in the upper layers. Contrasting the layers before and after a sudden change may therefore amplify the knowledge emerging from the upper layers and make the model rely more on its internal factual knowledge. Moreover, this evolution of information seems to vary from token to token.
Their method requires accurately selecting the premature layer containing plausible but less factual information, which may not always be in the same early layer. Therefore, they propose to find that premature layer by dynamically selecting the premature layer as seen in the following image.
DoLa-figure3](https://maximofn.com/wp-content/uploads/2024/07/DoLa-figure3.webp)
To select the premature layer they calculate the Jensen-Shannon divergence (JSD) between the intermediate layers with the final layer. The premature layer is selected as the layer with the highest JSD.
However, as this process can be a bit slow, what they do is to group several layers together to make fewer calculations.
Now that we have the last layer (mature layer) and the premature layer, we can contrast the predictions of both layers. To do this, they calculate the log probability of the next token in the mature layer and the premature layer. They then subtract the log-likelihood of the premature layer from that of the mature layer, thus giving more weight to the knowledge of the mature layer.
DoLa’s motivation is to de-emphasize lower-layer linguistic knowledge and amplify factual knowledge of the real world. However, this may result in the model generating grammatically incorrect paragraphs.
Empirically they have not observed such a problem, but they have found that the resulting DoLa distribution sometimes has a greater tendency to repeat previously generated sentences, especially during the generation of long sequences of reasoning in the chain of thought.
So they include a repetition penalty introduced in Keskar et al. (2019)
with θ = 1.2
during decoding
Let’s see how to implement DoLa with the transformers
library of Hugging Face. For more information on how to implement DoLa with the transformers
library you can consult the following link
First we log into the Hub, because we are going to use Llama 3 8B, to use it you have to ask permission to Meta, so to download it you have to be logged in so you know who is downloading it.
from huggingface_hub import notebook_login
notebook_login()
Now we instantiate the tokenizer and the model
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import torch
compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoints = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoints)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoints, torch_dtype=compute_dtype, device_map="auto")
model.config.pad_token_id = model.config.eos_token_id
We assign a fixed seed value for the reproducibility of the example
set_seed(42)
Generate LLM input tokens
question = 'What does Darth Vader say to Luke in "The Empire Strikes Back"?'
text = f"Answer with a short answer.\n\nQuestion: {question}\n\nAnswer: "
inputs = tokenizer(text, return_tensors="pt").to(model.device)
We now generate the vanilla input, i.e., without applying DoLa
generate_kwargs={
"do_sample": False,
"max_new_tokens": 50,
"top_p": None,
"temperature": None
}
vanilla_output = model.generate(**inputs, **generate_kwargs)
print(tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0])
We see that he knows that there is a famous mistake, but he fails to say the actual phrase
Now applying DoLa
dola_high_output = model.generate(**inputs, **generate_kwargs, dola_layers='high', repetition_penalty=1.2)
print(tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0])
Now he does manage to give the correct sentence and the famous error
Let’s do another test with another example, I reboot the notebook and use another model
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import torch
compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoints = "huggyllama/llama-7b"
tokenizer = AutoTokenizer.from_pretrained(checkpoints)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoints, torch_dtype=compute_dtype, device_map="auto")
model.config.pad_token_id = model.config.eos_token_id
We assign a fixed seed value for the reproducibility of the example
set_seed(42)
I am writing a new question
text = "On what date was the Declaration of Independence officially signed?"
inputs = tokenizer(text, return_tensors="pt").to(device)
We generate the vanilla output
generate_kwargs={
"do_sample": False,
"max_new_tokens": 50,
"top_p": None,
"temperature": None
}
vanilla_output = model.generate(**inputs, **generate_kwargs)
print(tokenizer.batch_decode(vanilla_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0])
As we can see, it generates the wrong departure, since although it is celebrated on July 4, it was actually signed on July 2.
Let’s try now with DoLa
dola_high_output = model.generate(**inputs, **generate_kwargs, dola_layers='high', repetition_penalty=1.2)
print(tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0])
It still doesn’t generate a correct output, so let’s tell it to only contrast the final layer with layers 28 and 30.
dola_high_output = model.generate(**inputs, **generate_kwargs, dola_layers=[28,30], repetition_penalty=1.2)
print(tokenizer.batch_decode(dola_high_output[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0])
Now it does generate the correct answer