Now Reading
How To Implement CNN Model Using PyTorch With TPU

How To Implement CNN Model Using PyTorch With TPU

CNN Using PyTorch With TPU

With the successful implementation of deep learning models in a variety of applications, now it is the time to get results not only accurate but with a faster speed. To get much accurate results, the size of data definitely matters, but when this size impacts the training time of machine learning models, it is always a concern. To overcome this issue of training time, the TPU runtime environment is used to accelerate the same. In this order, PyTorch has been supporting the machine learning implementation by providing cutting-edge hardware accelerators. The PyTorch support for Cloud TPUs is achieved via integration with XLA (Accelerated Linear Algebra), a compiler for linear algebra that can target multiple types of hardware, including CPU, GPU, and TPU. 

This article demonstrates how we can implement a Deep Learning model using PyTorch with TPU to accelerate the training process. Here, we define a Convolutional Neural Network (CNN) model using PyTorch and train this model in the PyTorch/XLA environment. XLA connects the CNN model with the Google Cloud TPU (Tensor Processing Unit) in the distributed multiprocessing environment. In this implementation, 8 TPU cores are used to create a multiprocessing environment. We will test this PyTorch deep learning framework in Fashion MNIST classification and observe the training time and accuracy.

Register for our Workshop on How To Start Your Career In Data Science?

Implementing CNN Using PyTorch With TPU

We will implement the execution in Google Colab because it provides free of cost cloud TPU (Tensor Processing Unit). Before proceeding further, in the Colab notebook, go to ‘Edit’ and then ‘Notebook Settings’ and select the ‘TPU’ as the ‘Hardware accelerator’ from the list as given in the below screenshot.

CNN Model Using PyTorch With TPU

To verify whether the TPU environment is working properly, run the below line of codes.

import os
assert os.environ['COLAB_TPU_ADDR']

It will be executed successfully if the TPU is enabled otherwise it will return the ‘KeyError: ‘COLAB_TPU_ADDR’’. You can also check the TPU by printing its address.

TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', TPU_Path)

After enabling the TPU, we will install the compatible wheels and dependencies to setup the XLA environment using the below code.

VERSION = "20200516" 
!curl -o
!python --version $VERSION

Once it is installed successfully, we will proceed to define the methods for loading the data set, initializing the CNN model, training and testing. First of all, we will import the required libraries.

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
from torchvision import datasets, transforms

After that, we will define the hyperparameters to be required further.

# Define Parameters
FLAGS = {}
FLAGS['datadir'] = "/tmp/mnist"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.01
FLAGS['momentum'] = 0.5
FLAGS['num_epochs'] = 50
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

The below code snippet will define the CNN model as a PyTorch instance and the functions for loading the data, training the model and testing the model.

SERIAL_EXEC = xmp.MpSerialExecutor()

class FashionMNIST(nn.Module):

  def __init__(self):
    super(FashionMNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(FashionMNIST())

def train_mnist():
  def get_dataset():
    norm = transforms.Normalize((0.1307,), (0.3081,))
    train_dataset = datasets.FashionMNIST(
            [transforms.ToTensor(), norm]))
    test_dataset = datasets.FashionMNIST(
            [transforms.ToTensor(), norm]))

    return train_dataset, test_dataset

  # Using the serial executor avoids multiple processes to
  # download the same data.
  train_dataset, test_dataset =

  train_sampler =

  train_loader =

  test_loader =

  # Scale learning rate to world size
  lr = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model =
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
  loss_fn = nn.NLLLoss()

  def train_fun(loader):
    tracker = xm.RateTracker()
    for x, (data, target) in enumerate(loader):
      output = model(data)
      loss = loss_fn(output, target)
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f}'.format(
            xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)

  def test_fun(loader):
    total_samples = 0
    correct = 0
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_fun(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

Now, to plot the results as the predicted label and actual labels for the test images, the below function module will be used.

# Result Visualization
import math
from matplotlib import pyplot as plt

M, N = 5, 5
RESULT_IMG_PATH = '/tmp/test_result.png'

def plot_results(images, labels, preds):
  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]
  inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))

  num_images = images.shape[0]
  fig, axes = plt.subplots(M, N, figsize=(12, 12))
  fig.suptitle('Predicted Lables')

  for i, ax in enumerate(fig.axes):
    if i >= num_images:
    img, label, prediction = images[i], labels[i], preds[i]
    img = inv_norm(img)
    img = img.squeeze() # [1,Y,X] -> [Y,X]
    label, prediction = label.item(), prediction.item()
    if label == prediction:
      ax.set_title(u'Actual {}/ Predicted {}'.format(label, prediction), color='blue')
          'Actual {}/ Predicted {}'.format(label, prediction), color='red')
  plt.savefig(RESULT_IMG_PATH, transparent=True)

Now, we are all set to train the model on the Fashion MNIST dataset. Before starting the training, we will record the start time and after finishing the training, we will record the end time and print the total training time for 50 epochs.

See Also

# Start training processes
def train_cnn(rank, flags):
  global FLAGS
  FLAGS = flags
  accuracy, data, pred, target = train_mnist()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(train_cnn, args=(FLAGS,), nprocs=FLAGS['num_cores'],

CNN Using PyTorch With TPU

CNN Using PyTorch With TPU

Once the training ends successfully, we will print the total time taken during the training.

end_time = time.time()
print('Total Training time = ',end_time-start_time )

As we can see above, this approach has taken 269 seconds or about 4.5 minutes that means less than 5 minutes to train the PyTorch model in 50 epochs. Finally, we will visualize the predictions by the trained model.

from google.colab.patches import cv2_imshow
import cv2

CNN Using PyTorch With TPU

So, we can conclude that using the TPU to implement a deep learning model results in fast training as we have seen above. The training of the CNN model on 40,000 training images in 50 epochs has been performed in less than 5 minutes. We have also obtained more than 89% accuracy during the training. So training the deep learning models on TPU is always a benefit in terms of time and accuracy.


  1. Joe Spisak, “Get started with PyTorch, Cloud TPUs, and Colab”.
  2. “PyTorch on XLA Devices”, PyTorch release.
  3. “Training PyTorch models on Cloud TPU Pods”, Google Cloud Guides.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join our Telegram Group. Be part of an engaging community

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top