Any machine learning solution comes with two primary challenges — accuracy and training time. There is always a trade-off between these two. There are places where you can’t trade accuracy, such as in the case of self-driving cars. That’s why we don’t see these cars on the road yet as the engineers are training the models with a vast number of features. However, there are applications where long hours of training time can’t be spared.
Sign up for your weekly dose of what's up in emerging technology.
So, what if we just look at the weights and decide whether to train a model or not?
This would dramatically reduce the computational costs of any ML pipeline.
To study the prediction of the accuracy of a neural network given only its weights, the researchers from Google Brain propose a formal setting that frames this task.
Watching The Weights
In their paper, the authors have demonstrated how a certain framework would help gain rich insights from studying the weights. They wrote in detail about the premise of their experiment, which is to figure out whether we can understand the accuracy of neural networks just by using parameters of the networks. Like supervised learning, they have collected 80k NN’s and have tried to figure out the behaviour of these networks based on the weights.
As a first step to setting up the framework, the authors have considered a fixed unknown data-generating distribution, which they call P(X, Y), where X consists of images and Y contains labels of the classes.
As shown in the illustration above, nodes contain hyperparameters λ, weights W, and expected accuracy Acc P(W).
The hyperparameters λ include a number of layers and activation function, learning rate, initialisation variance, weight regularisation and what part of the training set is being used.
For the experiments, convolutional neural networks are trained using hyperparameters λ and get a particular weight vector W, which is a flattened vector containing all the weights.
As a first step, convolutional neural networks CNNs trained in the under-parameterised regime were studied.
Each of the 4 convolutional neural networks is split as 15k CNNs are used for the training split, and the remaining ones were held out for the test split. The entire training and hyperparameter selection for the models took place on the training splits. The evaluation of these test splits is done using 3-fold cross-validation.
For the experiments, the authors explore three different estimators: logit-linear model (L-Linear), gradient boosting machine using regression trees (GBM), and a fully-connected deep neural network. These networks are trained to minimise the mean square error.
To enable predicting accuracy from the flattened weight vector, the number of weights in the architecture are kept small: 3 convolutional layers with 16 filters each, followed by global average pooling and a fully connected layer, for a total of 4970 learnable weights.
The results obtained from the above setting, claim the authors, have shown the accuracy of 55% and 75% on CIFAR10 and SVHN datasets respectively.
Instead of stopping training when networks converge or reach a certain level of accuracy, each CNN is trained for 18 epochs to study CNNs under general conditions.
Finally, the models in which numerical instabilities were detected are discarded.
The main contributions of this paper are:
- Proposal for a new formal setting that captures the approach and relates to previous works.
- Introduction of a new, large dataset with strong baselines and discuss extensive empirical results.
- It is possible to predict the accuracy using trained weights alone.
- Only a few statistics of the weights are sufficient for high accuracy in prediction.
- It is possible to rank neural network models trained on an unknown dataset just by observing the trained weights, without ever having access to the dataset itself.
Weights are (one of) the most important characteristics of any deep neural network and any kind of insights drawn from the initial steps can reduce the overall costs of any experiment.
Read the full work here.