Google has launched a grammar correction feature that will come directly built into the Gboard on Pixel 6. The tech giant said that this feature would detect and suggest corrections for grammatical errors while the user is typing. At the moment, this feature will work on rectifying English sentences, but Google has plans to expand to more languages soon. Previously, Google had made improvements to correct language errors in Google Docs by using Neural Grammar Correction in Docs.
Google AI has been introducing different innovations using neural networks and machine learning recently. It released a new method called Task Affinity Groupings (TAG) that tells you which tasks should be trained together in multi-task neural networks. It also came out with a new differentially private clustering algorithm that generates new representative data points privately.
Sign up for your weekly dose of what's up in emerging technology.
Google trained a sequence-to-sequence neural network to take an input sentence or a sentence prefix) and output the grammatically correct version. If the original text is already grammatically correct, the output of the model is identical to its input which shows that no corrections are needed.
What does the model use?
Transformer encoder: It is a neural network architecture based on a self-attention mechanism. Though recurrent neural networks (RNNs) have been the common network architecture for translation, processing language sequentially, they conduct multiple steps to make decisions that depend on words far away from each other. Due to its sequential nature, it makes it difficult to take advantage of fast computing devices.
Transformer, on the other hand, performs a small, constant number of steps that are selected empirically. During each step, it applies a self-attention mechanism that directly models relationships between all words in a sentence, regardless of their respective position.
LSTM decoder: Long short-term memory (LSTM) is a recurrent neural network architecture with feedback connections with the capability of processing entire sequences of data. LSTM units usually consist of a cell, an input gate, an output gate, and a forget gate. It finds applications in unsegmented connected handwriting recognition, speech recognition, etc.
Image: Google (Overview of the grammatical error correction (GEC) model architecture)
Google used techniques such as shared embedding, factorized embedding and quantization to solve the problem of limited memory and computational power.
Reducing file size
- Shared embedding
Some of the model weights are shared between the Transformer encoder and the LSTM decoder. This leads to a reduction in model file size without impacting accuracy.
The model splits a sentence into a sequence of predefined tokens, which is needed to achieve good quality. But this increases the model size. A factorized embedding separates the size of the hidden layers from the size of the vocabulary embedding.
Google performs post-training quantization allowing it to store each 32-bit floating-point weight using only 8-bits. Each weight is stored with lower fidelity, but the quality of the model is not affected.
Google needed training data in the form of <original, corrected> text pairs and found hard distillation to generate training data better matched to the on-device domain yields.
How does hard distillation work?
- Google collected hundreds of millions of English sentences from across the public web.
- Used the cloud-based grammar model to generate grammar corrections for those sentences
- The training dataset of <original, corrected> sentence pairs is used to train a smaller on-device model that can correct full sentences.
Google found out that the on-device model built from this training dataset produces significantly higher quality suggestions than a similar-sized on-device model built on the original data used to train the cloud-based model.
Even before training the model from the data, the model has to be able to handle sentence prefixes. It is even more needed in messaging apps, where the user often omits the final period in a sentence and presses the send button as soon as they finish typing. Google used heuristics to solve this issue. This means that if a given sentence prefix can be completed to form a grammatically correct sentence, it is considered grammatically correct and incorrect, if not.
It created a second dataset suitable for training a large cloud-based model with a focus on sentence prefixes. By the heuristic previously mentioned, Google generated data by using the <original, corrected> sentence pairs from the cloud-based model’s training dataset and randomly sampling aligned prefixes from them.
- Then, it autocompletes each original prefix to a full sentence using a neural language model.
- If a full-sentence grammar model finds no errors in the full sentence, that means there is at least one possible way to complete this original prefix without making any grammatical errors. Then, the original prefix to be correct and output <original prefix, original prefix> is taken as a training example. Else, <original prefix, corrected prefix> is the output.
- The training data is used to train a large cloud-based model that can correct sentence prefixes, then uses that model for hard distillation, generating new <original, corrected> sentence prefix pairs that are better matched to the on-device domain.
- By combining the new sentence prefix pairs with the full sentence pairs, the final training data for the on-device model is built. It has the capability of correcting both full sentences as well as sentence prefixes.
Image: Google (Training data for the on-device model is generated from cloud-based models)
What happens when the user types?
When the mobile user has typed more than three words, Gboard sends a request to the on-device grammar model. Google underlines the grammar mistakes and provides replacement suggestions. As the model outputs only corrected sentences, the mistakes need to be changed into replacement suggestions.
Google syncs the original sentence and the corrected sentence by minimizing the Levenshtein distance. This is the number of edits that are needed to transform the original sentence to the corrected sentence. In the end, the method transforms the insertion edits and deletion edits to be replacement edits.