Training a GAN model often requires a large collection of the annotated dataset(FFHQ dataset, Imagenet Dataset). Making these large-scale datasets can take months or even years and sometimes it is also not feasible to collect large data(such as rare animals). Using less data for training GAN models is the absolute need at this point in time. However, training with fewer data points degrades accuracy. To overcome this problem, researchers have developed the technique of data augmentation by which we can increase the data sample without collecting new data points.
Cropping, flipping, scaling, colour jittering, and region masking are the most widely known data augmentation techniques. When we apply these transformations to only real images, the generator tries to match the distribution of the augmented ones resulting in a distribution shift. If we apply these transformations to both real and output images while training, it can lead to unbalanced optimization. To stabilize this situation researchers of MIT, Tsinghua University, Adobe Research, CMU have come up with an advanced technique called Differentiable Augmentation for Data-Efficient GAN Training. This method was presented at the 34th Conference on Neural Information Processing Systems (NeurIPS 2020), Vancouver, Canada by Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, Song Han.
Sign up for your weekly dose of what's up in emerging technology.
Differentiable Augmentation(DiffAugment) is a straightforward method which applies same Differentiable augmentation to both real and generated images during the training of GAN. The output below shows the effectiveness of DiffAugment. This output has been generated on training 100 head-shot photos of President Obama, 160 photos of cats and 389 of dogs both from AnimalFace dataset.
The mathematical formula for Differentiable Augmentation is given below:
D = Discriminator,
G = Generator,
LD = Discriminator loss,
LG = Generator loss,
fD and fD = loss function,
G(z) = mapping of the generator to latent vector z,
T = A random function.
Differentiable Augmentation T is applied to both the real image x the generated image G(z). To update G, gradients will back-propagate via T. For T to be the medium of back-propagation, it has to be differentiable with respect to the input. The image below shows the whole process.
- Requirements and Installation for Differentiable Augmentation
Clone the repository and change the directory
!git clone https://github.com/mit-han-lab/data-efficient-gans %cd data-efficient-gans/DiffAugment-stylegan2
- TensorFlow 1.15 with GPU support.
- tensorflow-datasets version <= 2.1.0
- use 4 or 8 GPUs with at least 12 GB of DRAM for training.
#install the preliminaries !pip uninstall -y tensorflow tensorflow-probability !pip install tensorflow-gpu==1.15.0 tensorflow-datasets==2.1.0
- Train the model
This is an example of training the 100 head-shot images of President Obama with the resolution of 64.
!python run_low_shot.py --dataset=100-shot-obama --resolution=64
- Train your own model
This library also provides the method to train your own model with Differentiable Augmentation. Further details can be checked here.
- Differentiable Augmentation Demo using Pre-Trained Model
4.1 Clone the repositories and install all the requirements.
4.2 Import all the required libraries and packages. The code snippet is available here.
4.3 Helper functions(discussed below).
4.4 Load the data and visualize it. For example, we have taken 100- shots of Obama. The code is available below.
data_dir = dataset_tool.create_dataset('100-shot-obama') training_images =  for fname in os.listdir(data_dir): if fname.endswith('.jpg'): training_images.append(np.array(PIL.Image.open(os.path.join(data_dir, fname)))) imgs = np.reshape(training_images, [5, 20, *training_images.shape]) imgs = np.concatenate(imgs, axis=1) imgs = np.concatenate(imgs, axis=1) PIL.Image.fromarray(imgs).resize((1000, 250), PIL.Image.ANTIALIAS)
The generated output is shown below:
4.5 For comparison purposes, we will check the data with Baseline StyleGAN and then with Differentiable Augmentation GAN. The helper functions required for this step are generated and _generate. Both the functions are responsible for loading the weights of pre-trained models and generating the output.
def _generate(network_name, num_rows, num_cols, seed, resolution): if seed is not None: np.random.seed(seed) with tf.Session(): #loading the weights _, _, Gs = misc.load_pkl(network_name) z = np.random.randn(num_rows * num_cols, Gs.input_shape) #generate the image outputs = Gs.run(z, None, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)) outputs = np.reshape(outputs, [num_rows, num_cols, *outputs.shape[1:]]) outputs = np.concatenate(outputs, axis=1) outputs = np.concatenate(outputs, axis=1) img = PIL.Image.fromarray(outputs) img = img.resize((resolution * num_cols, resolution * num_rows), PIL.Image.ANTIALIAS) return img def generate(network_name, num_rows, num_cols, seed=None, resolution=128): with Pool(1) as pool: return pool.apply(_generate, (network_name, num_rows, num_cols, seed, resolution)
Load the pre-trained model for baseline style GAN and generate the output by passing it to the helper function(discussed above).
#load the Baseline style GAN model and generate the resultls generate('mit-han-lab:stylegan2-100-shot-obama.pkl', num_rows=2, num_cols=5, seed=1000)
Now, load the pre-trained model for style GAN with Differentiable Augmentation and pass it into helper functions to generate the images.
#load the pretrained Differential Augmentation model for 100-shot-obama dataset generate('mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl', num_rows=2, num_cols=5, seed=1000)
The output generated is shown below and differences between the two models are clearly visible. The output from Baseline StyleGAN is more distorted than to Style GAN with Differentiable Augmentation. The possible reason for this is overfitting as the discriminator in Baseline Style GAN is learning the training dataset and hence its performance dropped significantly on testing data.
4.6 Evaluate both the models by calculating Frechet Inception Distance. Lower the value of FID, the better will be the result. You can check out all the evaluation measures used in GAN here. The code for calculating the FID is given below.
Functions to calculate the FID
def _evaluate(network_name, dataset, resolution, metric): dataset = dataset_tool.create_dataset(dataset, resolution) dataset_args = EasyDict(tfrecord_dir=dataset, resolution=resolution, from_tfrecords=True) metric_group = metric_base.MetricGroup([metric_defaults[metric]]) metric_group.run(network_name, dataset_args=dataset_args, log_results=False) return metric_group.metrics._results.value def evaluate(network_name, dataset, resolution=256, metric='fid5k-train'): with Pool(1) as pool: return pool.apply(_evaluate, (network_name, dataset, resolution, metric))
Now, calculate the FID by using the above function for both Baseline Style GAN and Style GAN with Differentiable Augmentation.
print('Evaluating StyleGAN2 (baseline)...') fid_baseline = evaluate('mit-han-lab:stylegan2-100-shot-obama.pkl', dataset='100-shot-obama') print('Baseline FID:', fid_baseline, '\n') print('Evaluating StyleGAN2 + DiffAugment (ours)...') fid_ours = evaluate('mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl', dataset='100-shot-obama') print('Ours FID:', fid_ours, '\n') plt.figure(figsize=(2, 3)) plt.bar([0, 1], [fid_baseline, fid_ours], color=['gray', 'darkred']) plt.xticks([0, 1], ['Baseline', 'Differential Augmentation']) plt.ylabel("FID") plt.show()
The output of the above code is shown below. It is clear that the Differentiable Augmentation is clearly better than the Baseline model in terms of FID score.
4.7 Lastly, generate the interpolation video using a pre-trained model. The code for it is given below.
!python generate_gif.py -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o interp.gif --num-rows=2 --num-cols=3 --seed=1 IPython.display.Image(open('interp.gif', 'rb').read())
The output of the above code is a .gif file, shown below.
You can check the full demo here.
In this article, we have discussed a new data augmentation technique called Differentiable Augmentation and showed a comparison of this method with the baseline. We finally concluded that Differentiable Augmentation is able to generate high-quality images with a limited amount of data(in this case, it was 100).
- Colab Notebook Differentiable Augmentation Demo – For Training
- Colab Notebook Differentiable Augmentation Demo – Compare Baseline StyleGAN and StyleGAN with Differentiable Augmentation
Official Codes, Docs & Tutorial are available at: