MITB Banner

Watch More

How To Convert A Sketch Into Colored Image Using Conditional GAN

sketch to color transformation using conditional GANs

The conditional generative adversarial networks are an extension of DCGANs where the images are generated based on a certain condition. The generation of images can be conditional on a class which allows a particular type of images to be generated. Like a DCGAN architecture, cGAN also comprises a generator and a discriminator that is built using deep convolution. So what is the difference between a DCGAN and a CGAN? While training our DCGAN to generate images, we have no way of requesting a GAN to generate a particular image. We also do not know which image is getting generated. This problem is overcome by conditional GAN. 

In this article we will talk about the architecture and working of a cGAN and learn how to implement a simple image to image translation using tensorflow. Let’s get started!

The Architecture and working of a cGAN

architecture of a Conditional GANs

In a conditional GAN, the generator and discriminator models receive additional input information in the form of a vector. This information could be the label of the input image or some other property. The information is one hot encoded and sent to the generator. The generator takes this vector of information and encodes features from an image like the class (male and female if we are trying to generate images of faces) or properties like hair, nose, eyes etc. This information is incorporated into the images which make the predictions not completely random anymore. 

Generator network

 

The generator network makes use of a special architecture known as U-net. U-net is a network that contains encoder and decoder blocks. The idea here is that the encoder (green block) tries to encode the input image into a smaller representation. By compressing the input this way we will achieve a higher level of data representation in the final encoding layer. The decoder block(blue blocks) does the opposite of the encoder. They have skip connections and reverse all the actions of the encoder. U-net models are known for being able to identify finer details of inputs like pixels and boundaries, and for learning the image segmentation. 

Discriminator network :

Disciminator model of conditional GANs

The discriminator is a simple convolutional neural network with batch normalization. The job of the discriminator is to identify the real and fake images. This is done by the output layer that stores the probability values for each of the images in the input layer which indicates the probability of the image being real or fake. This is called PatchGAN. 

Now that we have understood the architecture and working of cGAN, let us develop an image to image translation model using cGAN.

Implementation

We will implement a model that converts a sketch image into a colored image. The dataset for this needs to be in the form of a sketch-color pair. 

The dataset used in this project is available here for downloading. The dataset consists of multiple images of pokemon. It is in the form of a sketch-color pair. 

Setting up the requirements 

Loading and visualizing the dataset

import numpy as np
import pandas as pd
import os
import tensorflow as tf
import matplolib.pyplot as plt
import time
From google.colab import drive 
drive.mount(‘/content/gdrive/’)
input_path = '/content/gdrive/My Drive/pokemon-sketch/pokemon_pix2pix_dataset'
def image_loader(image_file):
  image_read = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image_read)
  dim = tf.shape(image)[1]
  dim = dim // 2
  input = image[:, :dim, :]
  original = image[:, dim:, :]
  input = tf.cast(input, tf.float32)
  original = tf.cast(original, tf.float32)
  return input, original
inp_img, orig_img = image_loader(input_path+'/train/3.jpg')
plt.figure()
plt.imshow(inp_img/255.0)
plt.figure()
plt.imshow(orig_img/255.0)

 Data augmentation 

For better accuracy and faster convergence, I have resized the image, normalized it and applied data augmentation techniques of flipping the image right and left. Here are the functions.

def resize_image(input,original, height, width):
        input=tf.image.resize(
input[height,width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
       original=tf.image.resize(original,[height,width],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return input, original
def cutout(input, original):
  stacked = tf.stack([input, original], axis=0)
  cut_image = tf.image.random_crop(stacked, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
  return cut_image[0], cut_image[1]
def normalize(input, original):
  input = (input / 127.5) - 1
  original = (original / 127.5) - 1
  return input, original
def add_noise(input, original):
  input,original = resize_image(input,original, 256, 256)
  input,original = cutout(input,original)
  if tf.random.uniform(()) > 0.5:
    input = tf.image.flip_left_right(input)
    original = tf.image.flip_left_right(original)
  return input,original

Generator model 

Let us build our U-net model for the generator of the cGAN. To avoid code redundancy I will write the convolution layers in the form of functions. 

def build_model_downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)
  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))
  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())
  result.add(tf.keras.layers.LeakyReLU())
  return result
def build_model_upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)
  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))
  result.add(tf.keras.layers.BatchNormalization())
  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))
  result.add(tf.keras.layers.ReLU())
  return result
def Generator():
  down_stack = [
    build_model_downsample(64, 4, apply_batchnorm=False),
    build_model_downsample(128, 4), 
    build_model_downsample(256, 4),
    build_model_downsample(512, 4), 
    build_model_downsample(512, 4),
    build_model_downsample(512, 4), 
    build_model_downsample(512, 4), 
    build_model_downsample(512, 4), 
  ]
  up_stack = [
    build_model_upsample(512, 4, apply_dropout=True), 
    build_model_upsample(512, 4, apply_dropout=True), 
    build_model_upsample(512, 4, apply_dropout=True), 
    build_model_upsample(512, 4), 
    build_model_upsample(256, 4), 
    build_model_upsample(128, 4), 
    build_model_upsample(64, 4), 
  ]
  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh') 
  concat = tf.keras.layers.Concatenate()
  inputs = tf.keras.layers.Input(shape=[None,None,3])
  x = inputs
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)
  skips = reversed(skips[:-1])
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])
  x = last(x)
  return tf.keras.Model(inputs=inputs, outputs=x)

Discriminator model 

def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)
  input = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
  target = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
  concat= tf.keras.layers.concatenate([input, target]) 
  layer1 = downsample(64, 4, False)(concat)
  layer2 = downsample(128, 4)(layer1) 
  layer3 = downsample(256, 4)(layer2) 
  pad1 = tf.keras.layers.ZeroPadding2D()(layer3) 
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(pad1) 
  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
  pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) 
  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(pad2) 
  return tf.keras.Model(inputs=[input, target], outputs=last)

Loss functions

Before passing the data let us include our loss functions for generator and discriminator.

def discriminator_loss(disc_real_output, disc_generated_output):
 actual_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
  total = actual_loss + generated_loss
  return total
alpha = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
 mean_error = tf.reduce_mean(tf.abs(target - gen_output))
total = gan_loss + (alpha *mean_error)
  return total

Training 

Let us set the hyperparameters and load our data into the model so that it can be trained.

def load_training_data(image_file):
  input,original = image_loader(image_file)
  input,original = add_noise(input,original)
  input,original= normalize(input,original)
  return input,original
train_image = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_image = train_image.map(load_training_data,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_image= train_image.cache().shuffle(BUFFER_SIZE)
train_image= train_image.batch(1)

We will do the same for testing data 

def load_test_data(image_file):
  input,original = image_loader(image_file)
  input,original = add_noise(input,original)
  input,original= normalize(input,original)
  return input,original
test_image = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_image = test_image.map(load_test_data,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_image= test_image.cache().shuffle(BUFFER_SIZE)
test_image= test_image.batch(1)
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
EPOCHS =101
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
def train_step(input_image, target):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)
    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)
    gen_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
  generator_gradients = gen_tape.gradient(gen_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)
  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15,15))
  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'ground_truth', 'Predicted Image']
  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()
def fit(train_ds, epochs, test_ds):
  for epoch in range(epochs):
    start = time.time()
    for input_image, target in train_ds:
      train_step(input_image, target)
    clear_output(wait=True)
    for example_input, example_target in test_ds.take(1):
      generate_images(generator, example_input, example_target)
    if (epoch + 1) % 20 == 0:
    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
fit(train_dataset, EPOCHS, test_dataset)

Here we have directly used the test data as validation during the training process itself.

Output : 

Output after training Conditional GANs

This is the output at 54th iteration. 

As we see , the generated images are colored and as the training increases, the performance becomes better.

Conclusion

Gans are models with a wide range of applications. With cGANS we get more control on our generated images when compared to other types of GANS and we can tell our model what to generate. cGANS are now being widely used for creating textures for clothes, shoes etc in the fashion industry. 

Access all our open Survey & Awards Nomination forms in one place >>

Picture of Bhoomika Madhukar

Bhoomika Madhukar

I am an aspiring data scientist with a passion for teaching. I am a computer science graduate from Dayananda Sagar Institute. I have experience in building models in deep learning and reinforcement learning. My goal is to use AI in the field of education to make learning meaningful for everyone.

Download our Mobile App

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox
Recent Stories