MITB Banner

How To Implement CNN Model Using PyTorch With TPU

This article demonstrates how we can implement a Deep Learning model using PyTorch with TPU to accelerate the training process. Here, we define a Convolutional Neural Network (CNN) model using PyTorch and train this model in the PyTorch/XLA environment. XLA connects the CNN model with the Google Cloud TPU (Tensor Processing Unit) in the distributed multiprocessing environment. In this implementation, 8 TPU cores are used to create a multiprocessing environment.

Share

CNN Using PyTorch With TPU

With the successful implementation of deep learning models in a variety of applications, now it is the time to get results not only accurate but with a faster speed. To get much accurate results, the size of data definitely matters, but when this size impacts the training time of machine learning models, it is always a concern. To overcome this issue of training time, the TPU runtime environment is used to accelerate the same. In this order, PyTorch has been supporting the machine learning implementation by providing cutting-edge hardware accelerators. The PyTorch support for Cloud TPUs is achieved via integration with XLA (Accelerated Linear Algebra), a compiler for linear algebra that can target multiple types of hardware, including CPU, GPU, and TPU. 

This article demonstrates how we can implement a Deep Learning model using PyTorch with TPU to accelerate the training process. Here, we define a Convolutional Neural Network (CNN) model using PyTorch and train this model in the PyTorch/XLA environment. XLA connects the CNN model with the Google Cloud TPU (Tensor Processing Unit) in the distributed multiprocessing environment. In this implementation, 8 TPU cores are used to create a multiprocessing environment. We will test this PyTorch deep learning framework in Fashion MNIST classification and observe the training time and accuracy.

Implementing CNN Using PyTorch With TPU

We will implement the execution in Google Colab because it provides free of cost cloud TPU (Tensor Processing Unit). Before proceeding further, in the Colab notebook, go to ‘Edit’ and then ‘Notebook Settings’ and select the ‘TPU’ as the ‘Hardware accelerator’ from the list as given in the below screenshot.

CNN Model Using PyTorch With TPU

To verify whether the TPU environment is working properly, run the below line of codes.

import os
assert os.environ['COLAB_TPU_ADDR']

It will be executed successfully if the TPU is enabled otherwise it will return the ‘KeyError: ‘COLAB_TPU_ADDR’’. You can also check the TPU by printing its address.

TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', TPU_Path)



After enabling the TPU, we will install the compatible wheels and dependencies to setup the XLA environment using the below code.

VERSION = "20200516" 
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

Once it is installed successfully, we will proceed to define the methods for loading the data set, initializing the CNN model, training and testing. First of all, we will import the required libraries.

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
from torchvision import datasets, transforms

After that, we will define the hyperparameters to be required further.

# Define Parameters
FLAGS = {}
FLAGS['datadir'] = "/tmp/mnist"
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.01
FLAGS['momentum'] = 0.5
FLAGS['num_epochs'] = 50
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

The below code snippet will define the CNN model as a PyTorch instance and the functions for loading the data, training the model and testing the model.

SERIAL_EXEC = xmp.MpSerialExecutor()

class FashionMNIST(nn.Module):

  def __init__(self):
    super(FashionMNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(FashionMNIST())

def train_mnist():
  torch.manual_seed(1)
 
  def get_dataset():
    norm = transforms.Normalize((0.1307,), (0.3081,))
    train_dataset = datasets.FashionMNIST(
        FLAGS['datadir'],
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), norm]))
    test_dataset = datasets.FashionMNIST(
        FLAGS['datadir'],
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), norm]))

  
    return train_dataset, test_dataset


  # Using the serial executor avoids multiple processes to
  # download the same data.
  train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

  # Scale learning rate to world size
  lr = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model = WRAPPED_MODEL.to(device)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS['momentum'])
  loss_fn = nn.NLLLoss()

  def train_fun(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])
      if x % FLAGS['log_steps'] == 0:
        print('[xla:{}]({}) Loss={:.5f}'.format(
            xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True)

  def test_fun(loader):
    total_samples = 0
    correct = 0
    model.eval()
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    print('[xla:{}] Accuracy={:.2f}%'.format(
        xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_fun(para_loader.per_device_loader(device))
    xm.master_print("Finished training epoch {}".format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_fun(para_loader.per_device_loader(device))
    if FLAGS['metrics_debug']:
      xm.master_print(met.metrics_report(), flush=True)

  return accuracy, data, pred, target

Now, to plot the results as the predicted label and actual labels for the test images, the below function module will be used.

# Result Visualization
import math
from matplotlib import pyplot as plt

M, N = 5, 5
RESULT_IMG_PATH = '/tmp/test_result.png'

def plot_results(images, labels, preds):
  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]
  inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))

  num_images = images.shape[0]
  fig, axes = plt.subplots(M, N, figsize=(12, 12))
  fig.suptitle('Predicted Lables')

  for i, ax in enumerate(fig.axes):
    ax.axis('off')
    if i >= num_images:
      continue
    img, label, prediction = images[i], labels[i], preds[i]
    img = inv_norm(img)
    img = img.squeeze() # [1,Y,X] -> [Y,X]
    label, prediction = label.item(), prediction.item()
    if label == prediction:
      ax.set_title(u'Actual {}/ Predicted {}'.format(label, prediction), color='blue')
    else:
      ax.set_title(
          'Actual {}/ Predicted {}'.format(label, prediction), color='red')
    ax.imshow(img)
  plt.savefig(RESULT_IMG_PATH, transparent=True)

Now, we are all set to train the model on the Fashion MNIST dataset. Before starting the training, we will record the start time and after finishing the training, we will record the end time and print the total training time for 50 epochs.

# Start training processes
def train_cnn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target = train_mnist()
  if rank == 0:
    # Retrieve tensors that are on TPU core 0 and plot.
    plot_results(data.cpu(), pred.cpu(), target.cpu())

xmp.spawn(train_cnn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork')

CNN Using PyTorch With TPU























...
CNN Using PyTorch With TPU







Once the training ends successfully, we will print the total time taken during the training.

end_time = time.time()
print('Total Training time = ',end_time-start_time )

As we can see above, this approach has taken 269 seconds or about 4.5 minutes that means less than 5 minutes to train the PyTorch model in 50 epochs. Finally, we will visualize the predictions by the trained model.

from google.colab.patches import cv2_imshow
import cv2
img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED)
cv2_imshow(img)

CNN Using PyTorch With TPU






























So, we can conclude that using the TPU to implement a deep learning model results in fast training as we have seen above. The training of the CNN model on 40,000 training images in 50 epochs has been performed in less than 5 minutes. We have also obtained more than 89% accuracy during the training. So training the deep learning models on TPU is always a benefit in terms of time and accuracy.

References:-

  1. Joe Spisak, “Get started with PyTorch, Cloud TPUs, and Colab”.
  2. “PyTorch on XLA Devices”, PyTorch release.
  3. “Training PyTorch models on Cloud TPU Pods”, Google Cloud Guides.
Share
Picture of Dr. Vaibhav Kumar

Dr. Vaibhav Kumar

Dr. Vaibhav Kumar is a seasoned data science professional with great exposure to machine learning and deep learning. He has good exposure to research, where he has published several research papers in reputed international journals and presented papers at reputed international conferences. He has worked across industry and academia and has led many research and development projects in AI and machine learning. Along with his current role, he has also been associated with many reputed research labs and universities where he contributes as visiting researcher and professor.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India