Now Reading
Hands-On Guide to Implement ResNet50 in PyTorch with TPU

Hands-On Guide to Implement ResNet50 in PyTorch with TPU

ResNet50 in PyTorch with TPU

PyTorch is consistently adding a boost to the field of Computer Vision and Deep Learning by providing a number of powerful tools and techniques. In the field of computer vision where the deep learning based executions are to be dealt with heavy image datasets, an accelerated environment is needed to fasten the execution process with an acceptable accuracy level. PyTorch provides this feature through the XLA (Accelerated Linear Algebra), a compiler for linear algebra that can target multiple types of hardware, including GPU, and TPU. The PyTorch/XLA environment is integrated with the Google Cloud TPU and an accelerated speed of execution is achieved.

In this article, we will demonstrate the implementation of ResNet50, a Deep Convolutional Neural Network, in PyTorch with TPU. The model will be trained and tested in the PyTorch/XLA environment in the task of classifying the CIFAR10 dataset. We will also check the time consumed in training this model in 50 epochs.

Register for our upcoming Masterclass>>

Implementing ResNet50 in Pytorch

To avail the facility of TPU, this implementation was done in the Google Colab. To start with, first, we need to select the TPU from Hardware accelerators under the notebook settings.

ResNet50 in PyTorch with TPU

After selecting the TPU, we will verify the environment using the below line of codes:-

Looking for a job change? Let us help you.
import os
assert os.environ['COLAB_TPU_ADDR']

It will be executed successfully if the TPU is enabled otherwise it will throw the ‘KeyError: ‘COLAB_TPU_ADDR’’. You can also check the TPU by printing its address.

TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', TPU_Path)

In the next step, we will install the XLA environment to accelerate the execution process. The same we have done in one of our last articles where we have implemented the convolutional neural network.

VERSION = "20200516"
!curl -o
!python --version $VERSION

Now, we will import all the required libraries here.

from matplotlib import pyplot as plt
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms
import time
from google.colab.patches import cv2_imshow
import cv2

After importing the libraries, we will define and initialize the required parameters.

# Define Parameters
FLAGS = {}
FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.02
FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 50
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

In the next step, we will define the ResNet50 model.

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
              self.expansion * planes,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out

class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)

def ResNet50():
  return ResNet(BasicBlock, [3, 4, 6, 4, 3])

The below code snippet will define the functions to load the CIFAR10 dataset, preparing training and test dataset, the training process and the test process.

SERIAL_EXEC = xmp.MpSerialExecutor()
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(ResNet50())

def train_resnet50():

  def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
    transform_test = transforms.Compose([
    train_dataset = datasets.CIFAR10(
    test_dataset = datasets.CIFAR10(
    return train_dataset, test_dataset
  # Using the serial executor avoids multiple processes
  # to download the same data.
  train_dataset, test_dataset =

  train_sampler =
  train_loader =
  test_loader =

  # Scale learning rate to num cores
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model =
  optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    for x, (data, target) in enumerate(loader):
      output = model(data)
      loss = loss_fn(output, target)
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.2f} Time={}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

Now, we will begin the training of ResNet50. The training will be done in the 50 epochs as we have defined in the parameters. Before starting the training, we will record the training time and after training, we will print the total time taken.

start_time = time.time()
# Start training processes
def training(rank, flags):
  global FLAGS
  FLAGS = flags
  accuracy, data, pred, target = train_resnet50()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(training, args=(FLAGS,), nprocs=FLAGS['num_cores'],

After the training, we will print the time taken during the training process.

end_time = time.time()
print("Time taken = ", end_time-start_time)
ResNet50 in PyTorch with TPU

Finally, we visualize the predictions made by the model on the sample test data during training.

ResNet50 in PyTorch with TPU

So, we can conclude by analyzing the training performance and predictions, the model has acquired more than 82% of accuracy during the training where the 50 epochs of training were performed in just around 50 minutes. In the end, we can also conclude that we have learnt how to implement the ResNet50 model in PyTorch with TPU on the CIFAR10 dataset. It opens a scope for testing other deep convolutional neural network models on the same or different benchmark datasets.

What Do You Think?

Join Our Discord Server. Be part of an engaging online community. Join Here.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top