How to generate images from text using DALL.E Mini?

Generating images from text method works by combining the observed and unobserved categories of text descriptions through some types of auxiliary information, which encodes observable distinguishing properties of objects.

Advertisement

Zero-shot Text to Image generation method works by combining the observed and unobserved categories of text descriptions through some types of auxiliary information, which encodes observable distinguishing properties of objects. In this article, we will be using DALL-E mini for the implementation of Zero-shot text to image generation and will generate the images for a given text string. Following are the topics to be covered in this article.

Table of contents

  1. What is Zero-shot Learning?
  2. How does Zero-shot Text-to-Image generation work?
  3. Generating image using DALL-E Mini

THE BELAMY

Sign up for your weekly dose of what's up in emerging technology.

Let’s start with understanding Zero-shot Learning (ZSL).

What is Zero-shot Learning?

The goal of Zero-Shot Learning (ZSL) is to learn intermediate semantic layers and their properties, so that, during inference, a new class of data can be predicted. As an example, a model developed to distinguish between the images of cats and dogs can also identify images of birds. In these instances, the classes covered are known as the “seen” classes, while the unlabeled training instances are called the “unseen” classes.

Seen-class knowledge can be transferred to unseen classes using a high dimensional vector space called semantic space. By taking advantage of the semantic space as well as a visual feature representation of image content, ZSL may be solved by projecting the visual feature vector and prototype into a combined embedding space. The projection of an image feature vector is matched to the unseen class by using the Nearest Neighbour Search (NNS).

Zero-Shot Learning is a subfield of Transfer Learning as zero-shot learning transfers the knowledge obtained from the training instances to testing instance classification.

There are three main components of ZSL on which the whole process of learning depends.

  1. The data classes used for training the deep learning model which is called Seen Classes
  2. The data classes used for validation purposes on which the existing deep model needs to generalize are Unseen Classes.
  3. As there is no prior knowledge about the unseen classes, some auxiliary information is necessary to solve the Zero-Shot Learning problem. The auxiliary information contains information such as descriptions, semantic information, or word embeddings related to unseen classes

Let’s understand how to use ZSL to synthesize images from text descriptions.

Are you looking for a complete repository of Python libraries used in data science, check out here.

How does Zero-shot Text-to-Image generation work?

The goal is to maximize the Evidence Lower Bound (ELB) on the joint likelihood of the model distribution over images, the description of the image, and the tokens for the encoded RGB image. 

  • Evidence Lower Bound is the statistical technique to approximate the log-likelihood function. It is a part of variational Bayesian inference. So Evidence here is the log-likelihood function and Lower Bound is the method which helps to approximate the function.

If pixels were used directly as image tokens, high-resolution images would require an inordinate amount of memory. The likelihood objective tends to prioritize modeling short-range dependencies between pixels, so much of the modelling capacity would be dedicated to recording high-frequency details, rather than developing the low-frequency structure that makes objects visually recognizable. The whole process could be divided into two-stage which will make the training procedure easier for the learner. 

Training Discrete Variational AutoEncoder 

Each RGB image is compressed by a discrete variational autoencoder (dVAE) into smaller image tokens, each element of which can assume thousands of possible values. This reduces the context size of the transformer without a large degradation in visual quality and it fastens up the process of training. This process is just like how we can label different things from a distant location even if it is not in a clear vision because we have trained our mind to recognize just by seeing the outline and gradients of things.

Now we need to set the initial before the uniform categorical distribution over the possible value codebook vectors, and these vectors need to be parameterized at the same spatial position as the smaller image output by the encoder. The Evidence Lower Bound (ELB) now becomes difficult to optimize, as the initial prior becomes a discrete distribution. 

So, we need to convert the categorical distribution into continuous distribution by refactoring the sample into a deterministic function of the parameters and some independent noise with the fixed distribution this process is also called relaxation. For performing this kind of operation Gumbile trick is perfect. Once the temperature factor of Gumbile one hot vector encoding tends to infinity the distribution starts to become uniform. Now, the exponentially weighted iterate averaging is utilized to maximize relaxed ELB. The likelihood for the distribution over the RGB images generated by the dVAE decoder is evaluated using the log-Laplace distribution. 

The cross-entropy losses are to be normalized for the text and image tokens by the total number of each kind in a batch of data. Since we are primarily interested in image modelling, we multiply the cross-entropy loss for the text by 1/8 and the cross-entropy loss for the image by 7/8.

Prior Learning

This stage focuses on concatenating the encoded text tokens with the smaller converted image tokens and trains the learner with an autoregressive transformer. 

To encode text-image, Byte-Pair Encoding (BPE) is used which encodes the lowercased descriptions and the image using tokens with vocabulary. To ensure that the frequently appearing words in the vocabulary are assigned a single token, the rarer words are broken down into multiple tokens.  As explained above, the dVAE encoder logit is used to generate image tokens. Then the text and image tokens are concatenated and modelled autoregressive as one stream of data. 

Here a 12-billion parameter sparse transformer learns the prior distribution over the text and image tokens by maximizing the ELB. The model uses three kinds of sparse attention masks to allocate each image token to one of its 64 self-attention layers.

  • The convolutional attention mask is only used in the last self-attention layer.
  • The row attention mask
  • The column attention mask

The parts of the attention masks for text-to-text attention use the standard causal mask, and the parts of the image-to-image attention use either a row, column, or convolutional attention mask.

Workflow of Transformer (Image Source)

The above workflow is for the transformer in which two processes are used: forward propagation and backward propagation. The solid line indicates the forward propagation, and the dashed line is the sequence of operations for backpropagation.  When the backpropagation algorithm operates from the output layer towards the input layer (backwards), the gradients reduce and approach zero which eventually leaves the weights of the lower layers almost unchanged. As a result, the gradient descent never converges to the optimum. 

This raises the problem of vanishing gradients. Similarly,  in the forward propagation, the gradients keep on getting bigger. This, in turn, causes very large weight updates and causes the gradient descent to diverge and the exploding gradients problem arises. 

Residual Block mechanism( Image source)

To deal with these problems Residual block or ResBlock is used which is different from the traditional network process. As the data is transferred from one layer to another by hooping 2-3 layers. In the above representation, it could be seen that the ‘x’ identity is escaping a layer and hooping to another. 

The activations and gradients along the identity path are stored in 32-bit precision. The “filter” operation sets all infinite and missing values in the activation gradient to zero. Without this, a nonfinite event in the current Resblock would cause the gradient scales for all preceding Resblocks to unnecessarily drop, thereby resulting in the underflow.

Generating image using DALL-E Mini

We will be using Hugging face DALL-E mini, a pre-trained model that works on ZSL for generating images from text descriptions.

Installing required libraries

!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
!pip install -q git+https://github.com/borisdayma/dalle-mini.git

Loading DALL-E mini model and tokenizer

from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel
 
DALLE_MODEL = "dalle-mini/dalle-mini/wzoooa1c:latest" 
DALLE_COMMIT_ID = None
 
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
 
CLIP_REPO = "openai/clip-vit-large-patch14"
CLIP_COMMIT_ID = None
 
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)

The Vqgan model will generate tokens for the description and the images and the Clip model will calculate the scores for the image generated. Here Wandb API is used, you will need a key to use this API which could be easily be generated

Defining image description

Taking two different image descriptions one with complexity and the other with less complexity to understand the model performance.

descript = "Dog wearing clothes"
tokenized_prompt = processor([descript])
tokenized_prompt = replicate(tokenized_prompt)
descript_1 = "Earth from space"
tokenized_prompt_1 = processor([descript_1])
tokenized_prompt_1 = replicate(tokenized_prompt_1)

Generating image

from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange
 
n_predictions = 10
gen_top_k = None
gen_top_p = None
temperature = 0.85
cond_scale = 3.0
 
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    key, subkey = jax.random.split(key)
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        model.params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = p_decode(encoded_images, vqgan.params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for img in decoded_images:
        images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))

Display the image with the score

print(f"Prompt: {descript}\n")
for idx in logits.argsort()[::-1]:
    display(images[idx])
    print(f"Score: {logits[idx]:.2f}\n")

Here are two samples of images generated out of ten predictions and we can say it pretty much resembles the text description.

Let’s see the other description with less complexity.

So, as the complexity decreases better images are generated.

Final words

Zero-Shot Learning implementation in Text-to-Image generation has proved to be a really great solution for the challenges faced during the synthesis of images. The models are trained with fewer data and the predictability of images is outstanding. This article provides the knowledge of ZSL and how it is used to synthesize images from text descriptions, as well as an implementation of the theory in python. You can experiment with this with the codes linked below.

References

More Great AIM Stories

Sourabh Mehta
Sourabh has worked as a full-time data scientist for an ISP organisation, experienced in analysing patterns and their implementation in product development. He has a keen interest in developing solutions for real-time problems with the help of data both in this universe and metaverse.

Our Upcoming Events

Conference, in-person (Bangalore)
MachineCon 2022
24th Jun

Conference, Virtual
Deep Learning DevCon 2022
30th Jul

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM
MORE FROM AIM