Now Reading
How To Avoid Overfitting In Neural Networks

How To Avoid Overfitting In Neural Networks

Bhoomika Madhukar
W3Schools

Deep neural networks deal with a multitude of parameters for training and testing. With the increase in the number of parameters, neural networks have the freedom to fit multiple types of datasets which is what makes them so powerful. But, sometimes this power is what makes the neural network weak. The networks often lose control over the learning process and the model tries to memorize each of the data points causing it to perform well on training data but poorly on the test dataset. This is called overfitting. 

Overfitting occurs when the model tries to make predictions on data that is very noisy. A model that is overfitted is inaccurate because the trend does not reflect the reality present in the data. To overcome this, there are a few techniques that can be used. 

In this article, we will:



  1. Learn the different techniques to avoid overfitting of the model 
  2. Implement these techniques to a deep learning model

Methods to Avoid Overfitting of a Model 

You can identify that your model is not right when it works well on training data but does not perform well on unseen and new data. You can also track the performance of the model performance through concepts like bias and variance. But how to solve this problem? Here are some of the techniques you can use to effectively overcome the overfitting problem in your neural network.

  1. Data Augmentation: Diversity of data and a larger dataset is the easiest way to avoid overfitting of the model. Data augmentation allows you to increase the size of your dataset by performing processes like flipping, cropping, rotation, scaling and translation on the existing images. Data augmentation not only increases the dataset size but also exposes the model to different angles and lighting and reduces the bias in the dataset, thus avoiding chances of overfitting. 
overfitting

2. Regularization Techniques: This method involves adding an extra element to the loss function. This extra element acts as a critic which punishes the model for using higher weights than needed. As the complexity of the model increases, a penalty is added in the loss function that helps in limiting the flexibility of the model. The two popular methods of regularization are the L1 and L2 regularization methods.  L1 regularization reduces the weight values of less important features to zero and eliminates them from further calculations. 

regularization

L2 regularization aims to minimize the magnitude of weights by squaring the weight. The disadvantage is that if there are large numbers of outliers, the square increases the magnitude of the outliers as well and the model tends to not perform as well as it would with L1 regularization. 

With an increase in penalty value, the cost function performs weight tweaking and reduces the increase and therefore reduces the loss and overfitting. 

3. Dropouts: Regularization techniques prevent the model from overfitting by modifying the cost function. Dropout, on the other hand, prevents overfitting by modifying the network itself. It works as follows. Every neuron apart from the ones in the output layer is assigned a probability p of being temporarily ignored from calculations. p is also called dropout rate and is usually initialized to 0.5. Then, as each iteration progresses, the neurons in each layer with the highest probability get dropped. This results in creating a smaller network with each epoch. Since in each iteration, a random input value can be eliminated, the network tries to balance the risk and not to favour any of the features and reduces bias and noise. 

dropout

4. Early Stopping: Early stopping is a technique that can avoid over-training and hence overfitting of the model. An over-trained model has a tendency to memorize all the training data points. With early stopping, a large arbitrary number of training epochs is specified. The model is stopped from training further when the model performance stops improving on the validation dataset. 

early stopping

As you can see in the above figure, after some iterations, test error has started to increase while the training error is still decreasing. So the training is stopped early to prevent the model from overfitting. 

Implementation of Techniques to Avoid Overfitting

Let us go ahead and implement all the above techniques to a neural network model. For a better understanding, we will choose a small dataset like MNIST. With the MNIST dataset, it is very easy to overfit the model. Using the above techniques we will try and avoid it. 

We will import the required libraries and load our dataset. 

import numpy as np
from matplotlib import pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten, Add, BatchNormalization, Conv2D, MaxPooling2D, Convolution2D
from keras.utils import np_utils
from keras.optimizers import Adam
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
from keras.datasets import mnist
%matplotlib inline
(X_train, y_train),(X_test, y_test) = mnist.load_data()
plt.imshow(X_train[4])

Before adding data augmentation, we will pre-process the data by reshaping the inputs, normalizing them and converting the targets into categorical values. 

X_train = X_train.reshape(X_train.shape[0], 28, 28,1)
X_test = X_test.reshape(X_test.shape[0], 28, 28,1)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)

Data Augmentation

I have made use of the built-in method to augment the dataset. But you can choose to augment it using other methods like albumentation library as well. 

augment = ImageDataGenerator(featurewise_center=True, 
                             rotation_range=50,
                              width_shift_range=0.01, 
                              height_shift_range=0.01, 
                              horizontal_flip=False, 
                              vertical_flip=False,
                            featurewise_std_normalization=True)
augment.fit(X_train)

Here, I have added rotation, flipping, shift range and feature wise standard normalization techniques to produce the data augmentations. 

aug = augment.flow(X_train[1:7], batch_size=1)
for i in range(1, 6):
    plt.subplot(1,5,i)
    plt.axis("off")
    plt.imshow(aug)
    plt.plot()
plt.show()

Dropout and Regularization

model=Sequential()
model.add(Conv2D(16, (3, 3), input_shape=(28,28,1), kernel_regularizer=l2(0.01))) 
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Activation('relu'))
model.add(Conv2D(16, (3, 3), kernel_regularizer=l2(0.01))) 
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Activation('relu'))
model.add(Conv2D(10, (1, 1), kernel_regularizer=l2(0.01))) 
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(8, (3, 3), kernel_regularizer=l2(0.01))) 
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Activation('relu'))
model.add(Conv2D(10, (4, 4), kernel_regularizer=l2(0.01))) 
model.add(Flatten())
model.add(Activation('softmax'))

Once this is done we will build our model. In the model built here, I have included both Dropout and the regularization technique. For purposes of this implementation, I have used l2 regularization since it causes faster convergence. 

As you can see above Keras provides a method for both these techniques to be implemented. The kernel_regularizer is passed with l2 regularization. The value of 0.01 passed is the penalty for the loss function every time the model tries to assign higher weights when it is not required. Similarly, Dropout is added with the p-value of 0.5 because it is the default value. 

Early Stopping

The last technique is to add early stopping to the model so that it stops the model from overtraining. 

See Also

from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping()

Early stopping is passed as a callback when the model is fitted with the training and validation set. If no parameters are passed, the default values are taken. You can customize the early stopping parameters as well.

Now, it is time to compile and fit the model. 

model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])
batch_size = 32
history =model.fit_generator(augment.flow(X_train,Y_train, batch_size=32),
                           steps_per_epoch = X_train.shape[0] /batch_size,
                              epochs=40, verbose=1, 
                              validation_data=(X_test, Y_test),
                              callbacks=[early_stopping])

I have given the number of epochs as 40. Since these models can be trained at just 10 epochs, this is a higher number and the model will be stopped early. 

Overfitting Neural Network

As you can see the model has automatically stopped training after 4 epochs because the validation accuracy started to decrease but the training accuracy increased. This means that the model was prevented from overfitting just after 4 epochs. 

We can plot the graph of this to understand better. 

def graph_plot(history, metric):
    train = history.history[metric]
    validation = history.history['val_'+metric]
    epochs = range(1, len(train) + 1)
    plt.plot(epochs, train)
    plt.plot(epochs, validation)
    plt.title('Training and validation '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.show()
graph_plot(history, 'accuracy')
Overfitting Neural Network

This graph shows that when the val_accuracy decreased steeply, the early stopping stopped the model from training further. 

Conclusion

Recognizing and eliminating overfitting from your neural network is key for any machine learning engineer. The goal of this article was to give brief insights about the methods that are available for eliminating overfitting so that the models you build can be useful in real-time and is helpful to the AI community

What Do You Think?

If you loved this story, do join our Telegram Community.


Also, you can write for us and be one of the 500+ experts who have contributed stories at AIM. Share your nominations here.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top