Transfer Learning For Multi-Class Image Classification Using Deep Convolutional Neural Network

In this article, we will implement the multiclass image classification using the VGG-19 Deep Convolutional Network used as a Transfer Learning framework where the VGGNet comes pre-trained on the ImageNet dataset. For the experiment, we will use the CIFAR-10 dataset and classify the image objects into 10 classes. The classification accuracies of the VGG-19 model will be visualized using the non-normalized and normalized confusion matrices.

Image classification has become more interesting in the research field due to the development of new and high performing machine learning frameworks. With the advancement of artificial neural networks and the development of deep learning architectures such as the convolutional neural network, that is based on artificial neural networks has triggered the application of multiclass image classification and recognition of objects belonging to the multiple categories. Every latest machine learning framework has a comparative advantage over the older ones in terms of performance and complexity. 

In this article, we will implement the multiclass image classification using the VGG-19 Deep Convolutional Network used as a Transfer Learning framework where the VGGNet comes pre-trained on the ImageNet dataset. For the experiment, we will use the CIFAR-10 dataset and classify the image objects into 10 classes. The classification accuracies of the VGG-19 model will be visualized using the non-normalized and normalized confusion matrices. 

What is Transfer Learning?

Transfer learning is a research problem in the field of machine learning. It stores the knowledge gained while solving one problem and applies it to a different but related problem. For example, the knowledge gained while learning to recognize cats could apply when trying to recognize cheetahs. In deep learning, transfer learning is a technique whereby a neural network model is first trained on a problem similar to the problem that is being solved. Transfer learning has the advantage of decreasing the training time for a learning model and can result in lower generalization error.

VGGNet – The Deep Convolutional Network

VGGNet is Deep Convolutional Neural Network that was proposed by Karen Simonyan and Andrew Zisserman of the University of Oxford in their research work ‘Very Deep Convolutional Neural Networks for Large-Scale Image Recognition’. The name of this model was inspired by the name of their research group ‘Visual Geometry Group (VGG)’. As this convolutional neural network has 19 layers in its architecture, it was named VGG-19. This model was proposed to reduce the number of parameters in a convolutional neural network with improved training time. Below is the block diagram of VGG-19 that illustrates its architecture.

VGG 19 deep convolutional neural network

(Image source: mc.ai)

The biggest advantage of this network is that You can load a pre-trained version of the network trained on more than a million images from the ImageNet database. A pre-trained network can classify images into thousands of object categories. Due to this advantage, we are going to apply this model on the CIFAR-10 image dataset that has 10 object categories. 

The Dataset

In this experiment, we will be using the CIFAR-10 dataset that is a publically available image data set provided by the Canadian Institute for Advanced Research (CIFAR). It consists of 60000 32×32 colour images in 10 classes, with 6000 images per class. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 50000 training images and 10000 test images in this dataset.

Implementation in Python

We will import the library to download the CIFAR-10 data set.

#Keras library for CIFAR-10 dataset
from keras.datasets import cifar10


#Downloading the CIFAR dataset
(x_train,y_train),(x_test,y_test)=cifar10.load_data()

We will import the remaining libraries that are going to be required in our experiment.

#importing other required libraries
import numpy as np
import pandas as pd
from sklearn.utils.multiclass import unique_labels
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from keras import Sequential
from keras.applications import VGG19 #For Transfer Learning
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD,Adam
from keras.callbacks import ReduceLROnPlateau
from keras.layers import Flatten,Dense,BatchNormalization,Activation,Dropout
from keras.utils import to_categorical

Here, we will split the downloaded dataset into training, test and validation sets.

#defining training and test sets
x_train,x_val,y_train,y_val=train_test_split(x_train,y_train,test_size=.3)

Once split, we will see the shape of our data. It should be same as given in the dataset description at its parent website.

#Dimension of the dataset
print((x_train.shape,y_train.shape))
print((x_val.shape,y_val.shape))
print((x_test.shape,y_test.shape))

CIFAR 10 dataset

We need to do one hot encoding here because we have 10 classes and we should expect the shape[1] of y_train,y_val and y_test to change from 1 to 10

#One Hot Encoding
y_train=to_categorical(y_train)
y_val=to_categorical(y_val)
y_test=to_categorical(y_test)

After one hot encoding, we will ensure that we have obtained the required shape.

#Verifying the dimension after one hot encoding
print((x_train.shape,y_train.shape))
print((x_val.shape,y_val.shape))
print((x_test.shape,y_test.shape))

CIFAR10 dataset

Here, we will perform the image data augmentation. This is the technique that is used to expand the size of a training dataset by creating modified versions of images in the dataset. First, we will define individual instances of ImageDataGenerator for augmentation and then we will fit them with each of the training, test and validation datasets. 

#Image Data Augmentation
train_generator = ImageDataGenerator(rotation_range=2, horizontal_flip=True, zoom_range=.1)

val_generator = ImageDataGenerator(rotation_range=2, horizontal_flip=True, zoom_range=.1)

test_generator = ImageDataGenerator(rotation_range=2, horizontal_flip= True, zoom_range=.1)

#Fitting the augmentation defined above to the data
train_generator.fit(x_train)
val_generator.fit(x_val)
test_generator.fit(x_test)

We will use the learning rate annealer in this experiment. The learning rate annealer decreases the learning rate after a certain number of epochs if the error rate does not change. Here, through this technique, we will monitor the validation accuracy and if it seems to be a plateau in 3 epochs, it will reduce the learning rate by 0.01.

#Learning Rate Annealer
lrr= ReduceLROnPlateau(monitor='val_acc', factor=.01, patience=3, min_lr=1e-5)

Now, we will instantiate the VGG19 that is a deep convolutional neural network as a transfer learning model.

Defining VGG19 as a Deep Convolutional Neural Network
#Defining the VGG Convolutional Neural Net
base_model = VGG19(include_top = False, weights = 'imagenet', input_shape = (32,32,3), classes = y_train.shape[1])

Now, we will define VGG19 as a deep learning architecture. For this purpose, it will be defined as a Keras Sequential model with several dense layers. 

#Adding the final layers to the above base models where the actual classification is done in the dense layers
model= Sequential()
model.add(base_model) 
model.add(Flatten()) 

Now, to add further layers, we need to see the dimension of our model.

#Model summary
model.summary()

deep convolutional neural networks










#Adding the Dense layers along with activation and batch normalization
model.add(Dense(1024,activation=('relu'),input_dim=512))
model.add(Dense(512,activation=('relu'))) 
model.add(Dense(256,activation=('relu'))) 
model.add(Dropout(.3))ense(128,activation=('relu')))
#model.add(Dropout(.2))
model.add(Dense(10,activation=('softmax'))) 

#Checking the final model summary
model.summary()

deep convolutional neural networks

As we have defined our model, now we need to initialize the hyperparameters that are required to train the model and then finally, we will compile our model.

#Initializing the hyperparameters
batch_size= 100
epochs=50
learn_rate=.001
sgd=SGD(lr=learn_rate,momentum=.9,nesterov=False)
adam=Adam(lr=learn_rate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
model.compile(optimizer=sgd,loss='categorical_crossentropy',metrics=['accuracy'])

Now, we start training our VGG10, the deep convolutional neural network model.

#Training the model
model.fit_generator(train_generator.flow(x_train, y_train, batch_siz e= batch_size),epochs = epochs, steps_per_epoch = x_train.shape[0]//batch_size, validation_data = val_generator.flow(x_val, y_val, batch_size = batch_size), validation_steps = 250, callbacks=[lrr], verbose = 1)

training VGG19 deep convolutional neural networks

As we can see in the above picture, we have achieved the training accuracy by 99.22% and validation accuracy by 85.41%. Now we will visualize the accuracy and loss during training.

#Plotting the training and validation loss and accuracy
f,ax=plt.subplots(2,1) 

#Loss
ax[0].plot(model.history.history['loss'],color='b',label='Training Loss')
ax[0].plot(model.history.history['val_loss'],color='r',label='Validation Loss')

#Accuracy
ax[1].plot(model.history.history['accuracy'],color='b',label='Training  Accuracy')
ax[1].plot(model.history.history['val_accuracy'],color='r',label='Validation Accuracy')

training performance of VGG19 deep convolutional neural network

We will make image class predictions through this model using the test data set.

#Making prediction
y_pred=model.predict_classes(x_test)
y_true=np.argmax(y_test,axis=1)
Performance of VGG19 – The Deep Convolutional Neural Network

Finally, we will visualize the classification performance on test data using confusion matrices. 

#Defining function for confusion matrix plot
def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):

    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    #Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

#print(cm)

    fig, ax = plt.subplots(figsize=(7,7))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')


    #Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

np.set_printoptions(precision=2)

First, we will see the exact number of correct and incorrect classification using the non-normalized confusion matrix and then we will see the same in percentage using the normalized confusion matrix. 

#Plotting the confusion matrix
confusion_mtx = confusion_matrix(y_true, y_pred)

#Defining the class labels
class_names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Plotting non-normalized confusion matrix
plot_confusion_matrix(y_true, y_pred, classes = class_names, title='Confusion matrix, without normalization')

confusion matrix of VGG19 
#Plotting normalized confusion matrix
plot_confusion_matrix(y_true, y_pred, classes = class_names, normalize = True, title = 'Normalized confusion matrix')

confusion matrix of VGG19 

As we can see by classifying images into 10 classes, the model has given a minimum accuracy of 72% and a maximum accuracy of 95%. We can further tune the training parameters and re-train our model to see any possible upscaling in the classification. But what we have got in this experiment is the standard one. Out of 10 classes, it has given less than 80% accuracy in classifying only for 3 classes and has given more than 90% accuracy in classifying images of 5 classes. 

Download our Mobile App

Dr. Vaibhav Kumar
Dr. Vaibhav Kumar is a seasoned data science professional with great exposure to machine learning and deep learning. He has good exposure to research, where he has published several research papers in reputed international journals and presented papers at reputed international conferences. He has worked across industry and academia and has led many research and development projects in AI and machine learning. Along with his current role, he has also been associated with many reputed research labs and universities where he contributes as visiting researcher and professor.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
MOST POPULAR

6 IDEs Built for Rust

Rust IDEs aid efficient code development by offering features like code completion, syntax highlighting, linting, debugging tools, and code refactoring

Can OpenAI Save SoftBank? 

After a tumultuous investment spree with significant losses, will SoftBank’s plans to invest in OpenAI and other AI companies provide the boost it needs?

Oracle’s Grand Multicloud Gamble

“Cloud Should be Open,” says Larry at Oracle CloudWorld 2023, Las Vegas, recollecting his discussions with Microsoft chief Satya Nadella last week.