XLNet is a generalized autoregressive language model that learns unsupervised representations of text sequences. This model incorporates modelling techniques from Autoencoder(AE) models(BERT) into AR models while avoiding limitations of AE.
This paper was presented by Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, and Quoc V(researchers from Google AI Brain Team and Carnegie Mellon University) on 19th June 2019. An update was made on 2nd Jan 2020.
Let’s dive into the details before using this model in a classification task.
XLNet
XLnet is an extension of the Transformer-XL model. It learns bidirectional contexts using an autoregressive method. Let’s first understand the shortcomings of the BERT model so that we can better understand the XLNet Architecture. Let’s see how BERT learns from data.
We randomly remove some tokens in input data and we try to predict these removed tokens. mt is a mask that is 1 when a token is removed 1 elsewhere. Pretraining the model with this task helps learn patterns data and the model performs well in downstream tasks using the learned parameters. The Masked Language modelling objective of BERT can be represented mathematically as
Drawbacks of BERT
There are two problems with this objective
Independence Assumption
We make an assumption that each missing token is dependent on all the input tokens but independent of other masked tokens. Let’s see an example of why this is problematic.
Dogs love _____ _____
Model ranks [‘eating’,’playing’] as the most probable words in position 2.
It also ranks [‘fetch’,’meat’] as the most probable words in position 3.Because of the independence assumption if we pick highly probable words. Sentence becomes
Dogs love eating fetch
instead of
“ Dogs love eating meat ” or “ Dogs love playing fetch ”
Pre-train-Finetune Discrepancy
In BERT when doing pre-training we have <mask> tokens but in real data, we don’t have these missing tokens which create a discrepancy between pre-training and finetuning.
These problems are non-existent in AR modelling as the objective is to predict the next word given the sequence of words before or after the word. The mathematical formulation is as follows.
Since we are factorizing the probability into multiple conditional probabilities naturally, we are not making assumptions nor using any artificial <mask> tokens.
XLNet uses this kind of AR modelling. But there is a huge problem with this approach. We are only depending on unidirectional context i.e either forward or backward. Following example illustrates this problem.
Dogs love ____ meat. The model may predict playing instead of eating because it doesn’t see the token meat yet. AE models don’t have this problem as they see all the non-missing words.
Permutations over Factorization
To capture the bidirectional context all permutations of the given sequence are used for training.
All the T! Permutations of the tokens in the sequence are represented by z. Mathematically this is given as
Following nice visualization by the authors makes this clear
Probabilities are calculated using all such factorization orders and an expected value of these probabilities is maximized.
Now we have successfully captured the bidirectional context as well as created a language model without making troublesome assumptions. But there is still one problem left to tackle.
Since we try to predict the next word given only a sequence of words we fail to create a distinction between distributions based on position. Let’s see an example
Dogs love _____ . Now both the words eating and meat are highly probable since we only have the context of Dogs love. Ideally, we want to eat to be more probable in the 3rd position than meat. This kind of position awareness is what distinguishes this model from a simple bag of words model.
A new representation(g) is introduced to represent this positional awareness along with the context of the previous words. The new conditional probability is
Two Stream Self Attention
How do we calculate g? This is done using a two-stream self-attention mechanism.
Hidden states h of transformers like BERT represent the information about the content of a word and the position. These states are calculated using self-attention. But we cannot directly use these states as g. Because these states contain the content of the current word we need to predict. What we can do instead is we can add an attention layer on top of the previous hidden states. The query vector for this layer is a representation of position information and key, value vectors which are the hidden states. Therefore we need two different attention layers with shared parameters.
That is the backbone structure of XLNet.This itself seems like a very good solution but there is one more problem that’s very critical. The length of the sequences which we can model is fixed. Sequence lengths are fixed for BERT too. This problem is usually solved by segmenting the sequence. But the information in the previous is lost when we try to do predictions on the current segment. This is solved by introducing a cache memory. The hidden state vectors are saved after the end of learning on the first segment. These saved vectors are used at the beginning of the next segment.
Here is the vector saved from the previous segment.
There are several, such clever tricks XLNet uses to improve the performance and efficiency of the representations. Let’s see how to use this model in a classification task.
Usage
Let’s try using the XLNET base model for the purpose of classification.
Unfortunately, XLNet isn’t available in the TensorFlow hub yet. We still can clone the official implementation from GitHub and work with it. This model is huge so it requires a system with lots of VRAM. Use tensorflow version 1.x as the current implementation may not work with 2.0
Command to clone GitHub repository
! git clone https://github.com/zihangdai/xlnet.git
You can get pre-trained weights using
! wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip ! unzip cased_L-12_H-768_A-12.zip
Let’s get the IMDb reviews dataset using wget
! wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz ! tar zxf aclImdb_v1.tar.gz
A classifier head on top of XLNet is available in the repository. All you need to do is run the run_classsifier.py file to train the model on current data.
python xlnet/run_classifier.py --do_train=True --do_eval=True --eval_all_ckpt=True --task_name=imdb --data_dir="+DATA_DIR+" --output_dir="+OUTPUT_DIR+" --model_dir="+CHECKPOINT_DIR+" --uncased=False --spiece_model_file="+PRETRAINED_MODEL_DIR+"/spiece.model --model_config_path="+PRETRAINED_MODEL_DIR+"/xlnet_config.json --init_checkpoint="+PRETRAINED_MODEL_DIR+"/xlnet_model.ckpt --max_seq_length=128 --train_batch_size=8 --eval_batch_size=8 --num_hosts=1 --num_core_per_host=1 --learning_rate=2e-5 --train_steps=4000 --warmup_steps=500 --save_steps=500 --iterations=500"
Results will be saved in OUTPT_DIR
90% validation accuracy with such a simple Neural Network is really good. This shows the power of XLNet’s representations.
Conclusion
The performance boost obtained by XLNet is mainly because of the novel language modelling objective used. This shows that there is scope for improvement of the models by improving the language models used as the pre-training objectives for the models.