Hands-On Guide To Custom Training With Tensorflow Strategy

Distributed training in TensorFlow is built around data parallelism, where we can replicate the same model architecture on multiple devices and run different slices of input data on them. Here the device is nothing but a unit of CPU + GPU or separate units of GPUs and TPUs. This method follows like; our entire data is divided into equal numbers of slices. These slices are decided based on available devices to train; following each slice, there is a model to train on that slice.

As your models get bigger and more complex, it may become infeasible to train them on single units of CPU and GPU. So you might need to figure out a way to distribute training across multiple ones. To achieve this, Tensorflow has come up with various strategies. We can distribute our data across clusters of machines. Each of these clusters can have one or more devices that can carry out large-scale training on your models accordingly. This is typically called a distribution strategy. 

Distributed training in TensorFlow is built around data parallelism, where we can replicate the same model architecture on multiple devices and run different slices of input data on them. Here the device is nothing but a unit of CPU + GPU or separate units of GPUs and TPUs. This method follows like; our entire data is divided into equal numbers of slices. These slices are decided based on available devices to train; following each slice, there is a model to train on that slice. As the data is different for each model, the weights are also different at each model, so ultimately those weights need to aggregate into the new master model. 

This all distributed training is done by the Tensorflow’s tf.distribute.Strategy class which supports different distribution strategies on high-level APIs such as Tensorflow Keras. Ease of use has been the primary focus while designing these APIs, and when it comes to adopting these strategies, it requires minimal code to do so. We can also leverage all the functionality of Keras API such that layers, model, metrics, summarization checkpoints.  

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

There are few new terms when discussing distribution strategies such as Device, Replica, Worker and Mirrored variable. Let’s see one by one;

Here Device is referred to as any kind of machine which is used to train ML models. This device may consist of a combination of CPU and two GPUs, so totally we have three machines to train our model.

During training, copies of model parameters are placed on these machines, and these copies are referred to as Replica. The worker is nothing but the dedicated software which does training on these replicas. Lastly, there are some variables that we want to be in sync across all devices, such as mirrored variables. 

To know more about the distributed training, watch this Google I/O event;

We will be using the fashion MNIST data to implement these distribution strategies, containing 60K training images and 10K test images of size 28 x 28. Additionally, for better flexibility and control, we will be using custom training loops.

Implementation of Custom Training With Tensorflow Strategy

The following code implementation is in reference to the official implementation.

Import all dependencies:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
Load  the dataset:

We are working with a convolutional layer that accepts the 4D inputs as batch size, height, width, and channels. Such networks perform betters on scaled values so will scale pixel between 0 to 1;

dataset = tf.keras.datasets.fashion_mnist
(train_ima, train_labe), (test_ima, test_labe) = dataset.load_data()
train_ima = train_ima[...,None]
test_ima = test_ima[...,None]
train_ima = train_ima / np.float32(255)
test_ima = test_ima / np.float32(255)
Create a Strategy:

For this work, we are using tf.distribute.MirroredStrategy() and its working steps are as follows;

  • It replicates the model graph and all related variables to the replicas available.
  • Next, the slices made are distributed to these replicas 
  • Each replica calculates the loss and gradient associated with input slices 
  • Gradients are synced across all the replicas by summing them 
  • After sync update is made to all replicas
# device is getting detected automatically 
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) # check the available devices

buffer_size = len(train_ima)
batch_size_per_replica = 64
global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync

Create dataset using tensor slice and distribute them across replicas;

train_dataset = tf.data.Dataset.from_tensor_slices((train_ima,train_labe)).shuffle(buffer_size).batch(global_batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices((test_ima,test_labe)).batch(global_batch_size)
distri_train_dataset = strategy.experimental_distribute_dataset(train_dataset)

distri_test_dataset = strategy.experimental_distribute_dataset(test_dataset)
Create CNN model:

Here we create a Sequential model consisting of our main convolutional layers Max Pooling layer for effective feature extraction;

def create_model():
  model = tf.keras.Sequential([
         tf.keras.layers.Conv2D(64, 3, activation='relu'),
         tf.keras.layers.Conv2D(80, 3, activation='relu'),
         tf.keras.layers.Dense(64, activation='relu'),
  return model

Create a checkpoint directory to store the checkpoints for all epochs;

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,'ckpt')
Define Loss Function:

Let’s say we have 2 GPUs and the batch size of 64 now one batch of input distributed across 2 replicas each gets input of size 32, after making the forward propagation loss is calculated by each model now instead of dividing the loss by 32 it is divided by GLOBAL_BATCH_SIZE, i.e. 64. This approach needs to be followed because each model calculates gradients they need to be synced across all models by summing them up.

with strategy.scope():
  # here we settling reduction to NONE as lastly we going to divide them global
  # batch
  loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits = True,
      reduction = tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_obj(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)
Define metrics:

The below metrics tracks the test loss and training and test accuracy; 

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')
  train_accu = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
  test_accu = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
Training loop:

Instead of using model.compile() and model.fit() here define all the functionality exclusively which are used by these methods; here we define model, optimizer, and checkpoints under strategy we defined earlier.

with strategy.scope():
  # create model, optimizer, and checkpoints 
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  checkpoint = tf.train.Checkpoint(optimizer =optimizer, model = model)

Define training step and test step;

def train_step(inputs):
  images, labels = inputs
  with tf.GradientTape() as tape:
    predictions = model(images, training = True)
    loss = compute_loss(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  train_accu.update_state(labels, predictions)
  return loss
def test_step(inputs):
  images, labels = inputs
  predictions = model(images, training=False)
  t_loss = loss_obj(labels, predictions)
  test_accu.update_state(labels, predictions)
  test_accu.update_state(labels, predictions)

Here the training is done under the tf.function to make our model portable; we are iterating over our distributed dataset of train and test using a for a loop. 

def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):

  total_loss = 0.0
  num_batches = 0

  for x in distri_train_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  for x in distri_test_dataset:
  if epoch % 2 == 0:
  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accu.result()*100, test_loss.result(),


Restore the checkpoints: 

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_ima, test_labe)).batch(global_batch_size)

def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)

for images, labels in test_dataset:
  eval_step(images, labels)
print ('Accuracy after restoring the saved model without strategy: {}'.format(


Accuracy is: 90.2900


As we intend to work on complex projects, or if you want to train your own GAN model from scratch, the conventional training approach will take at least a week to train your model. By using the distributed training approach, you can significantly reduce your training time and cost. In addition, the distributed training approach opened the way for developers to build highly scaled and deep models. 


Vijaysinh Lendave
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox