Now Reading
Implementing AlexNet Using PyTorch As A Transfer Learning Model In Multi-Class Classification

Implementing AlexNet Using PyTorch As A Transfer Learning Model In Multi-Class Classification

AlexNet is one of the popular variants of the convolutional neural network and used as a deep learning framework. In the last article, we implemented the AlexNet model using the Keras library and TensorFlow backend on the CIFAR-10 multi-class classification problem. In that experiment, we defined a simple convolutional neural network that was based on the prescribed architecture of the ALexNet model as proposed in the research work of Alex Krizhevsky. In that experiment, we did not use the transfer learning approach and did not use the pre-trained network weights on the ImageNet dataset. In that way, we could achieve an average classification accuracy score of 64.8%. 

In this article, we will employ the AlexNet model provided by the PyTorch as a transfer learning framework with pre-trained ImageNet weights. The network will be trained on the CIFAR-10 dataset for a multi-class image classification problem and finally, we will analyze its classification accuracy when tested on the unseen test images. Our aim is to compare the performance of the AlexNet model when it is used as a transfer learning framework and when not used as a transfer learning framework.

Implementation of AlexNet in PyTorch

Now, we are going to implement the pre-trained AlexNet model in PyTorch. Before proceeding further, make sure that you have installed the PyTorch successfully if you are working on your local system. But if you are working in Google Colab and using the hosted runtime, then the installation of PyTorch is not required on the local system. In the Colab, if you wish to use the CUDA interface, set the GPU as the hardware accelerator in the notebook settings. The below code was implemented in Google Colab and the .py file was downloaded.

# -*- coding: utf-8 -*-

Automatically generated by Colaboratory.

Original file is located at

In the next step, we are going to import the most important libraries. Remaining libraries will be imported along with the code segments for better describing the use of that library. In this first step, we will import the torch because we are going to implement our AlexNet model in PyTorch. The torchdivision library is required to import the dataset and other operations. The transforms library will be used to transform the downloaded image into the network compatible image dataset.

import torch
import torchvision
import torchvision.transforms as transforms

Here, we are defining an object through which we can transform an image into the required image dataset that will be compatible with the AlexNet model. Using the below code snippet, the input image will be first converted to the size 256×256 pixels and then cropped to the size 224×224 pixels as the AlexNet model require the input images with size 224×224. Finally, the image dataset will be converted to the PyTorch tensor data type. To normalize the input image data set, the mean and standard deviation of the pixels data is used as per the standard values suggested by the PyTorch.

transform = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

In the below code segment, the CIFAR10 dataset is downloaded from the PyTorch’s dataset library and parallelly transformed into the required shape using the transform method defined above. The DataLoader performs operations on the downloaded data such as customizing data loading order, automatic batching, automatic memory pinning, etc. 

#Downloading training data
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

trainloader =, batch_size=4, shuffle=True, num_workers=2)

#Downloading test data
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader =, batch_size=4, shuffle=False, num_workers=2)

#Class labels

classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

cifar 10 data set in pytorch 

Once the dataset is downloaded, we will visualize some random images from the dataset using the below function.

import matplotlib.pyplot as plt
import numpy as np

#Function to show some random images
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

#Get some random training images
dataiter = iter(trainloader)
images, labels =

#Show images
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

cifar 10 images

Once are confirm with the downloaded image dataset, we ill proceed further and instantiate the AlexNet model. For this purpose, the below code snippet will load the AlexNet model that will be pre-trained on the ImageNet dataset. 

#Now using the AlexNet
AlexNet_model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)

#Model description

AlexNet using PyTorch

As we are going to use this network in image classification with the CIFAR-10 dataset, there will be 10 output class labels to be predicted by the network. For this purpose, we need to update the network because we can see in the above image, the final classifier label that is (6): Linear() is having the 1000 nodes at the output layer. This must be changed to 10. Along with this, we can see a very large number of nodes in the dense layers of the network (see the input and output features in the Linear() under classifier). This may cause the network to overfit or having heavy losses during the training. For this purpose, we will update the structure of each classifier using the below lines of codes.

#Updating the second classifier
AlexNet_model.classifier[4] = nn.Linear(4096,1024)

#Updating the third and the last classifier that is the output layer of the network. Make sure to have 10 output nodes if we are going to get 10 class labels through our model.
AlexNet_model.classifier[6] = nn.Linear(1024,10)

Once updated, we will gain check the description of the model


AlexNet using PyTorch


As we can see in the above description, the last to classifiers are updated and we have 10 nodes as the output features. To speed-up the performance during training, we will use the CUDA interface with GPU. 

#Instantiating CUDA device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Verifying CUDA

#Move the input and AlexNet_model to GPU for speed if available

AlexNet using PyTorch

Now, we will define the optimizer and loss functions. Stochastic gradient descent will be used as an optimizer and cross-entropy will be used for the loss.

See Also
Bringing Old Photos Back to Life

import torch.optim as optim
import torch.nn as nn

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(AlexNet_model.parameters(), lr=0.001, momentum=0.9)

Training the AlexNet

In the next step, we will train the AlexNet model using the below code snippet.

for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients

        # forward + backward + optimize
        output = AlexNet_model(inputs)
        loss = criterion(output, labels)

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training of AlexNet')

Testing the AlexNet Model

Once the training is over, we will test the classification accuracy of our trained model on 10,000 test images.

#Testing Accuracy
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = AlexNet_model(images)
        _, predicted = torch.max(, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

So, as we can see above, the model has given 84.41 % of accuracy in classifying the unseen test images when trained in 10 epochs. Now, we will check the classification accuracy of our model in classifying images of the individual classes.

#Testing classification accuracy for individual classes.
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = AlexNet_model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

In the end, we will match the average accuracy in classifying images of individual classes with the accuracy of the entire network.

avg = 0
for i in range(10):
  temp = (100 * class_correct[i] / class_total[i])
  avg = avg + temp
avg = avg/10
print('Average accuracy = ', avg)


Finally, we can observe that the pre-trained AlexNet model has given the 83% accuracy in multiclass image classification. This accuracy can certainly be improved when we runt this training for more epochs say 100 or 200. That is far better than the AlexNet that we defined in the last article in Keras which was not using the pre-trained weights on the ImageNet dataset. So it can be concluded that the AlexNet model has a very good performance when it is used as a transfer learning framework. 

Note: This article is inspired by the PyTorch’s tutorial on training a classifier in which a simple neural network model has been defined for multiclass image classification.

What Do You Think?

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join our Telegram Group. Be part of an engaging community

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top