Generating High Resolution Images Using Transformers

Taming Transformers

Transformers are known for their long-range interactions with sequential data and are easily adaptable to different tasks, be it Natural Language Processing, Computer Vision or audio. Transformers are free to learn all complex relationships in the given input as they do not contain any inductive bias, unlike Convolution Neural Networks(CNN). This on the one hand increases expressivity but makes it computationally impractical for long sequences or high-quality images. In December 2020, Patrick Esser, Robin Rombach and Björn Ommer, AI researchers from Heidelberg University, Germany, published a paper on combining the Convolutional Neural Network(CNN) and Transformers to overcome the problem of producing high-resolution images: Taming Transformers for High-Resolution Image Synthesis.

To make transformers more efficient, the Taming Transformer method integrates the inductive bias of CNNs with transformers’ expressivity. To produce high-resolution images,  the proposed methods demonstrates:

  • They use VQGAN CNNs to effectively learn a codebook of context-rich visual parts.
  • Utilization of transformers to efficiently model their composition within high-resolution images.

The Model Architecture of Taming Transformers

The model architecture uses convolutional neural  VQGAN, which contains encoder-decoder and adversarial training methods to produce codebook (efficient and rich representations) of images. This GAN architecture is used to train the generator to output high-resolution images. Once this GAN training is over, the model architecture then takes only the decoder part as the input to the transformer architecture, a.k.a codebook. This codebook holds the efficient and rich representation of images(instead of pixels) in a compressed form that can be read sequentially. The transformer is then trained with this codebook to predict the next indices’ distribution in the given representation, similar to the autoregressive model, to synthesize the required output.

Source : https://compvis.github.io/taming-transformers/

Some of the results from Taming Transformer 

Source : https://compvis.github.io/taming-transformers/

Tasks done by Taming Transformers

  • Image Completion
  • Depth-to-image generation
  • Label-to-image generation
  • Pose-to-human generation
  • Super Resolution

Quick Start with Taming Transformers

Installation & Dependencies

Clone the repository via git and download all the required models and configs.

 !git clone https://github.com/CompVis/taming-transformers
 %cd taming-transformers
 !mkdir -p logs/2020-11-09T13-31-51_sflckr/checkpoints
 !wget 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' -O 'logs/2020-11-09T13-31-51_sflckr/checkpoints/last.ckpt'
 !mkdir logs/2020-11-09T13-31-51_sflckr/configs
 !wget 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' -O 'logs/2020-11-09T13-31-51_sflckr/configs/2020-11-09T13-31-51-project.yaml' 

Install all the dependencies via pip.

 %pip install omegaconf==2.0.0 pytorch-lightning==1.0.8
 import sys
 sys.path.append(".") 

Demo – Taming Transformers via pretrained model

  1. Load the model and print the configs.
 from omegaconf import OmegaConf
 config_path = "logs/2020-11-09T13-31-51_sflckr/configs/2020-11-09T13-31-51-project.yaml"
 config = OmegaConf.load(config_path)
 import yaml
 print(yaml.dump(OmegaConf.to_container(config))) 
  1. Initialize the model.
 from taming.models.cond_transformer import Net2NetTransformer
 model = Net2NetTransformer(**config.model.params) 
  1. Load the checkpoints.
 import torch
 ckpt_path = "logs/2020-11-09T13-31-51_sflckr/checkpoints/last.ckpt"
 sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
 missing, unexpected = model.load_state_dict(sd, strict=False)
 model.cuda().eval()
 torch.set_grad_enabled(False) 
  1. Load the example semantic image and convert it into a tensor. Here we are taking label-to-image generation as an example.
 from PIL import Image
 import numpy as np
 segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png"
 segmentation = Image.open(segmentation_path)
 segmentation = np.array(segmentation)
 segmentation = np.eye(182)[segmentation]
 segmentation = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=model.device) 
  1. Visualize the segmentation.
 def show_segmentation(s):
   s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:]
   colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3)
   colorize = colorize / colorize.sum(axis=2, keepdims=True)
   s = s@colorize
   s = s[...,0,:]
   s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
   s = Image.fromarray(s)
   display(s)
 show_segmentation(segmentation)
 
  1. Encode the above image tensors so as to get the codebook via VQGAN.
 c_code, c_indices = model.encode_to_c(segmentation)
 print("c_code", c_code.shape, c_code.dtype)
 print("c_indices", c_indices.shape, c_indices.dtype)
 assert c_code.shape[2]*c_code.shape[3] == c_indices.shape[1]
 segmentation_rec = model.cond_stage_model.decode(c_code)
 show_segmentation(torch.softmax(segmentation_rec, dim=1)) 
  1. Take the decoder output to give it as input to the transformer model.
 def show_image(s):
   s = s.detach().cpu().numpy().transpose(0,2,3,1)[0]
   s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8)
   s = Image.fromarray(s)
   display(s)
 codebook_size = config.model.params.first_stage_config.params.embed_dim
 z_indices_shape = c_indices.shape
 z_code_shape = c_code.shape
 z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device)
 x_sample = model.decode_to_img(z_indices, z_code_shape)
 show_image(x_sample)
 
  1. The last step is to train the transformer so as to produce the required output. The code snippet is available here. You can check at each step how the transformer is being trained from decoder output to get the final image. The window size for training is 16 X 16.

Important Links

  • Check this link to run streamlit on Colab.
  • Pre-trained models for image generation, depth-to-image generation, super-resolution, etc. are available here.
  • For training the data from scratch, you can refer here.

Conclusion

In this article, we have given an overview of Taming Transformation for High-Resolution Image Synthesis. Instead of using pixels, the model is trained on a codebook from VQGAN whose decoder is then fed to transformer architecture to generate the required results. This post discussed the need, model architecture, results & tasks of the proposed method. It also consists of basic tutorials of using Taming Transformers pre-trained models. The final result shows that this method has outperformed previous state-of-the-art methods based on convolutional architectures.

Official References are available at:

More Great AIM Stories

Aishwarya Verma
A data science enthusiast and a post-graduate in Big Data Analytics. Creative and organized with an analytical bent of mind.

More Stories

OUR UPCOMING EVENTS

8th April | In-person Conference | Hotel Radisson Blue, Bangalore

Organized by Analytics India Magazine

View Event >>

30th Apr | Virtual conference

Organized by Analytics India Magazine

View Event >>

MORE FROM AIM
fruit recognition
Dr. Vaibhav Kumar
Fruit Recognition using the Convolutional Neural Network

In this article, we will recognize the fruit where the Convolutional Neural Network will predict the name of the fruit given its image. We will train the network in a supervised manner where images of the fruits will be the input to the network and labels of the fruits will be the output of the network. After successful training, the CNN model will be able to correctly predict the label of the fruit. 

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM