Training and fine-tuning language models to new knowledge is a tedious process that eats up both time and resources. But what if a language model could acquire new knowledge by just reading new data and memorising it? Google’s new research paper titled, ‘Memorizing Transformers,’ released in ICLR 2022, discusses just how this can be done. The paper notes that attention can be used as a kind of fast learning for long sequences. A model can use attention to memorise facts by storing them as key or value pairs in long-term memory. The model can easily access this information by creating a query.
Adding memory: Paper
Increasing attention context
When performing tasks, transformer performance depends upon the context of attention. This length of the context of attention is usually short, but the paper showed that this can be increased by using approximate k-nearest-neighbour (kNN) lookup. This method is usually used for retrieving information. The kNN lookup algorithm can be scaled up with its variations such as ScaNN and Faiss.
This study looks at increasing the context of attention from a different perspective. While traditional long sequence attention models perform averaging of tokens at long distances, kNN lookup is able to retrieve exact values even from the distant context. Also, gradients aren’t backpropagated into the external memory. Instead, the proposed method reuses keys and values previously computed on prior training steps.
To push scalability, backpropagating gradients into the external memory is essential but requires computing all the keys and values with the parameters on every single training step. Reusing keys and values, on the other hand, cuts down on the amount of computation for large memories. This method maintains a reasonable step time while being able to easily scale external memory to long sequences of up to 131k or 262k tokens on a single TPU device.
Extending transformers: Paper
Methodology
The input text is first tokenised, and then the tokens are embedded into vector spaces. The vector space embeddings are passed through a series of layers of transformers, each of which performs dense self-attention followed by a feed-forward network, or FFN. As the language model is a decoder-only type, a causal attention mask is used, and token embeddings of the previous layer predict the next token.
Then, long documents are split into subsequences of 512 tokens, and each subsequence is used as an input for a training step. When normally, subsequences are shuffled together, in this case, long documents were fed into the transformer sequentially, as is done with the Transformer-XL.
The kNN-augmented attention layer, which is near one of the transformer layers near the top of the stack, combines two types of attention: standard dense self-attention on the local context and an approximate k-nearest-neighbour search into the external memory. Using approximate kNN search instead of exact kNN search speeds the computational speed of the model considerably.
Datasets
Average perplexities of each model: Paper
The research used datasets including a corpus of papers from arXiv Math, open-source code files from GitHub, mathematical theories from Isabelle, a large collection of documents from C4 and English-language books from PG-19.
Findings
By presenting a simple extension to the transformer, known as kNN-augmented attention, the research found that it could increase the length of the context in a language model. The transformer demonstrated improved perplexity over the baseline for all architectures and data models. Aside from this, external memory is also beneficial even if the transformer is scaled from 200M to 8B parameters.
The study concludes that the transformer does not need to be pre-trained from scratch. Rather, memory can be added to an existing pre-trained model and fine-tuned to reap more. kNN retrieval is also able to scale up to much bigger memory sizes and can use their code repositories and huge knowledge bases.
Criticism
Context within datasets
As promising as this alternative sounds, author and data scientist Minhaaj Rehman sparked a discussion on LinkedIn revolving around how effective the method would be in reality. Rehman noted that it was the “nature of the language itself that was elusive, and not how inference servers process the data.” Another data scientist suggested that language models could be improved by the RETRO transformer, or Retrieval Enhanced Transformer, so that models can retrieve data from a 2 trillion token database. However, according to Rehman, this was a misdirection of the issue. Having faster inference did not compensate for the model’s inability to correctly understand contextual information.
According to a paper published in 2020 titled, ‘Context pre-modeling: an empirical analysis for classification based user-centric context-aware predictive modelling,’ even a slightly different context can lead to a hugely different outcome even with a simple dataset. The context of a dataset can influence it far more than the predictive power of a model. To create a model that is aware of context, context pre-modelling is key.