Transformers are known for their long-range interactions with sequential data and are easily adaptable to different tasks, be it Natural Language Processing, Computer Vision or audio. Transformers are free to learn all complex relationships in the given input as they do not contain any inductive bias, unlike Convolution Neural Networks(CNN). This on the one hand increases expressivity but makes it computationally impractical for long sequences or high-quality images. In December 2020, Patrick Esser, Robin Rombach and Björn Ommer, AI researchers from Heidelberg University, Germany, published a paper on combining the Convolutional Neural Network(CNN) and Transformers to overcome the problem of producing high-resolution images: Taming Transformers for High-Resolution Image Synthesis.
To make transformers more efficient, the Taming Transformer method integrates the inductive bias of CNNs with transformers’ expressivity. To produce high-resolution images, the proposed methods demonstrates:
- They use VQGAN CNNs to effectively learn a codebook of context-rich visual parts.
- Utilization of transformers to efficiently model their composition within high-resolution images.
The Model Architecture of Taming Transformers
The model architecture uses convolutional neural VQGAN, which contains encoder-decoder and adversarial training methods to produce codebook (efficient and rich representations) of images. This GAN architecture is used to train the generator to output high-resolution images. Once this GAN training is over, the model architecture then takes only the decoder part as the input to the transformer architecture, a.k.a codebook. This codebook holds the efficient and rich representation of images(instead of pixels) in a compressed form that can be read sequentially. The transformer is then trained with this codebook to predict the next indices’ distribution in the given representation, similar to the autoregressive model, to synthesize the required output.
Some of the results from Taming Transformer
Tasks done by Taming Transformers
- Image Completion
- Depth-to-image generation
- Label-to-image generation
- Pose-to-human generation
- Super Resolution
Quick Start with Taming Transformers
Installation & Dependencies
Clone the repository via git and download all the required models and configs.
!git clone https://github.com/CompVis/taming-transformers %cd taming-transformers !mkdir -p logs/2020-11-09T13-31-51_sflckr/checkpoints !wget 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' -O 'logs/2020-11-09T13-31-51_sflckr/checkpoints/last.ckpt' !mkdir logs/2020-11-09T13-31-51_sflckr/configs !wget 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' -O 'logs/2020-11-09T13-31-51_sflckr/configs/2020-11-09T13-31-51-project.yaml'
Install all the dependencies via pip.
%pip install omegaconf==2.0.0 pytorch-lightning==1.0.8 import sys sys.path.append(".")
Demo – Taming Transformers via pretrained model
- Load the model and print the configs.
from omegaconf import OmegaConf config_path = "logs/2020-11-09T13-31-51_sflckr/configs/2020-11-09T13-31-51-project.yaml" config = OmegaConf.load(config_path) import yaml print(yaml.dump(OmegaConf.to_container(config)))
- Initialize the model.
from taming.models.cond_transformer import Net2NetTransformer model = Net2NetTransformer(**config.model.params)
- Load the checkpoints.
import torch ckpt_path = "logs/2020-11-09T13-31-51_sflckr/checkpoints/last.ckpt" sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] missing, unexpected = model.load_state_dict(sd, strict=False) model.cuda().eval() torch.set_grad_enabled(False)
- Load the example semantic image and convert it into a tensor. Here we are taking label-to-image generation as an example.
from PIL import Image import numpy as np segmentation_path = "data/sflckr_segmentations/norway/25735082181_999927fe5a_b.png" segmentation = Image.open(segmentation_path) segmentation = np.array(segmentation) segmentation = np.eye(182)[segmentation] segmentation = torch.tensor(segmentation.transpose(2,0,1)[None]).to(dtype=torch.float32, device=model.device)
- Visualize the segmentation.
def show_segmentation(s): s = s.detach().cpu().numpy().transpose(0,2,3,1)[0,:,:,None,:] colorize = np.random.RandomState(1).randn(1,1,s.shape[-1],3) colorize = colorize / colorize.sum(axis=2, keepdims=True) s = s@colorize s = s[...,0,:] s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8) s = Image.fromarray(s) display(s) show_segmentation(segmentation)
- Encode the above image tensors so as to get the codebook via VQGAN.
c_code, c_indices = model.encode_to_c(segmentation) print("c_code", c_code.shape, c_code.dtype) print("c_indices", c_indices.shape, c_indices.dtype) assert c_code.shape*c_code.shape == c_indices.shape segmentation_rec = model.cond_stage_model.decode(c_code) show_segmentation(torch.softmax(segmentation_rec, dim=1))
- Take the decoder output to give it as input to the transformer model.
def show_image(s): s = s.detach().cpu().numpy().transpose(0,2,3,1) s = ((s+1.0)*127.5).clip(0,255).astype(np.uint8) s = Image.fromarray(s) display(s) codebook_size = config.model.params.first_stage_config.params.embed_dim z_indices_shape = c_indices.shape z_code_shape = c_code.shape z_indices = torch.randint(codebook_size, z_indices_shape, device=model.device) x_sample = model.decode_to_img(z_indices, z_code_shape) show_image(x_sample)
- The last step is to train the transformer so as to produce the required output. The code snippet is available here. You can check at each step how the transformer is being trained from decoder output to get the final image. The window size for training is 16 X 16.
- Check this link to run streamlit on Colab.
- Pre-trained models for image generation, depth-to-image generation, super-resolution, etc. are available here.
- For training the data from scratch, you can refer here.
In this article, we have given an overview of Taming Transformation for High-Resolution Image Synthesis. Instead of using pixels, the model is trained on a codebook from VQGAN whose decoder is then fed to transformer architecture to generate the required results. This post discussed the need, model architecture, results & tasks of the proposed method. It also consists of basic tutorials of using Taming Transformers pre-trained models. The final result shows that this method has outperformed previous state-of-the-art methods based on convolutional architectures.
Official References are available at: