PyTorch is consistently adding a boost to the field of Computer Vision and Deep Learning by providing a number of powerful tools and techniques. In the field of computer vision where the deep learning based executions are to be dealt with heavy image datasets, an accelerated environment is needed to fasten the execution process with an acceptable accuracy level. PyTorch provides this feature through the XLA (Accelerated Linear Algebra), a compiler for linear algebra that can target multiple types of hardware, including GPU, and TPU. The PyTorch/XLA environment is integrated with the Google Cloud TPU and an accelerated speed of execution is achieved.
In this article, we will demonstrate the implementation of ResNet50, a Deep Convolutional Neural Network, in PyTorch with TPU. The model will be trained and tested in the PyTorch/XLA environment in the task of classifying the CIFAR10 dataset. We will also check the time consumed in training this model in 50 epochs.
Implementing ResNet50 in Pytorch
To avail the facility of TPU, this implementation was done in the Google Colab. To start with, first, we need to select the TPU from Hardware accelerators under the notebook settings.
After selecting the TPU, we will verify the environment using the below line of codes:-
import os assert os.environ['COLAB_TPU_ADDR']
It will be executed successfully if the TPU is enabled otherwise it will throw the ‘KeyError: ‘COLAB_TPU_ADDR’’. You can also check the TPU by printing its address.
TPU_Path = 'grpc://'+os.environ['COLAB_TPU_ADDR'] print('TPU Address:', TPU_Path)![]()
In the next step, we will install the XLA environment to accelerate the execution process. The same we have done in one of our last articles where we have implemented the convolutional neural network.
VERSION = "20200516" !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py !python pytorch-xla-env-setup.py --version $VERSION
Now, we will import all the required libraries here.
from matplotlib import pyplot as plt import numpy as np import os import time import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch_xla import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.utils.utils as xu import torchvision from torchvision import datasets, transforms import time from google.colab.patches import cv2_imshow import cv2
After importing the libraries, we will define and initialize the required parameters.
# Define Parameters FLAGS = {} FLAGS['data_dir'] = "/tmp/cifar" FLAGS['batch_size'] = 128 FLAGS['num_workers'] = 4 FLAGS['learning_rate'] = 0.02 FLAGS['momentum'] = 0.9 FLAGS['num_epochs'] = 50 FLAGS['num_cores'] = 8 FLAGS['log_steps'] = 20 FLAGS['metrics_debug'] = False
In the next step, we will define the ResNet50 model.
class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d( in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes)) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = torch.flatten(out, 1) out = self.linear(out) return F.log_softmax(out, dim=1) def ResNet50(): return ResNet(BasicBlock, [3, 4, 6, 4, 3])
The below code snippet will define the functions to load the CIFAR10 dataset, preparing training and test dataset, the training process and the test process.
SERIAL_EXEC = xmp.MpSerialExecutor() # Only instantiate model weights once in memory. WRAPPED_MODEL = xmp.MpModelWrapper(ResNet50()) def train_resnet50(): torch.manual_seed(1) def get_dataset(): norm = transforms.Normalize( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), norm, ]) transform_test = transforms.Compose([ transforms.ToTensor(), norm, ]) train_dataset = datasets.CIFAR10( root=FLAGS['data_dir'], train=True, download=True, transform=transform_train) test_dataset = datasets.CIFAR10( root=FLAGS['data_dir'], train=False, download=True, transform=transform_test) return train_dataset, test_dataset # Using the serial executor avoids multiple processes # to download the same data. train_dataset, test_dataset = SERIAL_EXEC.run(get_dataset) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS['batch_size'], sampler=train_sampler, num_workers=FLAGS['num_workers'], drop_last=True) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS['batch_size'], shuffle=False, num_workers=FLAGS['num_workers'], drop_last=True) # Scale learning rate to num cores learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size() # Get loss function, optimizer, and model device = xm.xla_device() model = WRAPPED_MODEL.to(device) optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=FLAGS['momentum'], weight_decay=5e-4) loss_fn = nn.NLLLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS['batch_size']) if x % FLAGS['log_steps'] == 0: print('[xla:{}]({}) Loss={:.2f} Time={}'.format(xm.get_ordinal(), x, loss.item(), time.asctime()), flush=True) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() data, pred, target = None, None, None for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples print('[xla:{}] Accuracy={:.2f}%'.format( xm.get_ordinal(), accuracy), flush=True) return accuracy, data, pred, target # Train and eval loops accuracy = 0.0 data, pred, target = None, None, None for epoch in range(1, FLAGS['num_epochs'] + 1): para_loader = pl.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) xm.master_print("Finished training epoch {}".format(epoch)) para_loader = pl.ParallelLoader(test_loader, [device]) accuracy, data, pred, target = test_loop_fn(para_loader.per_device_loader(device)) if FLAGS['metrics_debug']: xm.master_print(met.metrics_report(), flush=True) return accuracy, data, pred, target
Now, we will begin the training of ResNet50. The training will be done in the 50 epochs as we have defined in the parameters. Before starting the training, we will record the training time and after training, we will print the total time taken.
start_time = time.time() # Start training processes def training(rank, flags): global FLAGS FLAGS = flags torch.set_default_tensor_type('torch.FloatTensor') accuracy, data, pred, target = train_resnet50() if rank == 0: # Retrieve tensors that are on TPU core 0 and plot. plot_results(data.cpu(), pred.cpu(), target.cpu()) xmp.spawn(training, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')![]()
![]()
![]()
After the training, we will print the time taken during the training process.
end_time = time.time() print("Time taken = ", end_time-start_time)![]()
Finally, we visualize the predictions made by the model on the sample test data during training.
img = cv2.imread(RESULT_IMG_PATH, cv2.IMREAD_UNCHANGED) cv2_imshow(img)![]()
So, we can conclude by analyzing the training performance and predictions, the model has acquired more than 82% of accuracy during the training where the 50 epochs of training were performed in just around 50 minutes. In the end, we can also conclude that we have learnt how to implement the ResNet50 model in PyTorch with TPU on the CIFAR10 dataset. It opens a scope for testing other deep convolutional neural network models on the same or different benchmark datasets.