Active Hackathon

Guide to Differentiable Augmentation for Data-Efficient GAN Training

Data Efficient GAN

Training a GAN model often requires a large collection of the annotated dataset(FFHQ dataset, Imagenet Dataset). Making these large-scale datasets can take months or even years and sometimes it is also not feasible to collect large data(such as rare animals).  Using less data for training GAN models is the absolute need at this point in time. However, training with fewer data points degrades accuracy. To overcome this problem, researchers have developed the technique of data augmentation by which we can increase the data sample without collecting new data points.

Cropping, flipping, scaling, colour jittering, and region masking are the most widely known data augmentation techniques. When we apply these transformations to only real images, the generator tries to match the distribution of the augmented ones resulting in a distribution shift. If we apply these transformations to both real and output images while training, it can lead to unbalanced optimization. To stabilize this situation researchers of MIT, Tsinghua University, Adobe Research, CMU have come up with an advanced technique called Differentiable Augmentation for Data-Efficient GAN Training. This method was presented at the 34th Conference on Neural Information Processing Systems (NeurIPS 2020), Vancouver, Canada by Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, Song Han.


Sign up for your weekly dose of what's up in emerging technology.

Differentiable Augmentation(DiffAugment) is a straightforward method which applies same Differentiable augmentation to both real and generated images during the training of GAN. The output below shows the effectiveness of DiffAugment. This output has been generated on training 100 head-shot photos of President Obama, 160 photos of cats and 389 of dogs both from AnimalFace dataset. 

The mathematical formula for Differentiable Augmentation is given below:


D  = Discriminator,

G  = Generator,

L= Discriminator loss,

L= Generator loss,

fD and fD  = loss function,

G(z) = mapping of the generator to latent vector z,

T = A random function.

Differentiable Augmentation T is applied to both the real image x the generated image G(z). To update G, gradients will back-propagate via T. For T to be the medium of back-propagation, it has to be differentiable with respect to the input. The image below shows the whole process.

Datasets used for Experiments

  1. Requirements and Installation for Differentiable Augmentation

Clone the repository and change the directory

!git clone
%cd data-efficient-gans/DiffAugment-stylegan2
  • TensorFlow 1.15 with GPU support.
  • tensorflow-datasets version <= 2.1.0 
  • use 4 or 8 GPUs with at least 12 GB of DRAM for training.
#install the preliminaries
!pip uninstall -y tensorflow tensorflow-probability
!pip install tensorflow-gpu==1.15.0 tensorflow-datasets==2.1.0
  1. Train the model

This is an example of training the 100 head-shot images of President Obama with the resolution of 64.   

    !python --dataset=100-shot-obama --resolution=64

You may also try out other datasets for training mentioned here. Check the full demo here. For other examples, click here.

  1. Train your own model

This library also provides the method to train your own model with Differentiable Augmentation. Further details can be checked here.

  1. Differentiable Augmentation Demo using Pre-Trained Model

4.1 Clone the repositories and install all the requirements.

4.2 Import all the required libraries and packages. The code snippet is available here.

4.3 Helper functions(discussed below).

4.4 Load the data and visualize it. For example, we have taken 100- shots of Obama. The code is available below.

data_dir = dataset_tool.create_dataset('100-shot-obama')
training_images = []
for fname in os.listdir(data_dir):
  if fname.endswith('.jpg'):
    training_images.append(np.array(, fname))))
imgs = np.reshape(training_images, [5, 20, *training_images[0].shape])
imgs = np.concatenate(imgs, axis=1)
imgs = np.concatenate(imgs, axis=1)
PIL.Image.fromarray(imgs).resize((1000, 250), PIL.Image.ANTIALIAS)

    The generated output is shown below:

4.5 For comparison purposes, we will check the data with Baseline StyleGAN and then with Differentiable Augmentation GAN. The helper functions required for this step are generated and _generate. Both the functions are responsible for loading the weights of pre-trained models and generating the output.

Helper functions

def _generate(network_name, num_rows, num_cols, seed, resolution):
  if seed is not None:
  with tf.Session():
    #loading the weights
    _, _, Gs = misc.load_pkl(network_name)
    z = np.random.randn(num_rows * num_cols, Gs.input_shape[1])
    #generate the image
    outputs =, None, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True))
    outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]])
    outputs = np.concatenate(outputs, axis=1)
    outputs = np.concatenate(outputs, axis=1)
    img = PIL.Image.fromarray(outputs)
    img = img.resize((resolution * num_cols, resolution * num_rows), PIL.Image.ANTIALIAS)
  return img
def generate(network_name, num_rows, num_cols, seed=None, resolution=128):
  with Pool(1) as pool:
    return pool.apply(_generate, (network_name, num_rows, num_cols, seed, resolution)

Load the pre-trained model for baseline style GAN and generate the output by passing it to the helper function(discussed above).

#load the Baseline style GAN model and generate the resultls
generate('mit-han-lab:stylegan2-100-shot-obama.pkl', num_rows=2, num_cols=5, seed=1000)

Now, load the pre-trained model for style GAN with Differentiable Augmentation  and pass it into helper functions to generate the images.

#load the pretrained Differential Augmentation model for 100-shot-obama dataset
generate('mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl', num_rows=2, num_cols=5, seed=1000)

The output generated is shown below and differences between the two models are clearly visible. The output from Baseline StyleGAN is more distorted than to Style GAN with Differentiable Augmentation. The possible reason for this is overfitting as the discriminator in Baseline Style GAN is learning the training dataset and hence its performance dropped significantly on testing data. 

4.6 Evaluate both the models by calculating Frechet Inception Distance. Lower the value of FID, the better will be the result. You can check out all the evaluation measures used in GAN here. The code for calculating the FID is given below.

    Helper functions

    Functions to calculate the FID

def _evaluate(network_name, dataset, resolution, metric):
  dataset = dataset_tool.create_dataset(dataset, resolution)
  dataset_args = EasyDict(tfrecord_dir=dataset, resolution=resolution, from_tfrecords=True)
  metric_group = metric_base.MetricGroup([metric_defaults[metric]]), dataset_args=dataset_args, log_results=False)
  return metric_group.metrics[0]._results[0].value
def evaluate(network_name, dataset, resolution=256, metric='fid5k-train'):
  with Pool(1) as pool:
    return pool.apply(_evaluate, (network_name, dataset, resolution, metric))

Now, calculate the FID by using the above function for both Baseline Style GAN and Style GAN with Differentiable Augmentation.

print('Evaluating StyleGAN2 (baseline)...')
fid_baseline = evaluate('mit-han-lab:stylegan2-100-shot-obama.pkl', dataset='100-shot-obama')
print('Baseline FID:', fid_baseline, '\n')
print('Evaluating StyleGAN2 + DiffAugment (ours)...')
fid_ours = evaluate('mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl', dataset='100-shot-obama')
print('Ours FID:', fid_ours, '\n')
plt.figure(figsize=(2, 3))[0, 1], [fid_baseline, fid_ours], color=['gray', 'darkred'])
plt.xticks([0, 1], ['Baseline', 'Differential Augmentation'])

The output of the above code is shown below. It is clear that the Differentiable Augmentation is clearly better than the Baseline model in terms of FID score.

4.7 Lastly, generate the interpolation video using a pre-trained model. The code for it is given below.

!python -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o interp.gif --num-rows=2 --num-cols=3 --seed=1
IPython.display.Image(open('interp.gif', 'rb').read())

The output of the above code is a .gif file, shown below.

You can check the full demo here.


In this article, we have discussed a new data augmentation technique called Differentiable Augmentation and showed a comparison of this method with the baseline. We finally concluded that Differentiable Augmentation is able to generate high-quality images with a limited amount of data(in this case, it was 100).

Official Codes, Docs & Tutorial are available at:

More Great AIM Stories

Aishwarya Verma
A data science enthusiast and a post-graduate in Big Data Analytics. Creative and organized with an analytical bent of mind.

Our Upcoming Events

Conference, Virtual
Genpact Analytics Career Day
3rd Sep

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
21st Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM