Now Reading
How Score-Based SDE Excels In Generative Modeling

How Score-Based SDE Excels In Generative Modeling

Score-based SDE cover art

Score-based generative models show good performance recently in image generation. In the context of statistics, Score is defined as the gradient of logarithmic probability density with respect to the data distribution parameter. Usually, while training a generative model, noises are added to the original image, and the model learns to revert the noisy image back to its original form. In a score-based generative model, noises are added in steps such that the final noisy image follows a predefined probability distribution.  A trained model generates the original image from the predefined distribution following the score estimated at each step during noising. 

Yang Song and Stefano Ermon from Stanford University, and Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar and Ben Poole from Google Brain have introduced Stochastic Differential Equations (SDE) in a score-based generative model instead of perturbing with noise distributions and denoising in steps. In general, stochastic differential equations are used to transform a complex distribution smoothly and continuously into a predefined simple distribution. Similarly, in this generative model framework, the input image is noised smoothly using a forward SDE. The noisy image follows a predefined distribution. The forward SDE evaluates and learns the score function using neural networks during this transition. A score-based reverse-time SDE can smoothly remove the noises from the predefined distribution and generate the original image.

Register for our Workshop>>

The major advantage of incorporating an SDE in generative models is that it does not possess any data-dependent trainable parameter and depends purely on the time-dependent score values. Therefore, the generative process by a reverse-time SDE is achieved through a time-dependent neural network that can estimate the score value at any intermediate time mark, originally learned by a forward SDE. This approach can be applied to generative domains, including images, audio, shapes, and graphs.

overview of Score-SDE
An overview of forward and reverse-time SDEs in image generation (Source).

This framework introduces two SDE solvers while flexibly accepting any SDE solver to integrate with its reverse-time SDE for sampling. The introduced SDE solvers are: Predictor-Corrector sampler that combines a numerical SDE with a score-based model and a probability flow ODE-based deterministic sampler. Since reverse-time SDE can be estimated using unconditional scores, the generation process can be more generalized to any conditional generation without re-training or fine-tuning. Therefore, it solves any inverse problem, including image inpainting, class-conditional generation, and colorization with a single fully-trained SDE model. This unified framework can accept any score-based model into its architecture to produce extraordinary results compared to their original versions. Moreover, this approach generates high-quality and high-fidelity images that no other generative model can generate.

Forward/reverse SDE
The process of learning a prior distribution (noisy predefined) in the forward pass and learning data distribution (denoised/original) in the reverse pass (Source).

The score-based SDE approach is highly flexible to employ different models and to tune hyperparameters. By varying precision parameters in the Probability flow ODE sampler, the number of score function evaluations (NFE) can be greatly varied. However, the quality of the generated image is uncompromised even at low NFE. Thus, Probability flow ODE sampler yields faster sampling than any other sampling method.

Sampling quality against the evaluation number in the Probability Flow ODE sampler (Source).

Python Implementation of Score-based SDE

The Score-SDE requires a PyTorch environment and a CUDA GPU runtime. Most of this code implementation references the official notebook of Score-based SDE. Download the source code files from the official repository.

!git clone https://github.com/yang-song/score_sde_pytorch.git

Output:

Change the directory to proceed further with the dependencies and source codes.

 %cd score_sde_pytorch/
 !ls -p 

Output:

Install dependencies and other requirements with pip command as shown below.

 # install dependencies
 !pip install -r requirements.txt 

Output:

Download pre-trained model’s checkpoints (around 1 GB) from the official storage as shown below.

!gdown --id 1JInV8bPGy18QiIzZcS1iECGHCuXL6_Nz

Output:

Create a directory exp/ve/cifar10_ncsnpp_continuous/ to move the downloaded checkpoint (further processing expects this path).

 # we need this path /content/score_sde_pytorch/exp/ve/cifar10_ncsnpp_continuous
 %cd /content/score_sde_pytorch/
 !mkdir exp/
 %cd exp/
 !mkdir ve/
 %cd ve/
 !mkdir cifar10_ncsnpp_continuous/
 %cd cifar10_ncsnpp_continuous 

Output:

Move the downloaded checkpoint to the newly created directory using the following command.

 # move checkpoint to the newly created directory
 %cd /content/score_sde_pytorch/
 !mv checkpoint_24.pth /content/score_sde_pytorch/exp/ve/cifar10_ncsnpp_continuous/ 

Check the directory for file availability.

 %cd /content/score_sde_pytorch/exp/ve/cifar10_ncsnpp_continuous/
 !ls  

Output:

Create the environment by importing the necessary libraries and modules.

 %load_ext autoreload
 %autoreload 2
 from dataclasses import dataclass, field
 import matplotlib.pyplot as plt
 import io
 import csv
 import numpy as np
 import pandas as pd
 import seaborn as sns
 import matplotlib
 import importlib
 import os
 import functools
 import itertools
 import torch
 from losses import get_optimizer
 from models.ema import ExponentialMovingAverage
 import torch.nn as nn
 import numpy as np
 import tensorflow as tf
 import tensorflow_datasets as tfds
 import tensorflow_gan as tfgan
 import tqdm
 import io
 import likelihood
 import controllable_generation
 from utils import restore_checkpoint
 sns.set(font_scale=2)
 sns.set(style="whitegrid") 
 import models
 from models import utils as mutils
 from models import ncsnv2
 from models import ncsnpp
 from models import ddpm as ddpm_model
 from models import layerspp
 from models import layers
 from models import normalization
 import sampling
 from likelihood import get_likelihood_fn
 from sde_lib import VESDE, VPSDE, subVPSDE
 from sampling import (ReverseDiffusionPredictor, 
                       LangevinCorrector, 
                       EulerMaruyamaPredictor, 
                       AncestralSamplingPredictor, 
                       NoneCorrector, 
                       NonePredictor,
                       AnnealedLangevinDynamics)
 import datasets 

Build the SDE model using the following codes.

 %cd /content/score_sde_pytorch/
 from configs.ve import cifar10_ncsnpp_continuous as configs
 ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
 config = configs.get_config()  
 sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
 sampling_eps = 1e-5 

Build a score-based generative model and restore the downloaded checkpoint.

 batch_size =   64
 config.training.batch_size = batch_size
 config.eval.batch_size = batch_size
 random_seed = 0 
 sigmas = mutils.get_sigmas(config)
 scaler = datasets.get_data_scaler(config)
 inverse_scaler = datasets.get_data_inverse_scaler(config)
 score_model = mutils.create_model(config)
 optimizer = get_optimizer(config, score_model.parameters())
 ema = ExponentialMovingAverage(score_model.parameters(),
                                decay=config.model.ema_rate)
 state = dict(step=0, optimizer=optimizer,
              model=score_model, ema=ema)
 state = restore_checkpoint(ckpt_filename, state, config.device)
 ema.copy_to(score_model.parameters()) 

Define helper functions to display images for the proceeding generative examples.

 def image_grid(x):
   size = config.data.image_size
   channels = config.data.num_channels
   img = x.reshape(-1, size, size, channels)
   w = int(np.sqrt(img.shape[0]))
   img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))
   return img 
 def show_samples(x):
   x = x.permute(0, 2, 3, 1).detach().cpu().numpy()
   img = image_grid(x)
   plt.figure(figsize=(8,8))
   plt.axis('off')
   plt.imshow(img)
   plt.show() 

Develop a Predictor (ReverseDiffusionPredictor) and a Corrector (LangevinCorrector) to perform Predictor-Corrector (PC) sampling.

 img_size = config.data.image_size
 channels = config.data.num_channels
 shape = (batch_size, channels, img_size, img_size)
 predictor = ReverseDiffusionPredictor 
 corrector = LangevinCorrector 
 snr = 0.16 
 n_steps =  1 
 probability_flow = False 
 sampling_fn = sampling.get_pc_sampler(sde, shape, predictor, corrector,
 inverse_scaler, snr, n_steps=n_steps,
 probability_flow=probability_flow,
 continuous=config.training.continuous,
 eps=sampling_eps, device=config.device)
 x, n = sampling_fn(score_model)
 show_samples(x) 

Output:

Score-SDE

Develop a Probability flow ODE sampler to get likelihood and a unique embedded data representation.

 shape = (batch_size, 3, 32, 32)
 sampling_fn = sampling.get_ode_sampler(sde, shape, inverse_scaler, denoise=True, eps=sampling_eps, device=config.device)
 x, nfe = sampling_fn(score_model)
 show_samples(x) 

Output:

See Also
VDSM

Score-SDE

Compute the Likelihood for each image in the dataset.

 train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=True, evaluation=True)
 eval_iter = iter(eval_ds)
 bpds = []
 likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler, eps=1e-5)
 for batch in eval_iter:
   img = batch['image']._numpy()
   img = torch.tensor(img).permute(0, 3, 1, 2).to(config.device)
   img = scaler(img)
   bpd, z, nfe = likelihood_fn(score_model, img)
   bpds.extend(bpd)
   print(f"average bpd: {torch.tensor(bpds).mean().item()}, NFE: {nfe}") 

Output:

Reconstruct the original images using the embedded representations of data.

 train_ds, eval_ds, _ = datasets.get_dataset(config, uniform_dequantization=False, evaluation=True)
 eval_batch = next(iter(eval_ds))
 eval_images = eval_batch['image']._numpy()
 shape = (batch_size, 3, 32, 32)
 likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler, eps=1e-5)
 sampling_fn = sampling.get_ode_sampler(sde, shape, inverse_scaler, denoise=True, eps=sampling_eps, device=config.device)
 plt.figure(figsize=(18, 6))
 plt.subplot(1, 2, 1)
 plt.axis('off')
 plt.imshow(image_grid(eval_images))
 plt.title('Original images')
 eval_images = torch.from_numpy(eval_images).permute(0, 3, 1, 2).to(config.device)
 _, latent_z, _ = likelihood_fn(score_model, scaler(eval_images))
 x, nfe = sampling_fn(score_model, latent_z)
 x = x.permute(0, 2, 3, 1).cpu().numpy()
 plt.subplot(1, 2, 2)
 plt.axis('off')
 plt.imshow(image_grid(x))
 plt.title('Reconstructed images') 

Output:

Score-SDE
Score-SDE

We can visualize that the Probability Flow ODE reconstructs images with greater visual quality.

Now, we can use the built model for controlled generations. First, we generate inpainting images. The model will generate the image portions where the original images are masked.

 train_ds, eval_ds, _ = datasets.get_dataset(config)
 eval_iter = iter(eval_ds)
 bpds = []
 predictor = ReverseDiffusionPredictor 
 corrector = LangevinCorrector 
 snr = 0.16 
 n_steps = 1 
 probability_flow = False 
 pc_inpainter = controllable_generation.get_pc_inpainter(sde, predictor, corrector, inverse_scaler, snr=snr, n_steps=n_steps, probability_flow=probability_flow, continuous=config.training.continuous,
 denoise=True)
 batch = next(eval_iter)
 img = batch['image']._numpy()
 img = torch.from_numpy(img).permute(0, 3, 1, 2).to(config.device)
 show_samples(img)
 mask = torch.ones_like(img)
 mask[:, :, :, 16:] = 0.
 show_samples(img * mask)
 x = pc_inpainter(score_model, scaler(img), mask)
 show_samples(x) 

Output:

Score-SDE
Original images
Score-SDE
Masked images
Score-SDE
Inpainted images

It is observed that most of the masked images were inpainted correctly to their original version. Secondly, we apply the built model for colourization tasks. The model can generate coloured images from grayscale images.

 train_ds, eval_ds, _ = datasets.get_dataset(config)
 eval_iter = iter(eval_ds)
 bpds = []
 predictor = ReverseDiffusionPredictor 
 corrector = LangevinCorrector 
 snr = 0.16 
 n_steps = 1
 probability_flow = False
 batch = next(eval_iter)
 img = batch['image']._numpy()
 img = torch.from_numpy(img).permute(0, 3, 1, 2).to(config.device)
 show_samples(img)
 gray_scale_img = torch.mean(img, dim=1, keepdims=True).repeat(1, 3, 1, 1)
 show_samples(gray_scale_img)
 gray_scale_img = scaler(gray_scale_img)
 pc_colorizer = controllable_generation.get_pc_colorizer(
     sde, predictor, corrector, inverse_scaler,
     snr=snr, n_steps=n_steps, probability_flow=probability_flow,
     continuous=config.training.continuous, denoise=True
 )
 x = pc_colorizer(score_model, gray_scale_img)
 show_samples(x) 

Output:

Score-SDE
Original images
Score-SDE
Grayscale images
Score-SDE
Colourized (generated) images

It is observed that most colours are restored back to their original, while very few colours are different from the original.

Wrapping Up Score-based SDE

Score-based Stochastic Differential Equations is a generalized framework meant exclusively for image generation. With a PC sampler and a Probability flow ODE sampler, Score-based SDE models yield both faster and more accurate outputs than existing approaches. This approach achieves an extraordinary Inception Score of 9.89 and an FID of 2.2 for unconditional image generation on CIFAR-10 image dataset. Thus, the Score-based SDE approach is presently the state-of-the-art in generative modeling tasks, including class-conditioned image generation, image inpainting, image colourization, high-fidelity high-resolution image generation.

Score-SDE generated images
A few of the high-fidelity, high-resolution generated images from the celebrity dataset

References: 

What Do You Think?

Join Our Telegram Group. Be part of an engaging online community. Join Here.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top