MITB Banner

Hands-On Guide To Custom Training With Tensorflow Strategy

Distributed training in TensorFlow is built around data parallelism, where we can replicate the same model architecture on multiple devices and run different slices of input data on them. Here the device is nothing but a unit of CPU + GPU or separate units of GPUs and TPUs. This method follows like; our entire data is divided into equal numbers of slices. These slices are decided based on available devices to train; following each slice, there is a model to train on that slice.

Share

As your models get bigger and more complex, it may become infeasible to train them on single units of CPU and GPU. So you might need to figure out a way to distribute training across multiple ones. To achieve this, Tensorflow has come up with various strategies. We can distribute our data across clusters of machines. Each of these clusters can have one or more devices that can carry out large-scale training on your models accordingly. This is typically called a distribution strategy. 

Distributed training in TensorFlow is built around data parallelism, where we can replicate the same model architecture on multiple devices and run different slices of input data on them. Here the device is nothing but a unit of CPU + GPU or separate units of GPUs and TPUs. This method follows like; our entire data is divided into equal numbers of slices. These slices are decided based on available devices to train; following each slice, there is a model to train on that slice. As the data is different for each model, the weights are also different at each model, so ultimately those weights need to aggregate into the new master model. 

This all distributed training is done by the Tensorflow’s tf.distribute.Strategy class which supports different distribution strategies on high-level APIs such as Tensorflow Keras. Ease of use has been the primary focus while designing these APIs, and when it comes to adopting these strategies, it requires minimal code to do so. We can also leverage all the functionality of Keras API such that layers, model, metrics, summarization checkpoints.  

There are few new terms when discussing distribution strategies such as Device, Replica, Worker and Mirrored variable. Let’s see one by one;

Here Device is referred to as any kind of machine which is used to train ML models. This device may consist of a combination of CPU and two GPUs, so totally we have three machines to train our model.

During training, copies of model parameters are placed on these machines, and these copies are referred to as Replica. The worker is nothing but the dedicated software which does training on these replicas. Lastly, there are some variables that we want to be in sync across all devices, such as mirrored variables. 

To know more about the distributed training, watch this Google I/O event;

We will be using the fashion MNIST data to implement these distribution strategies, containing 60K training images and 10K test images of size 28 x 28. Additionally, for better flexibility and control, we will be using custom training loops.

Implementation of Custom Training With Tensorflow Strategy

The following code implementation is in reference to the official implementation.

Import all dependencies:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
Load  the dataset:

We are working with a convolutional layer that accepts the 4D inputs as batch size, height, width, and channels. Such networks perform betters on scaled values so will scale pixel between 0 to 1;

dataset = tf.keras.datasets.fashion_mnist
(train_ima, train_labe), (test_ima, test_labe) = dataset.load_data()
train_ima = train_ima[...,None]
test_ima = test_ima[...,None]
train_ima = train_ima / np.float32(255)
test_ima = test_ima / np.float32(255)
Create a Strategy:

For this work, we are using tf.distribute.MirroredStrategy() and its working steps are as follows;

  • It replicates the model graph and all related variables to the replicas available.
  • Next, the slices made are distributed to these replicas 
  • Each replica calculates the loss and gradient associated with input slices 
  • Gradients are synced across all the replicas by summing them 
  • After sync update is made to all replicas
# device is getting detected automatically 
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) # check the available devices

buffer_size = len(train_ima)
batch_size_per_replica = 64
global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
EPOCHS = 10

Create dataset using tensor slice and distribute them across replicas;

train_dataset = tf.data.Dataset.from_tensor_slices((train_ima,train_labe)).shuffle(buffer_size).batch(global_batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices((test_ima,test_labe)).batch(global_batch_size)
distri_train_dataset = strategy.experimental_distribute_dataset(train_dataset)

distri_test_dataset = strategy.experimental_distribute_dataset(test_dataset)
Create CNN model:

Here we create a Sequential model consisting of our main convolutional layers Max Pooling layer for effective feature extraction;

def create_model():
  model = tf.keras.Sequential([
         tf.keras.layers.Conv2D(64, 3, activation='relu'),
         tf.keras.layers.MaxPooling2D(),
         tf.keras.layers.Conv2D(80, 3, activation='relu'),
         tf.keras.layers.MaxPooling2D(),
         tf.keras.layers.Flatten(),
         tf.keras.layers.Dense(64, activation='relu'),
         tf.keras.layers.Dense(10)                    
  ])
  return model

Create a checkpoint directory to store the checkpoints for all epochs;

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,'ckpt')
Define Loss Function:

Let’s say we have 2 GPUs and the batch size of 64 now one batch of input distributed across 2 replicas each gets input of size 32, after making the forward propagation loss is calculated by each model now instead of dividing the loss by 32 it is divided by GLOBAL_BATCH_SIZE, i.e. 64. This approach needs to be followed because each model calculates gradients they need to be synced across all models by summing them up.

with strategy.scope():
  # here we settling reduction to NONE as lastly we going to divide them global
  # batch
  loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits = True,
      reduction = tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_obj(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)
Define metrics:

The below metrics tracks the test loss and training and test accuracy; 

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')
  train_accu = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
  test_accu = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
Training loop:

Instead of using model.compile() and model.fit() here define all the functionality exclusively which are used by these methods; here we define model, optimizer, and checkpoints under strategy we defined earlier.

with strategy.scope():
  # create model, optimizer, and checkpoints 
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  checkpoint = tf.train.Checkpoint(optimizer =optimizer, model = model)

Define training step and test step;

def train_step(inputs):
  images, labels = inputs
  with tf.GradientTape() as tape:
    predictions = model(images, training = True)
    loss = compute_loss(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  train_accu.update_state(labels, predictions)
  return loss
def test_step(inputs):
  images, labels = inputs
  predictions = model(images, training=False)
  t_loss = loss_obj(labels, predictions)
  test_loss.update_state(t_loss)
  test_accu.update_state(labels, predictions)
  test_accu.update_state(labels, predictions)

Here the training is done under the tf.function to make our model portable; we are iterating over our distributed dataset of train and test using a for a loop. 

@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):

  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0

  for x in distri_train_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in distri_test_dataset:
    distributed_test_step(x)
  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)
  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accu.result()*100, test_loss.result(),
                         test_accu.result()*100))

  test_loss.reset_states()
  train_accu.reset_states()
  test_accu.reset_states()

Restore the checkpoints: 

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_ima, test_labe)).batch(global_batch_size)

@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)
print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))

Output:

Accuracy is: 90.2900

Conclusion:

As we intend to work on complex projects, or if you want to train your own GAN model from scratch, the conventional training approach will take at least a week to train your model. By using the distributed training approach, you can significantly reduce your training time and cost. In addition, the distributed training approach opened the way for developers to build highly scaled and deep models. 

References:

Share
Picture of Vijaysinh Lendave

Vijaysinh Lendave

Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.