How to sample from language modelsBen MannBlockedUnblockFollowFollowingMay 24Causal language models like GPT-2 are trained to predict the probability of the next word given some context.
For example, given “I ate a delicious hot ___”, the model may predict “dog” with 80% probability, “pancake” 5% probability, etc.
The cool thing about this structure is they can be used to generate sequences of arbitrary length.
I can give the model “I ate,” sample a token from the resulting distribution to get “I ate a”, then put that through the model again to get another distribution and resulting token.
Repeat as long as we like.
It turns out that this generation often either gets stuck in repetitive loops or forgets the subject and goes off topic.
Why is this happening, and how might we better sample to generate more human-like text?This post is a summary and exploration of The Curious Case of Neural Text Degeneration.
I found it one of the most thorough and readable papers I’ve read, so please check it out if this post piques your interest!If we always sample the the most likely word, the standard language model training objective causes us to get stuck in loops like “I don’t know.
I don’t know.
I don’t know.
” This is unnatural, but most of the model’s attention in modern language models is only on the most recent few tokens.
Instead, popular sampling methods for generation are based on sampling from the distribution.
But sampling also runs into a problem: if we have 50K possible choices, even if the bottom 25K tokens are each extremely unlikely, in aggregate they might have for example 30% of the probability mass.
This means with each sample, we have a 1 in 3 chance of completely derailing our “train of thought.
” Because of the short context mentioned earlier, this will cause an unrecoverable error cascade as each next word depends heavily on this recent wrong word.
To combat sampling from the tail, the most popular methods are temperature and top k sampling.
Temperature sampling is inspired by statistical thermodynamics, where high temperature means low energy states are more likely encountered.
In probability models, logits play the role of energy and we can implement temperature sampling by dividing logits by the temperature before feeding them into softmax and obtaining our sampling probabilities.
For example:>>> import torch>>> import torch.
functional as F>>> a = torch.
])Or visuallyLower temperatures make the model increasingly confident in its top choices, while temperatures greater than 1 decrease confidence.
0 temperature is equivalent to argmax/max likelihood, while infinite temperature corresponds to a uniform sampling.
Top k sampling means sorting by probability and zero-ing out the probabilities for anything below the k’th token.
It appears to improve quality by removing the tail and making it less likely to go off topic.
But in some cases, there really are many words we could sample from reasonably (broad distribution below), and in some cases there aren’t (narrow distribution below).
Holtzman et al 2019To address this problem, the authors propose top p sampling, aka nucleus sampling, in which we compute the cumulative distribution and cut off as soon as the CDF exceeds P.
In the broad distribution example above, it may take the top 100 tokens to exceed top_p = .
In the narrow distribution, we may already exceed top_p = .
9 with just “hot” and “warm” in our sample distribution.
In this way, we still avoid sampling egregiously wrong tokens, but preserve variety when the highest scoring tokens have low confidence.
Why doesn’t maximum likelihood sampling work?.In the training process, there’s never a chance to see compounding errors.
The model is trained to predict the next token based on a human-generated context.
If it gets one token wrong by generating a bad distribution, the next token uses the “correct” human generated context independent of the last prediction.
During generation it is forced to complete its own automatically-generated context, a setting it has not considered during training.
Qualitative resultsHere are samples using top_k=40 and context “I ate a delicious”And here are samples using top_p=0.
9 and same “I ate a delicious” context:Try it yourself here!.You can enable GPU in Runtime > Change runtime type and get big batches for no additional runtime.
Beyond the paper: choosing p and k automaticallyI found it challenging to determine which of these samples is more human-like.
For this reason I designed an experiment to determine top_k and top_p empirically.
Our goal is to use top_k and top_p to maximize the probability of choosing the actual next word we’ve held out.
When searching for the optimal k and p values, it’s actually easy to determine analytically for a given sample.
For k, we find the sorted index where the “golden” token occurred.
For p, we find the CDF of the golden token.
For example, if the context is “I ate a delicious hot” and the actual word is “dog”, but the model’s predicted distribution had “pancake” as most likely, we’d search down the probabilities until we found “dog” at index 3.
At index 1, the CDF might be 62%.
At index 3, the CDF might be something like 86%, so we’ll record that as our optimal p.
Across many examples, we can compute a histogram of optimal p and k values and compute summary statistics on them.
I tested on a random section of Wikipedia with a context length of 15.
This is much shorter than what the model was trained on (1024), but common for settings like https://duet.
li or chat bots.
===== ks =====max 29094.
00===== ps =====max 1.
00Feel free to try it yourself in my colab notebook.
If the model were being evaluated on its training set, it would be optimal to choose top_k = 1.
But since the model is slightly out of domain, the most likely token sometimes appears further down the list.
In addition, we have a 50K token vocabulary.
In many datasets, we’ll never see all the tokens, but the model isn’t sure of that.
By zero-ing out much of the probability mass using top_p or top_k, we incorporate our prior to never choose these never-seen-even-in-training tokens.
That said, this search for k and p is still in the context of the model’s view of the world and as such it’s only a bandaid.
What we really want is to fix training.
Fixing trainingI’ve also started to think about changing the training objective to better match the generation task.
For example, could we train some kind of discriminator to punish the model when it generates whole sequences that don’t look human?.It’s not straightforward how to apply a GAN architecture to non-continuous domains.
I came upon Adversarial Text Generation without Reinforcement Learning and an RL-based idea, but it seems these have not yet become mainstream.
I think it’d be interesting to apply these ideas to the big transformers that have swept state of the art in the last few months.