Advertisement

Active Hackathon

Converting A Model From Pytorch To Tensorflow: Guide To ONNX

Open Neural Network Exchange (ONNX) is a powerful and open format built to represent machine learning models. It overcomes the problem of framework lock-in by providing an universal intermediary model format.
Converting a model from Pytorch to Tensorflow: Guide to ONNX

Open Neural Network Exchange (ONNX) is a powerful and open format built to represent machine learning models. The final outcome of training any machine learning or deep learning algorithm is a model file that represents the mapping of input data to output predictions in an efficient manner. These models are stored in different file formats depending on the framework they were created in .pkl for Scikit-learn, .pb for TensorFlow, .pth for PyTorch, and so on. Therein lies the problem, you can’t take a model created and trained in one framework and use it or deploy it in a different framework.  

The intent behind ONNX is to be like the “USB standard” of the machine learning world. Before the introduction of USB computers and computer peripherals used to have ad hoc proprietary interfaces. Much like the pre-USB era, the present machine learning models have ad hoc formats. It overcomes the problem of framework lock-in by providing a universal intermediary model format that frameworks can easily save to and load from. This allows ML developers to create models in the framework of their choice without worrying about the deployment environment. ONNX also makes it easier to access hardware acceleration provided by different frameworks and runtime environments. 

THE BELAMY

Sign up for your weekly dose of what's up in emerging technology.

Installing ONNX and other required libraries

Requirement: Python 3.6 or higher version

1. Install ONNX
pip:
pip install onnx

Conda:
conda install -c conda-forge onnx

2. Install tensorflow and onnx-tensorflow

pip install tensorflow
pip install tensorflow-addons
git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow && pip install -e . 

3. Install PyTorch and torchvision
pip install pytorch
pip install torchvision

Converting a PyTorch model to TensorFlow

  1. Import required libraries and classes
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
 from torchvision import datasets, transforms
 from torch.autograd import Variable
 import onnx
 from onnx_tf.backend import prepare
 
  1. Define a basic CNN model
 class Net(nn.Module):    
      def __init__(self):
         super(Net, self).__init__()
         self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
         self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
         self.conv2_drop = nn.Dropout2d()
         self.fc1 = nn.Linear(320, 50)
         self.fc2 = nn.Linear(50, 10)
     def forward(self, x):
         x = F.relu(F.max_pool2d(self.conv1(x), 2))
         x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
         x = x.view(-1, 320)
         x = F.relu(self.fc1(x))
         x = F.dropout(x, training=self.training)
         x = self.fc2(x)
         return F.log_softmax(x, dim=1)
 
  1. Create the train and test methods
 def train(model, device, train_loader, optimizer, epoch):
     model.train()
     for batch_idx, (data, target) in enumerate(train_loader):
         data, target = data.to(device), target.to(device)
         optimizer.zero_grad()
         output = model(data)
         loss = F.nll_loss(output, target)
         loss.backward()
         optimizer.step()
         if batch_idx % 1000 == 0:
             print('Train Epoch: {} \tLoss: {:.6f}'.format(
                     epoch,  loss.item()))

 def test(model, device, test_loader):
     model.eval()
     test_loss = 0
     correct = 0
     with torch.no_grad():
         for data, target in test_loader:
             data, target = data.to(device), target.to(device)
             output = model(data)
             test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
             pred = output.max(1, keepdim=True)[1] # get the index of the maxlog-probability
             correct += pred.eq(target.view_as(pred)).sum().item()
     test_loss /= len(test_loader.dataset)
     print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
         test_loss, correct, len(test_loader.dataset),
         100. * correct / len(test_loader.dataset))) 
  1. Download the train and test datasets, normalize them and create data loaders.
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
     batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
     datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
     batch_size=1000, shuffle=True)
 
  1. Create the model, define the optimitier and train it
 device = torch.device("cuda")
 model = Net().to(device)
 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 for epoch in range(21):
     train(model, device, train_loader, optimizer, epoch)
     test(model, device, test_loader) 

  

  1. Save the trained model
    torch.save(model.state_dict(), 'mnist.pth')
  2. Load the saved model. Generate and pass random input so the Pytorch exporter can trace the model and save it to an ONNX file.
 trained_model = Net()
 trained_model.load_state_dict(torch.load('mnist.pth'))
 dummy_input = Variable(torch.randn(1, 1, 28, 28))
 torch.onnx.export(trained_model, dummy_input, "mnist.onnx")
ONNX file visualized using Netron
ONNX file visualized using Netron
  1. Load the ONNX file and import it to Tensorflow
 model = onnx.load('mnist.onnx')
 tf_rep = prepare(model) 
  1. Run and test the Tensorflow model.
 import numpy as np
 from IPython.display import display
 from PIL import Image
 print('Image 1:')
 img = Image.open('two.png').resize((28, 28)).convert('L')
 display(img)
 output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :])
 print('The digit is classified as ', np.argmax(output))
 print('Image 2:')
 img = Image.open('three.png').resize((28, 28)).convert('L')
 display(img)
 output = tf_rep.run(np.asarray(img, dtype=np.float32)[np.newaxis, np.newaxis, :, :])
 print('The digit is classified as ', np.argmax(output)) 


  1. Save the Tensorflow model.
    tf_rep.export_graph('mnist.pb')

More Great AIM Stories

Aditya Singh
A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.

Our Upcoming Events

Conference, Virtual
Genpact Analytics Career Day
3rd Sep

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
21st Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM
MOST POPULAR

Council Post: How to Evolve with Changing Workforce

The demand for digital roles is growing rapidly, and scouting for talent is becoming more and more difficult. If organisations do not change their ways to adapt and alter their strategy, it could have a significant business impact.

All Tech Giants: On your Mark, Get Set – Slow!

In September 2021, the FTC published a report on M&As of five top companies in the US that have escaped the antitrust laws. These were Alphabet/Google, Amazon, Apple, Facebook, and Microsoft.

The Digital Transformation Journey of Vedanta

In the current digital ecosystem, the evolving technologies can be seen both as an opportunity to gain new insights as well as a disruption by others, says Vineet Jaiswal, chief digital and technology officer at Vedanta Resources Limited

BlenderBot — Public, Yet Not Too Public

As a footnote, Meta cites access will be granted to academic researchers and people affiliated to government organisations, civil society groups, academia and global industry research labs.