This New Technique Called Distillation Can Vastly Speed Up Today’s Neural Networks

Today’s machine learning requires a large amount of data, even as the whole deep learning community is trying to speed up the training time. So much so, that researchers at Google recently came up with a method to train large scale distributed neural networks through a method called online distillation. Google even announced the ambitious aim of building a ‘One Model to Rule Them All’ — an idea where a universal model can take in various inputs and learn abstract representations and tasks.

Even though the aim is to speed up the training time, it is important to note that the goal should not be achieved by compromising accuracy. This research is remarkable since it doesn’t add any complexity to the already available training procedure. Increasing the available resources such as data, infrastructure and computing power, among others, can help us improve our performance. But there is a limit to this.

Dominant Training Procedures

Distributed stochastic gradient descent (SGD), both in its synchronous and asynchronous form, has emerged as the dominant algorithm for large-scale neural network training in a distributed setting. As previously mentioned, we can improve resources to train better models but there are diminishing returns after a certain a point. Therefore, these conditions become a barrier which puts a constraint on the scalability of the SGD algorithm.

To improve performances, the researcher relied on a process called distillation. This process refines an n-way ensemble of models into a single still – servable model using a two-phase process: First the researchers use n x M machines to train an n-way ensemble with distributed SGD and then use M machines to train the servable student network to mimic the n-way ensemble. By adding another phase to the training process and using more machines, distillation in general increases training time and complexity in return for quality improvement. The additional training costs discourage practitioners from using ensemble distillation, even though it improves results.


As a solution to the problem above, the researchers describe a simpler online variant called codistillation. This is simple to use in practice compared to a multi-phase distillation training procedure. In codistillation, much more emphasis was laid on improving the communication between layers, while designing the algorithm.

The researchers use codistillation to refer to distillation performed:

  1. Using the same architecture for all the models
  2. Using the same dataset to train all the models
  3. Using the distillation loss during training before any model has fully converged
Codistillation Algorithm


Data Sets, Models And Experiments

In order to understand the scalability of distributed training using codistillation, the researchers planned a task which is representative of important large-scale neural network training problems. Since neural language modelling is an ideal task, the researchers chose the task of language modelling. It is basically a simpler way to evaluate — and it uses a undemanding pipeline.

Common Crawl is an open repository of web crawl data. The researchers used English language documents which have long paragraphs because they wanted a data which allowed modelling of long range dependencies. The researchers constructed batches of 32 word pieces. To begin with, the goal was to determine the maximum number of GPU workers which can be employed for SGD. The researchers also tried asynchronous SGD with 32 and 128 workers and found that with large number of workers it is difficult to keep training stable. It was also found that the maximum number of GPU workers which can be productively employed for synchronous SGD will depend on infrastructure limits, tail latency, and batch size effects.

In the experiments, it was found that synchronous SGD with 128 workers had the strongest baseline performance in terms of training time and accuracy. Hence the researchers focussed the rest of the experiments on comparisons with 128 worker synchronous SGD and studied codistillation which uses synchronous SGD as a subroutine.

Reducing Prediction Churn

The general reproducibility problem where retraining a neural network after a minor change causes a change to the predictions. This is known as prediction churn. On large datasets with a stable training pipeline, even minor changes to model architecture can cause dramatic changes in the results. With rare features and unique examples, different networks tend to get good results on varied test cases. The results of a model trained by stochastic gradient descent hugely depend on initialisation, data presentation order and other infrastructure issues.

Model averaging is a very natural solution to reduce prediction churn. By averaging away the differences in the training procedure the predictions tend to be more consistent. The researchers also claim that given that the codistillation achieves many of the benefits of model averaging, it should similarly help reduce prediction churn.

Hence distillation is really useful because it can be used to accelerate training, improve quality, distribute training, communicate in efficient ways and reduce prediction churn.


Download our Mobile App

Abhijeet Katte
As a thorough data geek, most of Abhijeet's day is spent in building and writing about intelligent systems. He also has deep interests in philosophy, economics and literature.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox

6 IDEs Built for Rust

Rust IDEs aid efficient code development by offering features like code completion, syntax highlighting, linting, debugging tools, and code refactoring