Hands-On Guide to Implement ResNet50 in PyTorch with TPU

In this article, we will demonstrate the implementation of ResNet50, a Deep Convolutional Neural Network, in PyTorch with TPU. The model will be trained and tested in the PyTorch/XLA environment in the task of classifying the CIFAR10 dataset. We will also check the time consumed in training this model in 50 epochs.
ResNet50 in PyTorch with TPU

PyTorch is consistently adding a boost to the field of Computer Vision and Deep Learning by providing a number of powerful tools and techniques. In the field of computer vision where the deep learning based executions are to be dealt with heavy image datasets, an accelerated environment is needed to fasten the execution process with an acceptable accuracy level. PyTorch provides this feature through the XLA (Accelerated Linear Algebra), a compiler for linear algebra that can target multiple types of hardware, including GPU, and TPU. The PyTorch/XLA environment is integrated with the Google Cloud TPU and an accelerated speed of execution is achieved.

In this article, we will demonstrate the implementation of ResNet50, a Deep Convolutional Neural Network, in PyTorch with TPU. The model will be trained and tested in the PyTorch/XLA environment in the task of classifying the CIFAR10 dataset. We will also check the time consumed in training this model in 50 epochs.

Implementing ResNet50 in Pytorch

To avail the facility of TPU, this implementation was done in the Google Colab. To start with, first, we need to select the TPU from Hardware accelerators under the notebook settings.


Sign up for your weekly dose of what's up in emerging technology.

ResNet50 in PyTorch with TPU

After selecting the TPU, we will verify the environment using the below line of codes:-

import os
assert os.environ['COLAB_TPU_ADDR']

It will be executed successfully if the TPU is enabled otherwise it will throw the ‘KeyError: ‘COLAB_TPU_ADDR’’. You can also check the TPU by printing its address.

TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', TPU_Path)

In the next step, we will install the XLA environment to accelerate the execution process. The same we have done in one of our last articles where we have implemented the convolutional neural network.

VERSION = "20200516"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

Now, we will import all the required libraries here.

from matplotlib import pyplot as plt
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms
import time
from google.colab.patches import cv2_imshow
import cv2

After importing the libraries, we will define and initialize the required parameters.

# Define Parameters
FLAGS = {}
FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.02
FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 50
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

In the next step, we will define the ResNet50 model.

class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(
        in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(
        planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion * planes:
      self.shortcut = nn.Sequential(
              self.expansion * planes,
              bias=False), nn.BatchNorm2d(self.expansion * planes))

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out

class ResNet(nn.Module):

  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks - 1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_planes, planes, stride))
      self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = torch.flatten(out, 1)
    out = self.linear(out)
    return F.log_softmax(out, dim=1)

def ResNet50():
  return ResNet(BasicBlock, [3, 4, 6, 4, 3])

The below code snippet will define the functions to load the CIFAR10 dataset, preparing training and test dataset, the training process and the test process.

SERIAL_EXEC = xmp.MpSerialExecutor()
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(ResNet50())

def train_resnet50():

  def get_dataset():
    norm = transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
    transform_test = transforms.Compose([
    train_dataset = datasets.CIFAR10(
    test_dataset = datasets.CIFAR10(
    return train_dataset, test_dataset
  # Using the serial executor avoids multiple processes
  # to download the same data.
  train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
  train_loader = torch.utils.data.DataLoader(
  test_loader = torch.utils.data.DataLoader(

  # Scale learning rate to num cores
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    for x, (data, target) in enumerate(loader):
      output = model(data)
      loss = loss_fn(output, target)
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.2f} Time={}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

Now, we will begin the training of ResNet50. The training will be done in the 50 epochs as we have defined in the parameters. Before starting the training, we will record the training time and after training, we will print the total time taken.

start_time = time.time()
# Start training processes
def training(rank, flags):
  global FLAGS
  FLAGS = flags
  accuracy, data, pred, target = train_resnet50()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(training, args=(FLAGS,), nprocs=FLAGS['num_cores'],

After the training, we will print the time taken during the training process.

end_time = time.time()
print("Time taken = ", end_time-start_time)
ResNet50 in PyTorch with TPU

Finally, we visualize the predictions made by the model on the sample test data during training.

ResNet50 in PyTorch with TPU

So, we can conclude by analyzing the training performance and predictions, the model has acquired more than 82% of accuracy during the training where the 50 epochs of training were performed in just around 50 minutes. In the end, we can also conclude that we have learnt how to implement the ResNet50 model in PyTorch with TPU on the CIFAR10 dataset. It opens a scope for testing other deep convolutional neural network models on the same or different benchmark datasets.

More Great AIM Stories

Dr. Vaibhav Kumar
Vaibhav Kumar has experience in the field of Data Science and Machine Learning, including research and development. He holds a PhD degree in which he has worked in the area of Deep Learning for Stock Market Prediction. He has published/presented more than 15 research papers in international journals and conferences. He has an interest in writing articles related to data science, machine learning and artificial intelligence.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM