MITB Banner

Guide to XLNet for Language Understanding

XLnet is an extension of the Transformer-XL model. It learns bidirectional contexts using an autoregressive method. Let’s first understand the shortcomings of the BERT model so that we can better understand the XLNet Architecture. Let’s see how BERT learns from data.

XLNet is a generalized autoregressive language model that learns unsupervised representations of text sequences. This model incorporates modelling techniques from Autoencoder(AE) models(BERT) into AR models while avoiding limitations of AE.

This paper was presented by Zhilin Yang, Zihang Dai, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, and Quoc V(researchers from Google AI Brain Team and Carnegie Mellon University) on 19th June 2019. An update was made on 2nd Jan 2020.

Let’s dive into the details before using this model in a classification task.

XLNet

XLnet is an extension of the Transformer-XL model. It learns bidirectional contexts using an autoregressive method. Let’s first understand the shortcomings of the BERT model so that we can better understand the XLNet Architecture. Let’s see how BERT learns from data.

We randomly remove some tokens in input data and we try to predict these removed tokens. mt is a mask that is 1 when a token is removed 1 elsewhere. Pretraining the model with this task helps learn patterns data and the model performs well in downstream tasks using the learned parameters. The Masked Language modelling objective of BERT can be represented mathematically as 

Drawbacks of BERT

There are two problems with this objective 

Independence Assumption

We make an assumption that each missing token is dependent on all the input tokens but independent of other masked tokens. Let’s see an example of why this is problematic.

Dogs love _____ _____

Model ranks [‘eating’,’playing’] as the most probable words in position 2.

It also ranks [‘fetch’,’meat’] as the most probable words in position 3.Because of the independence assumption if we pick highly probable words. Sentence becomes 

Dogs love eating fetch 

instead of 

“ Dogs love eating meat ” or “ Dogs love playing fetch ”

Pre-train-Finetune Discrepancy

In BERT when doing pre-training we have <mask> tokens but in real data, we don’t have these missing tokens which create a discrepancy between pre-training and finetuning.

These problems are non-existent in AR modelling as the objective is to predict the next word given the sequence of words before or after the word. The mathematical formulation is as follows.

 Since we are factorizing the probability into multiple conditional probabilities naturally, we are not making assumptions nor using any artificial <mask> tokens.

XLNet uses this kind of AR modelling. But there is a huge problem with this approach. We are only depending on unidirectional context i.e either forward or backward. Following example illustrates this problem.

Dogs love ____ meat. The model may predict playing instead of eating because it doesn’t see the token meat yet. AE models don’t have this problem as they see all the non-missing words.

Permutations over Factorization

To capture the bidirectional context all permutations of the given sequence are used for training.

All the T! Permutations of the tokens in the sequence are represented by z. Mathematically this is given as 

Following nice visualization by the authors makes this clear

Attention Masks for Different Factorizations

Probabilities are calculated using all such factorization orders and an expected value of these probabilities is maximized.

Now we have successfully captured the bidirectional context as well as created a language model without making troublesome assumptions. But there is still one problem left to tackle.

Since we try to predict the next word given only a sequence of words we fail to create a distinction between distributions based on position. Let’s see an example

Dogs love _____ . Now both the words eating and meat are highly probable since we only have the context of Dogs love. Ideally, we want to eat to be more probable in the 3rd position than meat. This kind of position awareness is what distinguishes this model from a simple bag of words model.

A new representation(g) is introduced to represent this positional awareness along with the context of the previous words. The new conditional probability is 

Two Stream Self Attention

How do we calculate g? This is done using a two-stream self-attention mechanism.

Hidden states h of transformers like BERT represent the information about the content of a word and the position. These states are calculated using self-attention. But we cannot directly use these states as g. Because these states contain the content of the current word we need to predict. What we can do instead is we can add an attention layer on top of the previous hidden states. The query vector for this layer is a representation of position information and key, value vectors which are the hidden states. Therefore we need two different attention layers with shared parameters.

Content Stream
Query stream

That is the backbone structure of XLNet.This itself seems like a very good solution but there is one more problem that’s very critical. The length of the sequences which we can model is fixed. Sequence lengths are fixed for BERT too. This problem is usually solved by segmenting the sequence. But the information in the previous is lost when we try to do predictions on the current segment. This is solved by introducing a cache memory. The hidden state vectors are saved after the end of learning on the first segment. These saved vectors are used at the beginning of the next segment.

Here is the vector saved from the previous segment.

There are several, such clever tricks XLNet uses to improve the performance and efficiency of the representations. Let’s see how to use this model in a classification task.

Usage

Let’s try using the XLNET base model for the purpose of classification. 

Unfortunately, XLNet isn’t available in the TensorFlow hub yet. We still can clone the official implementation from GitHub and work with it. This model is huge so it requires a system with lots of VRAM. Use tensorflow version 1.x as the current implementation may not work with 2.0

Command to clone GitHub repository

! git clone https://github.com/zihangdai/xlnet.git

You can get pre-trained weights using

 ! wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip
 ! unzip cased_L-12_H-768_A-12.zip 

Let’s get the IMDb reviews dataset using wget

 ! wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
 ! tar zxf aclImdb_v1.tar.gz 

A classifier head on top of XLNet is available in the repository. All you need to do is run the run_classsifier.py file to train the model on current data.

 python xlnet/run_classifier.py 
   --do_train=True 
   --do_eval=True 
   --eval_all_ckpt=True 
   --task_name=imdb 
   --data_dir="+DATA_DIR+" 
   --output_dir="+OUTPUT_DIR+" 
   --model_dir="+CHECKPOINT_DIR+" 
   --uncased=False 
   --spiece_model_file="+PRETRAINED_MODEL_DIR+"/spiece.model 
   --model_config_path="+PRETRAINED_MODEL_DIR+"/xlnet_config.json 
   --init_checkpoint="+PRETRAINED_MODEL_DIR+"/xlnet_model.ckpt 
   --max_seq_length=128 
   --train_batch_size=8
   --eval_batch_size=8 
   --num_hosts=1 
   --num_core_per_host=1 
   --learning_rate=2e-5 
   --train_steps=4000 
   --warmup_steps=500 
   --save_steps=500 
   --iterations=500" 

Results will be saved in OUTPT_DIR

Results

90% validation accuracy with such a simple Neural Network is really good. This shows the power of XLNet’s representations. 

Colab Link

Conclusion

The performance boost obtained by XLNet is mainly because of the novel language modelling objective used. This shows that there is scope for improvement of the models by improving the language models used as the pre-training objectives for the models.

Access all our open Survey & Awards Nomination forms in one place >>

Picture of Pavan Kandru

Pavan Kandru

AI enthusiast with a flair for NLP. I love playing with exotic data.

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories