Generative adversarial networks (GANs) have become exceedingly good at photorealistic image synthesis from randomly sampled latent codes. Additionally, the generated output images can be easily transformed/edited(e.g., adding a smile or glasses) by tweaking the latent code. However, due to the high computational cost of large generators it usually takes seconds to see the results of a single edit on edge devices. Inspired by the quick preview option in modern creative softwares Ji Lin, Richard Zhang, et al, proposed anycost GAN for more immediate and interactive image editing in their paper – “Anycost GANs for Interactive Image Synthesis and Editing”. The anycost generator can be executed at a wide range of computational costs while producing visually consistent outputs. A low-cost generator is used for fast, responsive previews during image editing, and then the full-cost generator is used to render high-quality final outputs.
Architecture & Approach
Anycost GAN uses a subsection of the whole generator G’ to independently produce an output x’ = G’(w) which is perceptually similar to the full generator G(w). This G’(w) is used for fast preview during interactive editing, and full G(w) is used to render the final high-quality outputs. This allows the model to be deployed on diverse hardware and the end-users get to choose between different preview qualities depending on their hardware.
Sign up for your weekly dose of what's up in emerging technology.
Sampling-Based Multi-Resolution Training
The obvious choice for enabling a range of inference costs is to use different image resolutions. The ProGAN and StyleGAN family architectures already produce lower-res intermediate images but these are not visually similar to the final output image. To produce more accurate lower-resolution intermediate outputs anycost GAN enforces a multi-scale training objective. The generator produces incrementally higher resolution outputs after each block gk
here K is the total number of network blocks. The intermediate low-res outputs (x k) are used for the preview.
Existing approaches for multi-resolution training, like MSG-GAN, train the generator to support different resolutions by using just one discriminator for images of all resolutions. This kind of multi-resolution training mechanism affects the quality of the generated results on large-scale datasets, to overcome this image degradation issue anycost GAN uses a sampling-based training objective.At each iteration, a single resolution is sampled and trained both for the generator G and the discriminator D as shown in the figure.
When sampling a lower resolution image, the translucent parts are not executed. The intermediate output of G for a lower resolution is passed through a fromRGB convolution layer to increase channels and then fed to an intermediate layer of D. The objective function of this multi-resolution training is:
Adaptive Channel Training
Reducing the resolution from 1024×1024 to 256×256 only reduces the computation cost by 1.7×, despite having 16× fewer pixels. Just using variable resolutions isn’t enough to improve the generator speed significantly, to further improve the speed the anycost generator is trained to support variable channels. For each training iteration, a randomly sampled channel multiplier configuration is used and the corresponding subset of weights is updated. To preserve the most “important” channels during sampling the model is initialized using the multi-resolution generator from the previous stage and the channels of convolutional layers are sorted according to the magnitude of kernels, from highest to lowest. The most important ????c channels are sampled where ???? ∈ [0.25, 0.5, 0.75, 1] and c is the number of channels in the layer. The adaptive-channel training objective is written as follows:
Here C means the channel configurations for each layer. To keep the consistency of the output across different sub-networks a combination of MSE loss and LPIPS loss is added in the form of the following consistency loss:
Unlike regular GAN training, anycost GAN trains many sub-generators of different channels and resolutions at the same time. Using one single discriminator for all sub-generators of different channel configurations results in poor performance. To overcome this challenge the anycost discriminator is conditioned on the generator architecture, a learning-based approach is used to implement the conditioning. The channel configuration is first encoded into
g_arch vector using one-hot encoding,
g_arch is then passed through a fully connected layer to form the per-channel modulation. The feature map is modulated using the conditioned weight and bias before passing to the next layer.
Interactive Editing with Anycost GAN
- Clone the Anycost GAN repository and navigate into the
git clone https://github.com/mit-han-lab/anycost-gan.git
- Install the dependencies by creating a new Anaconda environment using environment.yml
conda env create -f environment.yml
Or update an existing environment
conda env update --name myenv --file environment.yml
- Let’s first visualize the preview (reduced) forms of an image for the four resolution choices and channel widths.
Import necessary libraries and classes
import os import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm_notebook as tqdm import matplotlib %matplotlib inline import matplotlib.pyplot as plt import json from PIL import Image, ImageFont, ImageDraw import models from models.dynamic_channel import set_sub_channel_config, remove_sub_channel_config, set_uniform_channel_ratio, reset_generator
Download and load the anycost generator trained on FFHQ dataset
g_ffhq = models.get_pretrained('generator', 'anycost-ffhq-config-f').to(device).eval()
Create functions for plotting the 4×4 grid of preview images.
def torch_to_np_image(x): assert x.shape == 1 x = x.squeeze(0) return ((x.permute(1, 2, 0) + 1) * 0.5 * 255).cpu().numpy().astype('uint8') def add_legend_to_figure(x, full_res): h, w, c = x.shape x_pad_rate = 0.1 y_pad_rate = 0.05 pad_h = int(h + h * y_pad_rate) pad_w = int(w + w * x_pad_rate) pad_x = np.zeros([pad_h, pad_w, c], dtype=x.dtype) pad_x[...] = (255) pad_x[-h:, -w:] = x img = Image.fromarray(pad_x) draw = ImageDraw.Draw(img) font = ImageFont.truetype(font_path, int(100 * full_res / 1024)) for i, text in enumerate(['1x', '0.75x', '0.5x', '0.25x']): text_w, text_h = draw.textsize(text, font=font) x_coord = x_pad_rate * w * 0.9 - text_w y_coord = (y_pad_rate + i * 1./4 + 1./8) * h - 0.5 * text_h draw.text((x_coord, y_coord), text ,(0,0,0), font=font) for i in range(4): text = str(full_res // (2 ** i)) text_w, text_h = draw.textsize(text, font=font) x_coord = (x_pad_rate + i * 1./4 + 1./8) * w - 0.5 * text_w y_coord = y_pad_rate * h * 0.9 - text_h draw.text((x_coord, y_coord), text ,(0,0,0), font=font) return np.asarray(img) def get_4x4_grid(g, latent, truncation=0.5, crop=False): using z code (latent) images =  full_resolution = g.resolution if truncation < 1: mean_style = g.mean_style(10000) else: mean_style = None with torch.no_grad(): for channel_mult in [1, .75, .5, .25]: set_uniform_channel_ratio(g, channel_mult) _, all_rgbs = g(latent, return_rgbs=True, truncation=truncation, truncation_style=mean_style, randomize_noise=False) all_rgbs = all_rgbs[-4:][::-1] all_rgbs = [F.interpolate(rgb.clamp(-1, 1), size=(full_resolution, full_resolution), mode='bilinear', align_corners=True) for rgb in all_rgbs] images.append(torch.cat(all_rgbs, dim=3)) image_grid = torch.cat(images, dim=2) image_grid_np = torch_to_np_image(image_grid) return add_legend_to_figure(image_grid_np, full_resolution)
Plot the image grid
latent = torch.randn(1, 1, 512, device=device) img_out_np = get_4x4_grid(g_ffhq, latent) plt.figure(figsize = (6,6)) plt.imshow(img_out_np) plt.axis('off')
- Now, let’s try the interactive editor
A PyQt5 window will open. Full generator G(w) previews will take 3-8 seconds to render, sub-generator G’(w) previews will take roughly 1 second to render.