permutations — 362,880.
If we sample from this permutations, pick a permutation, say 612934578 and apply the following rule:-Pick every number from the sequence 612934578, one at a time, and compute the layer 1 vector for that position by only using all word vectors that precede it in the sequence (i.
words 6 and 1 for computing 2, words 6, 1 ,2, and 9 for computing 3 etc.
)For illustration picking token 5 (“times”) as the one to be predicted, use all tokens preceding it in the permutation sequence above i.
words 6,1,2,9,3, and 4 for computing the vector for “times”.
Now apply the same procedure as above to compute layer 2 vectors.
For instance, those vector outputs in layer 1 corresponding to tokens at 612934 as well as the hidden state in layer 1 for “times” to compute the layer 2 vector for “times”.
By using this rule of not allowing words preceding it in the permutation sequence (note this permutation sequence can include words from both sides of the predicted word in the sentence) to see it, and it is only those words preceding it in the permutation sequence that a word uses in its prediction, the model ensures the predicted word can never been seen indirectly, regardless of number of layers in the model.
A key point to note is the permutation sequence is only used to decide what tokens are involved in the computation of a vector.
The order of words in the sentence is untouched — every word’s position is fixed by a relative position encoding scheme ( we will come back to this later).
In summary, by predicting a word at a position in a sentence by using a subset of its bidirectional context, and taking into account the word is itself not seen using the permutation scan rule above, we can avoid a word from seeing itself in a multilayered context.
This scheme, however requires that we sample enough permutations so that we make full use of the bidirectional context.
This is still computationally expensive, so the prediction is done only for a subset of the words in a sentence (this is a hyperparameter K in the model) — in the paper about 1/6 (~16%) of the words in a sentence are predicted.
Interestingly this is roughly the same number of words in a sentence that are also predicted in BERT.
BERT, however, skirts this problem by masking/corrupting those predicted words.
Computationally prediction of all words in a sentence at each layer is a few constant number of matrix multiplies in BERT.
XLNet in comparison has to do multiplication operations for each permutation it samples and for each position in the permutation sequence.
BERT corrupts/replaces ~15% of tokens with another token [mask].
In the illustration “time” is replaced by “[mask]”.
The prediction of “time” cannot see it either directly or indirectly since it has been replaced.
The hidden states adjacent to the yellow shaded vector in layer 1 see all the word vectors including the masked one in the input — it is not explicitly shown to avoid clutterThere are a couple of downsides with BERT’s masking approach however, which is not present in XLNet.
if “New” and “York” were masked in the sentence above, prediction of “New” would be independent of “York” and vice versa.
While this seems like a contrived example, XLNet does not make such assumptions which could perhaps in part be why it performs better than BERT in some tasks.
This is perhaps one of the reasons XLNet performs better in Q&A tasks where the answer is a phrase.
the other one is the fact the “mask” token is an artificially introduced token into the input that is only present during training and not fine tuning.
So there is a discrepancy there.
The permutation based dependency rule to capture bidirectional context is perhaps the unique aspect of XLNet.
It also leverages off a prior work Transformer XL to handle long sentences, by transferring state across fixed segments (a minor modification is done to ensure the positional information is relative as opposed to absolute).
XLNet caches layer states from segments for reuse in subsequent segments.
This enables XLNet to handle arbitrarily long sentences.
BERT in comparison can only handle fixed length segments.
Even though one could choose an arbitrary long fixed length segment in BERT, it becomes impractical due to the GPU/TPU memory requirements.
Lastly, this is an implementation detail of the permutation based dependency rule – in order predict a word at a position using terms preceding it in permutation sequence, positional information of the target position being predicted needs to be factored into the prediction to disambiguate predictions at two positions using the same set of neighbors.
Also while the hidden state vector used for predicting a position cannot include the word vector at that position, we also need a hidden state vector that includes that position while predicting other vectors.
This requires the model to learn two vectors for each position, h and g as shown below.
h is a function of word vectors and other hidden state vectors in the layer below.
g is a function of hidden state vectors and positional information of the predicted word.
The figure above from XLNet paper shows the computation of the hidden states for a word that looks at itself in addition to a subset of its neighbors.
Note these vectors are not predicted.
They are used in the computation of the hidden vectors g, taking into account not to include hidden state vectors that see that see the word being predicted.
This is accomplished using the permutation sequence rule described earlierThe figure above from XLNet paper shows computation of the hidden vectors “g” that are predicted.
Note g1 in layer 1 is only dependent on its position info (w) and not its vector x1.
Also g1 in layer 2 is dependent on hidden states h2 h3 and h4, all of which did not see x1 because of the permutation order 3,2,4,1.
1 cannot be seen by 3,2,4 because they precede it in the sequence 3,2,4,1Trying out the modelThis may be a bit challenging for many of us given the computation and memory resources (GPU/TPUs with sufficient on board memory) required for fine tuning the model on the evaluation tasks, unlike BERT( we can fine tune BERT on a single GPU machine for an NER task in about an hour or two on average) .
There is an notebook to test the model on a classification task (I could only run it with a batch size of 2 and got an accuracy of 90% — they report 92% with a batch size of 8).
There is a recent pytorch version that would be useful to try to understand how the model works.
Oddly there are no benchmark results for any sequence tagging tasks like NER (BERT is current state-of-art for NER) — perhaps we will see some results for tagging tasks soon.
The section below can be skipped.
Why cant a word see itself in the prediction?The maximum likelihood estimate for a language model that only uses the left side context to predict the next word (e.
predicting “times” given the words “Alaska is about 12” ).
Equation from XLNet paper.
This essentially says find those parameters theta, that maximizes the probability of the word “times” given “Alaska is about 12”.
Key point to note is the computation of h is a function of all words preceding “times”, but not itself.
If we include “times” too in the computation of h, the model can trivially learn, given e(xt) is the embedding for “times”.
The masked objective of BERT does not include the word itself, but only the replaced mask token.
It then predicts the actual word “times” given the corrupted sentence “Alaska is about 12 [mask] larger than New York” (assuming only one token is masked)Equation from XLNet paper.
The BERT objective.
mt is 1 only for masked tokens in a sentence which is about 15% of words.
Note H includes all words but the predicted word is replaced by the mask token.
e(xt) is the embedding for the word “times”.
The goal is to choose theta so that H(x^) is close to the embedding for “times” given all the words in the corrupted sentence.
BERT in essence skirts the seeing itself problem by replacing the predicted words with masked token and reconstructing them during the training process.
However the computation of H includes a mask token that is only present during pre-training.
XLNet uses a subset of the bidirectional context each time it predicts a word, but avoids the “seeing itself” problem by making sure the computation of “g” only includes tokens that do not see the word being predicted (g is a function of a subset of tokens around it and the predicted words position) .
Again it is easy to see the learning problem become trivial if g is a function of the word being predicted, given e(x) includes the embedding for the token to be predicted.
Equation from XLNet paper.
XLNet prediction of token at a position.
g in the equation above is computed as followsEquation from XLNet paper.
The highlighted state vector must be a strict inequality in the subscript — must be a typo (unless I am missing something very basic)ReferencesXLNet: Generalized Autoregressive Pretraining for Language Understanding, Zhilin Yang et al, June 2019Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context, Zihang Dai et al, Jan 2019BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Devlin et al, Oct 2018A review of BERT based modelsDeconstructing BERTImported this article manually from Quora https://qr.
ae/TWtVmo (automatic import failed for some reason).