Active Hackathon

A Residual-Based StyleGAN Encoder Via Iterative Refinement

The Style generative adversarial network, in short StyleGAN, is an extension of GAN architecture that posses large...

A Generative Adversarial Networks, in short, GAN is an approach to generative modeling using deep neural networks methods such as convolutional neural networks. Those are effective in generating high-quality images. Generative modeling is an unsupervised task of machine learning that involves automatically discovering and learning the regularities and patterns in input data so that the model can be used to output new examples that could be more similar to input data. 

The Style generative adversarial network, in short StyleGAN, is an extension of GAN architecture that posses large changes to the generator model, including the use of mapping network to map points in latent space(latent space nothing but the representation of compressed data) to an intermediate latent space, the use of intermediate latent space to control the style at each point in the generator model and introduction of noise as a source of variation at each point in the generator model. Thus, the model is capable of generating high-quality photos of faces and offers control over the style of the generated images.  


Sign up for your weekly dose of what's up in emerging technology.

In this post, we are going to discuss the Restyle Encoder. 

Architecture and Method:

The author of the Restyle Encoder proposes a novel architecture, a fast iterative method of image inversion into the latent space of pretrained StyleGAN generators that archives state of the art quality at lower inference time. The core idea of this innovation is to start from the average latent vector in W+ and predict an offset that would make the generated image look more like the target image, then repeat this step with a new image and latent vector as a starting point. With this approach, a perfect inversion can be achieved in about just ten steps.

The below figure is taken from the official research paper, which clearly explains the architecture with example. For detailed information, you can check the paper here.

Source: Architecture of Restyle Encoder

Official Code Implementation

The Restyle encoder offers various experiments, i.e. it can infer human faces, cars, horses, wild animals, churches and toonify the input image. Just we need to select on which we want to perform inference. 

Note: The below code is taken from official code resource.

Clone the repository to colab:
import os
CODE_DIR = 'restyle-encoder'
!git clone $CODE_DIR
!sudo unzip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 
Import all dependencies:
 from argparse import Namespace
 import time
 import os
 import sys
 import pprint
 import numpy as np
 from PIL import Image
 import torch
 import torchvision.transforms as transforms
 from utils.common import tensor2im
 from models.psp import pSp
 from models.e4e import e4e
 %load_ext autoreload
 %autoreload 2 
Select the type of experiment you want to perform as listed below
 #@title Select which experiment you wish to perform inference on: { run: "auto" }
 experiment_type = 'ffhq_encode' #@param ['ffhq_encode', 'cars_encode', 'church_encode', 'horse_encode', 'afhq_wild_encode', 'toonify'] 
Download the pretrained models

The author has provided a pre-trained model for each example listed above. The below code shows the command to download the pre-trained model for the desired experiment. 

 def get_download_model_command(file_id, file_name):
     """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
     current_directory = os.getcwd()
     save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models")
     if not os.path.exists(save_path):
     url = r"""wget --load-cookies /tmp/cookies.txt "$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate '{FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
     return url     
     "ffhq_encode": {"id": "1sw6I2lRIB0MpuJkpc8F5BJiSZrc0hjfE", "name": ""},
     "cars_encode": {"id": "1zJHqHRQ8NOnVohVVCGbeYMMr6PDhRpPR", "name": ""},
     "church_encode": {"id": "1bcxx7mw-1z7dzbJI_z7oGpWG1oQAvMaD", "name": ""},
     "horse_encode": {"id": "19_sUpTYtJmhSAolKLm3VgI-ptYqd-hgY", "name": ""},
     "afhq_wild_encode": {"id": "1GyFXVTNDUw3IIGHmGS71ChhJ1Rmslhk7", "name": ""},
     "toonify": {"id": "1GtudVDig59d4HJ_8bGEniz5huaTSGO_0", "name": ""}

path = MODEL_PATHS[experiment_type]

download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) 

Define the inference parameters

Inference parameters are nothing but the loading of the particular model and some desired transformation, defining path to pre-trained model and path of images.

     "ffhq_encode": {
         "model_path": "pretrained_models/",
         "image_path": "notebooks/images/face_img.jpg",
         "transform": transforms.Compose([
             transforms.Resize((256, 256)),
             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
     "cars_encode": {
         "model_path": "pretrained_models/",
         "image_path": "notebooks/images/car_img.jpg",
         "transform": transforms.Compose([
             transforms.Resize((192, 256)),
             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
     "church_encode": {
         "model_path": "pretrained_models/",
         "image_path": "notebooks/images/church_img.jpg",
         "transform": transforms.Compose([
             transforms.Resize((256, 256)),
             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
     "horse_encode": {
         "model_path": "pretrained_models/",
         "image_path": "notebooks/images/horse_img.jpg",
         "transform": transforms.Compose([
             transforms.Resize((256, 256)),
             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
     "afhq_wild_encode": {
         "model_path": "pretrained_models/",
         "image_path": "notebooks/images/afhq_wild_img.jpg",
         "transform": transforms.Compose([
             transforms.Resize((256, 256)),
             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
     "toonify": {
         "model_path": "pretrained_models/",
         "image_path": "notebooks/images/toonify_img.jpg",
         "transform": transforms.Compose([
             transforms.Resize((256, 256)),
             transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])


Load the pre-trained model
 model_path = EXPERIMENT_ARGS['model_path']
 ckpt = torch.load(model_path, map_location='cpu')
 opts = ckpt['opts']
 # update the training options
 opts['checkpoint_path'] = model_path
 opts = Namespace(**opts)
 if experiment_type == 'horse_encode': 
     net = e4e(opts)
     net = pSp(opts)  
 print('Model successfully loaded!') 
Visualise the image for face experiment
 image_path = EXPERIMENT_DATA_ARGS[experiment_type]["image_path"]
 original_image ="RGB") 
Perform the Inference

Generate the image corresponding to the average latent code because it will be used in an iterative refinement process.

 img_transforms = EXPERIMENT_ARGS['transform']
 transformed_image = img_transforms(input_image)
 def get_avg_image(net):
     avg_image = net(net.latent_avg.unsqueeze(0),
     avg_image ='cuda').float().detach()
     if experiment_type == "cars_encode":
         avg_image = avg_image[:, 32:224, :]
     return avg_image 

Now run the inference. 

 opts.n_iters_per_batch = 5
 opts.resize_outputs = False  # generate outputs at full resolution
 from utils.inference_utils import run_on_batch
 with torch.no_grad():
     avg_image = get_avg_image(net)
     tic = time.time()
     result_batch, result_latents = run_on_batch(transformed_image.unsqueeze(0).cuda(), net, opts, avg_image)
     toc = time.time()
     print('Inference took {:.4f} seconds.'.format(toc - tic)) 
Visualise the output result side by side 
 if opts.dataset_type == "cars_encode":
     resize_amount = (256, 192) if opts.resize_outputs else (512, 384)
     resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
 def get_coupled_results(result_batch, transformed_image):
     Visualize output images from left to right (the input image is on the right)
     result_tensors = result_batch[0]  # there's one image in our batch
     result_images = [tensor2im(result_tensors[iter_idx]) for iter_idx in range(opts.n_iters_per_batch)]
     input_im = tensor2im(transformed_image)
     res = np.array(result_images[0].resize(resize_amount))
     for idx, result in enumerate(result_images[1:]):
         res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)
     res = np.concatenate([res, input_im.resize(resize_amount)], axis=1)
     res = Image.fromarray(res)
     return res
 res = get_coupled_results(result_batch, transformed_image)

Note that the original image is on the rightmost side.

Just by changing the experiment type earlier in the code, we can perform toonification of the given image shows the result as below;


In this article, we have seen the basic understanding of GAN and StyleGAN. Later on, the novel architecture called Restyle Encoder, which performs inference on various categories, shows the two inferences: normal facial inference and toonification, which is the coolest thing. The quality of the output images can be improved by changing the no of interactions per batch.


More Great AIM Stories

Vijaysinh Lendave
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.

Our Upcoming Events

Conference, Virtual
Genpact Analytics Career Day
3rd Sep

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
21st Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM