A CycleGAN is designed for image-to-image translation, and it learns from unpaired training data.
It gives us a way to learn the mapping between one image domain and another using an unsupervised approach.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Examples of image data in both sets:
Translating summer landscapes to winter landscapes (or the reverse).
Unpaired Training Data
These images do not come with the labels, i.e. the generator creates the training data X from the Y datasets. We do not have to extract all the corresponding features from the individual images. In the GitHub code that introduced CycleGANs, the authors were able to translate the horses to zebras, even though there are no images of zebra exactly in the same position of horses. Thus CycleGANs enables learning from X to another domain Y mapping without having to find perfectly matched, training pairs!
A CycleGAN is made of two types of networks: discriminators and generators. In this example, the discriminators are responsible for classifying images as real or fake (for both X and Y kinds of images). The generators are responsible for generating convincing, fake images for both kinds of images.
A simple example of the CycleGAN. This image presents the data flow through CycleGAN to pull it all together:
You’ll need to download the data as a zip file here.
First, install the PyTorch and import all the libraries for this project.
function from PyTorch. Store the new dataset using the ImageFolder.
Image_type: Directory where X and Y image are stored
Image_dir: Main directory for Train and Test image
Image_size: resized image dimension
Batch_size: Number of images in one batch of Data
Visualization of Training and Testing Data
????_X and ????_???? , in this CycleGAN, are convolutional neural networks that see an image and attempt to classify it as real or fake.
Discriminators class to create the model in pytorch. The ReLu activation function is used to pass input images through convolutional layers. We have provided a helper function which creates a convolutional layer + an optional batch norm layer.
Now the helper function can easily create a Discriminators class.
The generators G_XtoY and G_YtoX.It is responsible for turning an image into a smaller feature representation, and an encoder, a transpose_conv and decoder net that is responsible for turning that representation into a transformed image.
It goes through three convolutional layers using BatchNorm and ReLu activation functions and reaches a series of residual blocks. The residual blocks are made of convolutional and batch normalization layers.
Generators class using the same function as Discriminator.
???????????????????? and ???????????????????? have the same architecture, so we only need to define one class, and later instantiate two generators.
It will contain three-part encoder, transformer and Decoder.Use the convolutional neural network and sequential function to define the generator. And a feed-forward function generator using ReLu. In the last layer, use TanH function.
It will help you connect the encoder and decoder. It consists of two convolutional layers. The layer must have the same input size as output.
Putting it all Together:
Create two discriminatorsG_XtoY and G_YtoX then two generators D_X and D_Y for full network. To train the model either on GPU for faster processing or use CPU.
Refer code snippet:
Discriminator and Generator Losses
Computing the discriminator and the generator losses are key to getting a CycleGAN to train.
Image from original paper by Jun-Yan Zhu et. al.
The discriminator loss is mean squared errors between the output of the discriminator.
Calculating the generator losses will look somewhat similar to calculating the discriminator loss; there will still be steps in which you generate fake images that look like they belong to the set of
???? images but are based on real images in set ????, and vice versa.
The first adversarial loss is calculated on the generator G and the discriminator D. The second adversarial loss is calculated on the generator G(x), and the discriminator D(y).
Cycle Consistency Loss
In addition to the adversarial losses, A cycle consistent mapping function is a function that can translate an image x from domain A to another image y in domain B, and generate back the original image.
A forward cycle consistent mapping function appears as follows:
X -> G(X) -> F(G(X)) ≈ x
A backward cycle consistent mapping function looks as follows:
Y -> G(Y) -> F(G(Y)) ≈ Y
This network sees a 128x128x3 image, compresses it into a feature representation as it goes through three convolutional layers and reaches a series of residual blocks
To calculate the total loss, if G is our generator from A to B and F is our generator from B to A, then
â = F(G(a)) ≈ a.
All Loss function:
Real_mse_loss: The loss in the real image.
Fake_mse_loss: The loss in fake image.
Cycle_Consistency_loss: Total loss
We will train our model in two-part
Calculate the discriminator loss for both.
Generate Fake Image.
Calculate the total loss using on both discriminators.
Generate fake images of X which is real image of Y.
Generate the new Y images based on the fake X images.
Calculate the cycle consistency loss on real Y images and New Y images.
Visualisation of Image:
We can see the sample of images after 100 epoch:
1.)Image transformation from X to Y in 100 epoch
2.)Image transformation from Y to X in 100 epoch
Now see the results after 5000 epoch:
Image transformation from X to Y in 5000 epoch
Image transformation from Y to X in 5000 epoch
We have learned how to use a CycleGAN in the image to image translation. We started with an introduction to CyleGANs and explored the architectures of networks involved in CycleGANs. We also explored the different loss functions required to train CycleGANs. This was followed by an implementation of CycleGAN in the PyTorch framework. We trained the CycleGAN on the available dataset and visualized the generated images, the losses, and the graphs for different networks.