Hands-On Guide To Pixel2Style2Pixel: Image-To-Image Translation

Pixel2Style2Pixel (pSp) is an image translation framework that builds upon the representative power of a pre-trained StyleGAN generator and the W+ latent space.
Pixel2Style2Pixel for image translation

Pixel2Style2Pixel (pSp) is an end-to-end image translation framework that builds upon the representative power of a pre-trained StyleGAN generator and the W+ latent space. The framework proposes a new encoder network that can directly embed real images into W+ latent space. It also introduces an identity loss which helps improve image reconstruction performance. 

Architecture & Approach

Pixel2Style2Pixel architecture

Pixel2Style2Pixel uses a fixed StyleGAN2 generator trained on the FFHQ dataset and a ResNet-IR architecture pre-trained on face recognition as the backbone network. StyleGAN showed that the different style inputs correspond to different levels of detail, roughly divided into three groups – coarse, medium, and fine. Taking note of this, the Pixel2Style2Pixel encoder backbone is extended with a feature pyramid. This feature pyramid generates three levels of feature maps from which styles are extracted using an intermediate network – map2style. The generated style vectors are fed into the StlyeGAN generator corresponding to their scale to generate the output image thus completing the translation from input pixels to output pixels through the intermediate style representation. 

Another benefit of StyleGAN’s layer-wise representation is its ability to disentangle semantic objects. This ability to independently manipulate semantic attributes leads to another desired property: the support for multi-modal synthesis. Several image-to-image translation tasks are ambiguous, i.e., a single input image may correspond to several outputs. Therefore, it is desirable to be able to sample all these possible outputs. While standard image-to-image architectures require specialized changes, Pixel2Style2Pixel inherently supports this by simply sampling style vectors. 


Sign up for your weekly dose of what's up in emerging technology.
Pixel2Style2Pixel architecture for multi-modal synthesis

This is done by randomly sampling a random 512-dimensional and generating a corresponding latent code in W+. The styles of this randomly generated latent code, wR, and the computed latent code of the input image, wI, are mixed by inserting select layers of wR into the corresponding layers of wI, with α parameter for blending between the two styles. As shown in the figure above, layers 1−7 are selected from the latent code of the input image, wI,  while layers 8−18 are taken from the sampled vector, wR. This creates multiple outputs with similar coarse and medium features but varying fine features. 

Loss functions used in Pixel2Style2Pixel

Pixel2Style2Pixel encoder uses a weighted combination of three loss functions, first is the pixel-wise L2 loss. 

Here x denotes the input image and pSp(x) is the encoder network’s output. 

In addition to this, to learn perceptual similarities, the encoder uses the LPIPS loss. LPIPS loss has been shown to better preserve image quality compared to the more standard perceptual loss.  

Here F(·) denotes the perceptual feature extractor.

Identity preservation between the input and output images is an important aspect of face generation tasks and none of the loss functions are sensitive to the preservation of facial identity. Pixel2Style2Pixel uses a dedicated recognition loss that measures the cosine similarity between the output image and its source to incorporate this objective into the overall loss function. 

Here R is a pre-trained ArcFace network for face recognition. Combining these three losses we get the total loss function for the encoder.

Here λ1, λ2, and λ3 are constants defining the loss weights.

Multimodal analysis and StyleGAN inversion using Pixel2Style2Pixel 

  1. Clone the repo:

!git clone https://github.com/eladrich/pixel2style2pixel.git

  1. Import required libraries and classes
 from argparse import Namespace
 import time
 import sys
 import pprint
 import numpy as np
 from PIL import Image
 import torch
 import torchvision.transforms as transforms
 from datasets import augmentations
 from utils.common import tensor2im, log_input_image
 from models.psp import pSp 
  1. Create a model, download the weights, and load them.
 #Dictionary of model paths to easily generate download links
     "ffhq_encode": {"id": "1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0", "name": "psp_ffhq_encode.pt"},
     "ffhq_frontalize": {"id": "1_S4THAzXb-97DbpXmanjHtXRyKxqjARv", "name": "psp_ffhq_frontalization.pt"},
     "celebs_sketch_to_face": {"id": "1lB7wk7MwtdxL-LL4Z_T76DuCfk00aSXA", "name": "psp_celebs_sketch_to_face.pt"},
     "celebs_seg_to_face": {"id": "1VpEKc6E6yG3xhYuZ0cq8D2_1CbT0Dstz", "name": "psp_celebs_seg_to_face.pt"},
     "celebs_super_resolution": {"id": "1ZpmSXBpJ9pFEov6-jjQstAlfYbkebECu", "name": "psp_celebs_super_resolution.pt"},
     "toonify": {"id": "1YKoiVuFaqdvzDP5CZaqa3k5phL-VDmyz", "name": "psp_ffhq_toonify.pt"}

 #Function that generates model download link 
 def get_download_model_command(file_id, file_name):
     """ Get wget download command for downloading the desired model and save to directory ../pretrained_models. """
     current_directory = os.getcwd()
     save_path = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models")
     if not os.path.exists(save_path):
     command = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
     return command

 #Available models
 # ['ffhq_encode', 'ffhq_frontalize', 'celebs_sketch_to_face', 'celebs_seg_to_face', 'celebs_super_resolution', 'toonify']

 #generate link and download model weighs
 path = MODEL_PATHS['celebs_seg_to_face'] #model to be downloaded: 'celebs_seg_to_face’
 download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
 !wget {download_command}

 #load model
 model_path = "pretrained_models/psp_celebs_seg_to_face.pt"
 ckpt = torch.load(model_path, map_location='cpu')
 opts = ckpt['opts']
 opts = Namespace(**opts)
 net = pSp(opts)
  1. Let’s do multi-modal synthesis, remember this is where a single input generates multiple outputs corresponding to a task like conditional image synthesis.
 def get_multi_modal_outputs(input_image, vectors_to_inject):
     results = []
     with torch.no_grad():
       for vec_to_inject in vectors_to_inject:
           cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda")
           # get latent vector to inject into our input image
           _, latent_to_inject = net(cur_vec,
           # get output image with injected style vector
           res = net(input_image.unsqueeze(0).to("cuda").float(),
     return results

 # randomly draw the latents to use for style mixing
 vectors_to_inject = np.random.randn(n_outputs_to_generate, 512).astype('float32')
 multi_results = get_multi_modal_outputs(transformed_image, vectors_to_inject) 
  1. Visualize the output
 input_vis_image = log_input_image(transformed_image, opts)
 res = np.array(input_vis_image.resize((256, 256)))
 for output in multi_results:
     output = tensor2im(output)
     res = np.concatenate([res, np.array(output.resize((256, 256)))], axis=1)
 res_image = Image.fromarray(res)

This output should be something like this:

Pixel2Style2Pixel multi-modal sampling
  1. Now let’s try StyleGAN inversion. Encoding an arbitrary image into latent style vectors that can be used to reconstruct the original image.

Setup the pre-trained model.

 path = MODEL_PATHS['ffhq_encode']
 download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
 !wget {download_command}
 ckpt = torch.load(model_path, map_location='cpu')
 opts = ckpt['opts']
 # update the training options
 opts['checkpoint_path'] = model_path
 if 'learn_in_w' not in opts:
     opts['learn_in_w'] = False
 opts = Namespace(**opts)
 net = pSp(opts)
 print('Model successfully loaded!') 

Extract and save the input images

 def get_download_images_command(file_id, file_name):
     """ Get wget download command for downloading the inversion images and save to directory ./inversion_images. """
     save_path = os.getcwd()
     url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
     return url

 save_path = "./inversion_images"
 download_command = get_download_images_command("1wfCiWuHjsj3oGDeYF9Lrkp8vwhTvleBu", "inversion_images.zip")

 !wget {download_command}
 !mkdir {save_path}
 !unzip {inversion_images_file_name} 

Visualize the input images

 import matplotlib.pyplot as plt
 %matplotlib inline
 image_paths = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith(".jpg")]
 n_images = len(image_paths)
 images = []
 n_cols = np.ceil(n_images / 2)
 fig = plt.figure(figsize=(20, 4))
 for idx, image_path in enumerate(image_paths):
     ax = fig.add_subplot(2, n_cols, idx + 1)
     img = Image.open(image_path).convert("RGB")

It should look like this.

Run the inversion model.

 img_transforms = EXPERIMENT_ARGS['transform']
 transformed_images = [img_transforms(image) for image in images]
 batched_images = torch.stack(transformed_images, dim=0)
 with torch.no_grad():
     tic = time.time()
     result_images = run_on_batch(batched_images, net, latent_mask=None)
     toc = time.time()
     print('Inference took {:.4f} seconds.'.format(toc - tic)) 
  1. Visualize and compare the reconstructions with the original images.
 from IPython.display import display
 couple_results = []
 for original_image, result_image in zip(images, result_images):
     result_image = tensor2im(result_image)
     res = np.concatenate([np.array(original_image.resize((256, 256))),
                           np.array(result_image.resize((256, 256)))], axis=1)
     res_im = Image.fromarray(res)

Last Epoch (Endnote)

This post discussed Pixel2Style2Pixel an image-to-image translation framework that introduces a novel encoder architecture tasked with encoding an arbitrary image directly into W+ latent space. This encoder is based on the Feature Pyramid Network; style feature vectors extracted from different pyramid scales are inserted directly into a fixed, pre-trained StyleGAN generator in correspondence to their spatial scales. Pixel2Style2Pixel can also be applied to other image-to-image translation tasks like- face frontalization,  super-resolution, inpainting, and the fascinating and impressive task of generating images from sketches.

Pixel2Style2Pixel sketch2image

To learn more about Pixel2Style2Pixel, and to try out the other tasks/models see the following links:

More Great AIM Stories

Aditya Singh
A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM