The idea that we look at all the words in proportion to their relevance, while understanding a word in a sequence is the prime factor for the success of transformers in the natural language processing domain. However this attention mechanism comes at a cost. It restricts the possible length of the sequence of words. In NLP settings where you have to model log range dependencies between words this becomes a major showstopper. In This let’s explore Transformer XL, a Transformer model that allows us to model long range dependencies while not disrupting the temporal coherence.
This model was introduced by Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell,
Quoc V. Le and Ruslan Salakhutdinov(researchers at Carnegie Mellon University and Google Brain) on 2 June 2019.
Transformers like BERT are constrained to fixed sequence lengths.But in real word documents are of variable lengths, we can try to do padding to reduce this problem but it is not very efficient. More efficient solution would be to chunksize the documents into fixed length segments.
This creates a problem of context fragmentation i.e information in the previous or later segment is not used in the current segment.
Ex: If we are trying to understand the meaning of an abbreviation that is defined in an earlier segment we need context from previous segments.
Transformer XL solves this by saving the hidden states from previous segments and using them in the current segment.This enables information to propagate longer than in vanilla transformer models.
It is important to note that we stop gradients from the memory layer inorder to keep the training complexity from blowing up exponentially.
Mathematically this can be expressed as
Relative Positional Encoding
This paper also introduces a robust way to encode the positional information using encoding.When calculating attention we inject a bias term dynamically.This bias term contains information about the position of the token. Traditional absolute position encoding creates temporal confusion as we can have the same positional encodings from the memory as well.
A Relative Position encoding scheme is used as a workaround. When calculating the attention from the query,key pairs all we need is the distance between query token and key token to fully capture the positional information.This is given by
Each of these terms represents an intuitive concept
- Content Based Addressing Term. Based on content of the query and key alone we decide this part of the attention.
- Content Dependent Positional Bias. This term takes into account the relative position of the key with respect to query and the content of the query.
- Global Content Bias. u is a trainable parameter that is the same for all queries.
- Global Position Bias. v is a trainable parameter that is the same for all relative positions.
Now one more question will naturally follow, Why only use only one previous segment for attention caching?
In the paper m segments are used for caching instead of 1. GPU Memory is the only constraint that can keep us from using higher numbers for m. We can use 1 segment while training and use more segments while running inference.
Transformer XL is a huge model hence it needs a high memory GPU setup to pre train or finetune. We will stick to just running inference in this article due to memory constraints
huggingface provides this transformer model as a simple package.A sequence classification head is added on top of Transformer XL and is provided in the library.
from transformers import TransfoXLTokenizer, TransfoXLForSequenceClassification
We will use imdb movie reviews dataset and try to predict the polarity of these reviews.
There are two steps we need to do as the part of inference pipeline.
1.tokenize the inputs and format them as per the model requirement.this is done by a custom tokenizer for this model.
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103', max_length=128, pad_to_max_length=True,) tokenized_inputs=[tokenizer(inputs[i],return_tensors='pt') for i in inputs]
2.Pass the tokenized inputs into the model and collect the outputs.
model = TransfoXLForSequenceClassification.from_pretrained('transfo-xl-wt103') outputs=[model(**i) for i in tokenized_inputs]
We got an average binary cross entropy loss of 0.69.
Transformer XL is an important variation of Transformers as it improves upon a major shortcoming of transformers, context fragmentation. It improved the speed of training and allowed the model to capture longer dependencies. Improvements upon this transformer like the XLNet are beating BERT at critical language tasks.