Practical Comparison of Transfer Learning Models in Multi-Class Image Classification

In this article, we will compare the multi-class classification performance of three popular transfer learning architectures - VGG16, VGG19 and ResNet50. These all three models that we will use are pre-trained on ImageNet dataset. For the experiment, we have taken the CIFAR-10 image dataset that is a popular benchmark in image classification. The performances of all the three models will be compared using the confusion matrices and their average accuracies.
transfer learning models

Computer vision is a trend nowadays due to the latest developments in the field of deep learning. Researchers and developers are continuously proposing interesting applications of computer vision using deep learning frameworks. In the last article ‘Transfer Learning for Multi-Class Image Classification Using Deep Convolutional Network’, we used the VGG19 model as a transfer learning framework to classify CIFAR-10 images into 10 classes. Now we will explore the other popular transfer learning architectures in the same task and compare their classification performance.

In this article, we will compare the multi-class classification performance of three popular transfer learning architectures – VGG16, VGG19 and ResNet50. These all three models that we will use are pre-trained on ImageNet dataset. For the experiment, we have taken the CIFAR-10 image dataset that is a popular benchmark in image classification. The performances of all the three models will be compared using the confusion matrices and their average accuracies.

The Dataset

In this experiment, we will be using the CIFAR-10 dataset that is a publically available image data set provided by the Canadian Institute for Advanced Research (CIFAR). It consists of 60000 32×32 colour images in 10 classes, with 6000 images per class. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 50000 training images and 10000 test images in this dataset.

VGG16 and VGG19

VGG16 and VGG 19 are the variants of the VGGNet. VGGNet is a Deep Convolutional Neural Network that was proposed by Karen Simonyan and Andrew Zisserman of the University of Oxford in their research work ‘Very Deep Convolutional Neural Networks for Large-Scale Image Recognition’. The name of this model was inspired by the name of their research group ‘Visual Geometry Group (VGG)’. The VGG16 has 16 layers in its architecture while the VGG19 has 19 layers. 

ResNet50

ResNet is the short name for Residual Networks and ResNet50 is a variant of this having 50 layers. It is a deep convolutional neural network used as a transfer learning framework where it uses the weights of pre-trained ImageNet.

Implementation of Transfer Learning Models in Python

Here, we are going to import all the required libraries. Make sure that you have installed the TensorFlow if you are working on your local system. For the implementation of transfer learning, three models VGG19, VGG16 and ResNet50 are also imported here.

#importing other required libraries
import numpy as np
import pandas as pd
from sklearn.utils.multiclass import unique_labels
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import itertools
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from keras import Sequential
from keras.applications import VGG19, VGG16, ResNet50
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD,Adam
from keras.callbacks import ReduceLROnPlateau
from keras.layers import Flatten, Dense, BatchNormalization, Activation,Dropout
from keras.utils import to_categorical
import tensorflow as tf
import random

Once the libraries are imported successfully, we will download the CIFAR-10 dataset that is a publicly available dataset with Keras. 

#Keras library for CIFAR dataset
from keras.datasets import cifar10
(x_train, y_train),(x_test, y_test)=cifar10.load_data()

After downloading the dataset, we will plot some random images from the dataset CIFAR-10 dataset to verify whether it has been downloaded correctly or not.

W_grid=5
L_grid=5
fig,axes = plt.subplots(L_grid,W_grid,figsize=(10,10))
axes=axes.ravel()
n_training=len(x_train)
for i in np.arange(0,L_grid * W_grid):
    index=np.random.randint(0,n_training) 
    axes[i].imshow(x_train[index])
    axes[i].set_title(y_train[index]) 
    axes[i].axis('off')
plt.subplots_adjust(hspace=0.4)
CIFAR 10 data set 

We will split our dataset into training and validation sets. Training and validation sets will be used during the training and the test set will be used in final prediction on the new image dataset.

#Train-validation-test split
x_train,x_val,y_train,y_val=train_test_split(x_train,y_train,test_size=.3)

After the split, we will perform one-hot encoding on the dataset because our output has 10 classes. First, we will print the shape and after one-hot encoding, we will verify the final shape of the dataset.

#Dimension of the CIFAR10 dataset
print((x_train.shape,y_train.shape))
print((x_val.shape,y_val.shape))
print((x_test.shape,y_test.shape))
CIFAR 10 data set




#Onehot Encoding the labels.
#Since we have 10 classes we should expect the shape[1] of y_train,y_val and y_test to change from 1 to 10
y_train=to_categorical(y_train)
y_val=to_categorical(y_val)
y_test=to_categorical(y_test)

#Verifying the dimension after one hot encoding
print((x_train.shape,y_train.shape))
print((x_val.shape,y_val.shape))
print((x_test.shape,y_test.shape))
CIFAR 10 

In order to preprocess the image dataset to make it available for training the deep learning model, the below image data augmentation steps will be performed. 

#Image Data Augmentation
train_generator = ImageDataGenerator(rotation_range=2, horizontal_flip=True, zoom_range=.1 )

val_generator = ImageDataGenerator(rotation_range=2, horizontal_flip=True,zoom_range=.1)

test_generator = ImageDataGenerator(rotation_range=2,  horizontal_flip= True, zoom_range=.1)

#Fitting the augmentation defined above to the data
train_generator.fit(x_train)
val_generator.fit(x_val)
test_generator.fit(x_test)

Now, we will define the learning rate annealer. As we have discussed in the previous article, the learning rate annealer decreases the learning rate after a certain number of epochs if the error rate does not change.

#Learning Rate Annealer
lrr= ReduceLROnPlateau(monitor='val_acc', factor=.01,  patience=3, min_lr=1e-5)

VGG19 Transfer Learning Model

In the next step, we will initialize our VGG19 model. As we are going to use the VGG10 as a transfer learning framework, we will use the pre-trained ImageNet weights with this model.

base_model_VGG19 = VGG19(include_top=False, weights='imagenet', input_shape=(32,32,3), classes=y_train.shape[1])

Now we will add the layers to the VGG19 network that we have initialized above.

#Adding the final layers to the above base models where the actual classification is done in the dense layers
model_vgg19 = Sequential()
model_vgg19.add(base_model_VGG19) 
model_vgg19.add(Flatten()) 
model_vgg19.add(Dense(1024,activation=('relu'),input_dim=512))
model_vgg19.add(Dense(512,activation=('relu'))) 
model_vgg19.add(Dense(256,activation=('relu'))) 
#model_vgg19.add(Dropout(.3))
model_vgg19.add(Dense(128,activation=('relu')))
#model_vgg19.add(Dropout(.2))
model_vgg19.add(Dense(10,activation=('softmax')))

After adding all the layers, we will check the model’s summary.

#VGG19 Model Summary
model_vgg19.summary()

VGG19 ; transfer learning models

Next, we will define the training hyperparameters and compile our model. For error optimization, we will be using stochastic gradient descent. 

#Defining the hyperparameters
batch_size= 100
epochs=50
learn_rate=.001
sgd=SGD(lr=learn_rate,momentum=.9,nesterov=False)

#Compiling the VGG19 model
model_vgg19.compile(optimizer=sgd,loss='categorical_crossentropy',metrics=['accuracy'])

After defining all the hyperparameters, we will train our model in 20 epochs.

model_vgg19.fit_generator(train_generator.flow(x_train, y_train, batch_size = batch_size), epochs=epochs, steps_per_epoch = x_train.shape[0]//batch_size, validation_data = val_generator.flow(x_val, y_val, batch_size = batch_size), validation_steps = 250, callbacks = [lrr], verbose = 1)

VGG19 training 

The training performance will be visualized now in terms of loss and accuracy during the training and the validation.

#Plotting the training and validation loss
f,ax=plt.subplots(2,1) #Creates 2 subplots under 1 column
#Training loss and validation loss
ax[0].plot(model_vgg19.history.history['loss'],color='b',label='Training Loss')
ax[0].plot(model_vgg19.history.history['val_loss'],color='r',label='Validation Loss')
#Training accuracy and validation accuracy
ax[1].plot(model_vgg19.history.history['acc'],color='b',label='Training  Accuracy')
ax[1].plot(model_vgg19.history.history['val_acc'],color='r',label='Validation Accuracy')

VGG19 training performance 

To plot the confusion matrix, we will define a function here.

#Defining function for confusion matrix plot
def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.Blues):
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Computing confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

# Visualizing
    fig, ax = plt.subplots(figsize=(7,7))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

   # Rotating the tick labels and setting their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")
    # Looping over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax
np.set_printoptions(precision=2)

We will make the predictions through the trained VGG19 model using the test image dataset.

#Making prediction
y_pred1 = model_vgg19.predict_classes(x_test)
y_true = np.argmax(y_test,axis=1)

Now, we will plot the non-normalized confusion matrix to visualize the exact number of classifications and normalized confusion matrix to visualize the percentage of classifications. 

#Plotting the confusion matrix
confusion_mtx=confusion_matrix(y_true,y_pred)

class_names=['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

#Plotting non-normalized confusion matrix
plot_confusion_matrix(y_true, y_pred1, classes = class_names,  title = 'Non-Normalized VGG19 Confusion Matrix')

VGG19 confusion matrix

#Plotting normalized confusion matrix
plot_confusion_matrix(y_true, y_pred1, classes = class_names, normalize = True, title = 'Normalized VGG19 Confusion matrix')

VGG19 confusion matrix

Finally, we will see the average classification accuracy of VGG19.

#Accuracy of VGG19
from sklearn.metrics import accuracy_score
accuracy_score(y_true, y_pred1)
VGG19 accuracy score

VGG16 Transfer Learning Model

As the next model, we will repeat the above steps for the VGG16 model.

#VGG16 Model
base_model_vgg16 = VGG16(include_top = False, weights= 'imagenet', input_shape = (32,32,3), classes = y_train.shape[1])

#Adding the final layers to the above base models where the actual classification is done in the dense layers
model_vgg16= Sequential()
model_vgg16.add(base_model_vgg16) 
model_vgg16.add(Flatten())
#Adding the Dense layers along with activation and batch normalization
model_vgg16.add(Dense(1024,activation=('relu'),input_dim=512))
model_vgg16.add(Dense(512,activation=('relu'))) 
model_vgg16.add(Dense(256,activation=('relu'))) 
#model.add(Dropout(.3))
model_vgg16.add(Dense(128,activation=('relu')))
#model.add(Dropout(.2))
model_vgg16.add(Dense(10,activation=('softmax')))

#Checking the final VGG16 model summary
model_vgg16.summary()

VGG16 ; transfer learning models


#Compiling VGG16
model_vgg16.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])


#Training VGG16
model_vgg16.fit_generator(train_generator.flow(x_train, y_train, batch_size = batch_size), epochs = epochs, steps_per_epoch = x_train.shape[0]//batch_size, validation_data = val_generator.flow(x_val, y_val, batch_size = batch_size), validation_steps=250, callbacks=[lrr], verbose=1)

VGG16 training


#Plotting the VGG16 training and validation loss
f,ax=plt.subplots(2,1) #Creates 2 subplots under 1 column
#Training loss and validation loss
ax[0].plot(model_vgg16.history.history['loss'],color='b',label='Training Loss')
ax[0].plot(model_vgg16.history.history['val_loss'],color='r',label='Validation Loss')
#Training accuracy and validation accuracy
ax[1].plot(model_vgg16.history.history['accuracy'],color='b',label='Training  Accuracy')
ax[1].plot(model_vgg16.history.history['val_accuracy'],color='r',label='Validation Accuracy')

VGG16 training performance 

#Making prediction
y_pred2=model_vgg16.predict_classes(x_test)
y_true=np.argmax(y_test,axis=1)

#Plotting the confusion matrix
confusion_mtx=confusion_matrix(y_true,y_pred2)

#Plotting non-normalized confusion matrix
plot_confusion_matrix(y_true, y_pred2, classes = class_names,title = 'Non-Normalized VGG16 Confusion Matrix')

VGG16 confusion matrix

#Plotting normalized confusion matrix
plot_confusion_matrix(y_true, y_pred2, classes = class_names, normalize = True, title= 'Normalized VGG16 Confusion matrix')

VGG16 confusion matrix


#Accuracy of VGG16
from sklearn.metrics import accuracy_score
accuracy_score(y_true, y_pred2)
VGG16 accuracy score

ResNet50 Transfer Learning Model

In the next step, we will perform the same steps with the ResNet50 model.

#Initializing ResNet50
base_model_resnet = ResNet50(include_top = False, weights = 'imagenet', input_shape = (32,32,3), classes = y_train.shape[1])
#Adding layers to the ResNet50
model_resnet=Sequential()
#Add the Dense layers along with activation and batch normalization
model_resnet.add(base_model_resnet)
model_resnet.add(Flatten())
#Add the Dense layers along with activation and batch normalization
model_resnet.add(Dense(1024,activation=('relu'),input_dim=512))
model_resnet.add(Dense(512,activation=('relu'))) 
model_resnet.add(Dropout(.4))
model_resnet.add(Dense(256,activation=('relu'))) 
model_resnet.add(Dropout(.3))
model_resnet.add(Dense(128,activation=('relu')))
model_resnet.add(Dropout(.2))
model_resnet.add(Dense(10,activation=('softmax')))

#Summary of ResNet50 Model
model_resnet.summary()

ResNet50 ; transfer learning models

#Compiling ResNet50
model_resnet.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])


#Training the ResNet50 model
model_resnet.fit_generator(train_generator.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, steps_per_epoch = x_train.shape[0]//batch_size, validation_data = val_generator.flow(x_val, y_val, batch_size = batch_size), validation_steps = 250, callbacks = [lrr], verbose=1)

ResNet50 training


#Plotting the training and validation loss
f,ax=plt.subplots(2,1) #Creates 2 subplots under 1 column
#Training loss and validation loss
ax[0].plot(model_resnet.history.history['loss'],color='b',label='Training Loss')
ax[0].plot(model_resnet.history.history['val_loss'],color='r',label='Validation Loss')
#Training accuracy and validation accuracy
ax[1].plot(model_resnet.history.history['accuracy'],color='b',label='Training  Accuracy')
ax[1].plot(model_resnet.history.history['val_accuracy'],color='r',label='Validation Accuracy')

ResNet50 training performance 

#Making prediction
y_pred3=model_resnet.predict_classes(x_test)
y_true=np.argmax(y_test,axis=1)

#Plotting the non normalized confusion matrix
confusion_mtx=confusion_matrix(y_true,y_pred3)

#Plotting non-normalized confusion matrix
plot_confusion_matrix(y_true, y_pred3, classes = class_names, title = 'Non-Normalized ResNet50 Confusion Matrix')

ResNet50 Confusion matrix

#Plotting normalized confusion matrix
plot_confusion_matrix(y_true, y_pred3, classes=class_names, normalize = True, title = 'Normalized ResNet50 Confusion Matrix')

ResNet50 Confusion matrix

#ResNet50 Classification accuracy
from sklearn.metrics import accuracy_score
accuracy_score(y_true, y_pred3)
ResNet50 accuracy score

Hence, the accuracy scores of all the three models are:-

Model VGG19 VGG16 ResNet50
Accuracy(%) 85.08 84.54 79.48

Finally, we are ready with all the evaluation matrices to analyze the three transfer learning-based deep convolutional neural network models. By analyzing accuracy scores and confusion matrices of all the tree models – VGG19, VGG16 and the ResNet50, we can conclude that the VGG19 has the best performance among all. The above scores are obtained in 20 epochs of training. It is possible that the score may be improved if we train the models in more epochs.

Download our Mobile App

Dr. Vaibhav Kumar
Dr. Vaibhav Kumar is a seasoned data science professional with great exposure to machine learning and deep learning. He has good exposure to research, where he has published several research papers in reputed international journals and presented papers at reputed international conferences. He has worked across industry and academia and has led many research and development projects in AI and machine learning. Along with his current role, he has also been associated with many reputed research labs and universities where he contributes as visiting researcher and professor.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
MOST POPULAR