Neural networks achieve perfect generalisation, well past the point of overfitting, in some cases through grokking a pattern in data. In a potential ground breaking study, researchers from OpenAI (Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, Vedant Misra) have explored generalisation of neural networks on small algorithmically generated datasets. The team explored generalisation as a function of dataset size and discovered that smaller datasets call for large amounts of optimisation for generalisation.
The generalisation of overparameterised neural networks has long piqued the curiosity of the machine learning community as it runs counter to the intuitions drawn from classical learning theory. The researchers demonstrated that training networks on small algorithmically generated datasets are inclined to manifest unusual generalisation patterns–detached from performance on the training set–more conspicuously compared to datasets derived from natural data. The experiments can be reproduced on a single GPU.
Sign up for your weekly dose of what's up in emerging technology.
What is Grokking
Suppose you train an overparametrized neural network (one with more parameters than the number of data points in your dataset) past the point where it has memorised the training data (as indicated by low training loss and high val loss). In that case, the network will suddenly learn to generalise, as indicated by a rapid decrease in the val loss (also known as “grok”). Practitioners usually hit pause on training networks at the first hint of overfitting (as shown by a rising gap between train/val loss). This goes against traditional statistical wisdom, which recommends that you should use underparametrized models to force the model to learn the rule (and hence generalise to new situations).
In their paper “Grokking: Generalization Beyond Overfitting On Small Algorithmic Datasets”, the authors present certain conclusions on Grokking and generalisation:
- Neural networks can generalise to empty slots in various binary op tables.
- Validation accuracy can abruptly climb from chance level to perfect generalisation long after significant overfitting. This is referred to as ‘grokking.’
- For a variety of binary operations, there have been data efficiency curves.
- Empirically when the dataset size drops, the amount of optimisation necessary for generalisation rapidly increases.
- Weight decay is especially helpful at improving generalisation on the tasks of grokking.
- Symbol embeddings discovered by these networks occasionally reveal the discernible structure of the mathematical objects represented.
Deep learning practitioners see small improvements in validation accuracy after validation loss stops decreasing. A double descent of validation loss has been observed in rare cases, and is seen as an outlier.
The researchers saw enhanced generalisation after initial overfitting occurs for a range of models, optimisers, and dataset sizes. Such behaviours are typical for all the binary operations for dataset sizes close to the minimal dataset size for which the network generalised within the allotted optimization budget, the researchers noted. For larger dataset sizes, the training and validation curves tend to be in alignment.
Researchers have used convolutional neural networks to investigate a large variety of generalisation or complexity measures to identify which are predictive of generalisation performance. Flatness-based measurements that evaluate the trained neural network’s sensitivity to parameter perturbations are shown to be the most predictive. Power et al. have hence hypothesised that the grokking phenomenon reported was caused by SGD noise forcing optimisation to flatter/simpler solutions that generalise better.
Moreover, the researchers noticed an interesting phenomenon: The number of optimisation steps needed to hit a given level of performance increases quickly as the size of the training dataset is reduced. Since this represents a way to trade compute for performance on smaller amounts of data, it would be useful to investigate in future work whether the effect is also present for other datasets.