Now Reading
Guide To Build Your First Convolutional Neural Network with PyTorch

Guide To Build Your First Convolutional Neural Network with PyTorch

When we consider Machine Learning as a professional skill, it ultimately comes down to the choice of using the right algorithms, packages and libraries and, there are many. Speaking extensively of packages, there are many popular choices to choose from – TensorFlow, Theano, Keras, PyTorch etc. We will focus on PyTorch for now. Developed by the brains of Facebook, PyTorch has a lot to offer in the Machine Learning space. 

In this tutorial, we will give a hands-on walkthrough on how to build a simple Convolutional Neural Network with PyTorch. 

How To Start Your Career In Data Science?


This short tutorial is intended for beginners who possess a basic understanding of the working of Convolutional Neural Networks and want to dip their hands in the code jar with PyTorch library.

Read some of our previous articles on Convolutional Neural Networks to have a good understanding before we dive into CNN with PyTorch. Follow the links below:

Why PyTorch ?

So why do we need to learn PyTorch too in a world that abounds with DL frameworks. PyTorch is a scientific computing package developed by Facebook for Deep Learning. It is built with dynamic computation which allows users to manipulate the computation graphs dynamically, a standout feature that other deep learning packages lack. PyTorch is also fast and has lots of easy to use API’s. PyTorch also has a large community support which makes it a library of choice for many Machine Learning Developers.

Although other packages, especially Tensorflow dominates in the production space, PyTorch has it large user space in researches which is all the more reason to learn to use it.

Let’s get started in and build a simple Convolutional Neural Network.

Importing Torch Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F

Creating A Custom CNN

First, we will define a class that inherits the nn.Module Class in Pytorch. Creating a class to customise the neural networks is a great approach as it gives more room for flexibility in coding making it easier to implement multiple networks.

  1. class CNN(nn.Module):

  2.   def __init__(self):

  3.     super(CNN, self).__init__()

  4.     self.conv1 = nn.Conv2d(in_channels = 3,out_channels = 16, 5)

  5.     self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, 5)

  6.     self.fc1 = nn.Linear(in_features = 32 * 5 * 5, out_features = 150)

  7.     self.fc2 = nn.Linear(in_features = 150,out_features =  90)
  8.     self.fc3 = nn.Linear(in_features = 90,out_features = 10)

  9.   def forward(self, x):

  10.     x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
  11.     x = F.max_pool2d(F.relu(self.conv2(x)), 2 )
  12.     x = x.view(-1, self.num_flat_features(x))
  13.     x = F.relu(self.fc1(x))
  14.     x = F.relu(self.fc2(x))
  15.     x = self.fc3(x)
    16     return x

  16.   def num_flat_features(self, x):

  17.     size = x.size()[1:]
  18.     num_features = 1
  19.     for s in size:
  20.       num_features *= s
  21.     return num_features 


Let’s follow the above code block line by line.

1. Create a custom class called CNN that inherits the nn.Module Class from PyTorch Library.

2. Define the class initializer method.

3. Used in Inheritance, the line delegates object calls to the parent or child class of CNN.

4. Defining the first convolutional layer.

Here the in_channels parameter corresponds to the number of channels in the input image. For grayscale images, the input channels are always 1 and for coloured images, the channels are 3. The number of out_channels can be of a random choice. There is no set rule on the choice of out-channels, the greater the out_channels the more features that are extracted from the images, however, larger numbers will use more resources and also increases the chance of overfitting for smaller datasets. The kernel size is the size of the convolution matrix. Setting it to 3 will use a 3×3 matrix for convolution.

5. Adding a second layer of convolution to the network.A point to be noted is that the second convolutional layer should have the same number of in_channels as the number of out_channels coming from the previous layer.

6. Creating a fully connected network. Here, we have 3 layers the first one being an input layer (line 6) connecting to the convolution layer, the second one being a hidden layer (line 7) and the third, an output layer (line 8). The number of out-features in the output layer corresponds to the number of classes or categories of the images that we need to classify.

See Also

9. Defining the forward method which will pass and forward the inputs (images) through all the layers in the network.

10. Performs max pooling on the image that has passed through the first layer of convolution layer activated with the Relu (Rectified Linear Unit ) activation function. The pooling is performed with a 2×2 matrix for which the shape has been passed as a tuple argument.

11. Saves a max pooled image data which had passed through the second convolution layer activated with the Relu (Rectified Linear Unit ) activation function. The pooling is performed with a 2×2 matrix for which this time the shape has been passed as an integer argument.

12. Flattening and reshaping the pooled matrix using the view method and the num_flat_features method.

13. Feeding the flattened matrix to the fully connected layers. The input layer (Line 13), hidden layer (Line 14) and Output layer (Line 15).

  1. Defining a method to flatten the extracted features after pooling.

Initialising the CNN

We now have a complete template to build as many CNNs as we need. Next, we will create an object to initialize the CNN.

#Initializing the Object for CNN
cnn = CNN()

The above line will create an object for our custom CNN that we built

Let’s have a look at the summary of the CNN that we built.

#Printing the summary of the CNN

This above line of code will print a description of the CNN.


After building the CNN we can use it to train and predict image labels or classes which will learn about in an upcoming tutorial. Until then Happy coding !!

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join our Telegram Group. Be part of an engaging community

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top