Now Reading
Guide To Video GPT: A Transformer-Based Architecture For Video Generation

Guide To Video GPT: A Transformer-Based Architecture For Video Generation

Video GPT

Video GPT is a novel machine learning architecture that employs likelihood-based generative modelling for video synthesis. It has been recently introduced by Wilson Yan, Yunzhi Zhang, Pieter Abbeel and Aravind Srinivas. (research paper).

Before going into the detailed workings of Video GPT, we will have a quick look at some of its background terminologies.

Register for our upcoming Masterclass>>
  • Autoencoder is an artificial neural network model belonging to the unsupervised learning category. It reduces the dimensionality of input data by ignoring any noisy data. It learns compressing and encoding input data. It then reconstructs the data from that encoded form to a new data representation resembling the original one before encoding. Visit this page to know more about autoencoders.
  • Latent space represents compressed data such that similar data points lie in proximity to each other. Read this article to read about it in detail.
  • Variational AutoEncoder (VAE) is the one that does not give out a single value for each of the encoding dimensions. Instead, it outputs a probabilistic distribution for each attribute in the latent space. Check out this weblink for more details.
  • Vector Quantization (VQ) is an encoding-decoding technique in which an encoder is fed with input vectors. It gives out the index of the closest codeword, which is provided to the decoder. The decoder then recognizes the input vector from it.
  • VQ-VAE (Vector Quantized Variational AutoEncoder) adds a discrete codebook to a standard autoencoder. It compares the output of the encoder network to all the vectors of the codebook; the closest vector is then fed to the decoder network. Thus, the VQ concept is combined with VAE.
  • GPT (Generative Pre-Training) is a pre-trained language model on a large corpus of text and then fine-tuned for required tasks. (article on OpenAI’s GPT).
  • Self-attention: Consider three vectors in a deep learning task viz. ‘Query (Q)’, ‘key (K)’ and ‘value (V)’. The term ‘attention’ means query and key vectors get multiplied such that the resultant vector of probabilities decides the value to be passed on to the subsequent layer. ‘Self-attention’ is the case where all the three vectors Q, K and V are the same. Find the research paper on ‘attention’ here
attention

Image source

The above figure explains ‘attention’ where Q and K vectors are first multiplied using matrix multiplication. The result then goes through a softmax function which creates a probability distribution which is then multiplied with V.

Overview of Video GPT

Video GPT is a simple model architecture that uses VQ-VAE and learns from an inputted raw video its downsampled discrete latent representations. It employs 3D convolutional networks and self-attention. 

Looking for a job change? Let us help you.
working of Video GPT

Image source: Research paper

The above figure explains the working of the Video GPT architecture. LHS of the figure depicts the first stage of operation, which is nothing but training a usual VQ-VAE model. In the second stage in sequence (RHS), the raw video data is encoded by VQ-VAE into latent sequences. At the decoding end, these latent sequences are sampled and converted into a new video sample (by VQ-VAE) resembling the original one.

Pre-trained VQ-VAE models used by Video GPT

  • ucf101_stride4*4*4 : trained on 128*128 dimensional videos (with 16 frames) taken from the UCF-101 dataset.
  • kinetics_stride4*4*4 : trained on 128*128-dimensional videos (with 16 frames) taken from the Kinetics-600 dataset.
  • kinetics_stride2*4*4 : trained on the same data as kinetics_stride4*4*4 but with latent temporal codes that are twice larger, resulting in better video reconstruction.

Note: The strides mentioned in the above models denote the amounts of downsampling across THW (number of images in a batch, height of image, width of image) for encoder structures.

Practical implementation

Here’s a demonstration of how to generate video using Video GPT. The code has been implemented using Python 3.7.10, matplotlib 3.2.2, torch 1.7.1, torchvision 0.8.2, and scikit-video 1.1.11 versions. Step-wise implementation of the code is as follows:

  1. Install Video GPT from GitHub.

!pip install git+https://github.com/wilson1yan/VideoGPT.git

  1. Install scikit-video, a Python library for video processing.

!pip install scikit-video av

  1. Import required libraries and modules.
 import os   
 import matplotlib.pyplot as plt
 from matplotlib import animation
 from IPython.display import HTML  
 import torch
 from torchvision.io import read_video, read_video_timestamps
 from videogpt import download, load_vqvae
 from videogpt.data import preprocess 
  1. Create a dictionary of videos to choose from for video reconstruction.
 vid = {
     'breakdancing': '1OZBnG235-J9LgB_qHv-waHZ4tjofiDgj',
     'bear': '16nIaqq2vbPh-WMo_7hs9feVSe0jWVXLF',
     'jaywalking': '1UxKCVrbyXhvMz_H7dI4w5hjPpRGCAApy',
     'cartoon': '1ONcTMSEuGuLYIDbX-KeFqd390vbTIH9d'
 } 
  1. Here, we are using ‘kinetics_stride2*4*4’ model. 
 “””
 Set up and run CUDA operations which are identical to CPU tensors but computations are performed using GPU
 “””
 dev = torch.device('cuda')
 Download the model
 vqvae = load_vqvae('kinetics_stride2x4x4', device=dev).to(dev) 
  1. Select the video from ‘vid’ to be reconstructed.

vid_name = ‘bear’

Initialize resolution of video to be constructed. It must be divisible by encode image stride which is 2*4*4 here.

resolution = vqvae.hparams.resolution

Initialize duration of the sequence of frames to be displayed.

seq_length = 64

  1. Download the video file.
vid_fname = download(vid[vid_name], f'{vid_name}.mp4')
  1. Decode the entire video frame-by-frame and record a list of frames’ timedtamps.
pts = read_video_timestamps(vid_fname, pts_unit='sec')[0]
  1. Read the video from specified .mp4 file and get its audio and video frames.
video = read_video(vid_fname, pts_unit='sec', start_pts=pts[0], end_pts=pts[seq_length - 1])[0]
  1. Preprocess the video using its THWC values (T: number of images in a batch, H: height of image, W: width of image, C: number of color channels)
video = preprocess(video, resolution, seq_length).unsqueeze(0).to(dev)
  1. Encoding and decoding step
 #Encode the video using VQ-VAE so that it creates latent sequences 
 with torch.no_grad():  #disable the gradient computation
     enc = vqvae.encode(video)
 #Decode the latent sequences to reconstruct the video
     vid_recon = vqvae.decode(enc)
 #Clamp the reconstructed video tensor’s values in the range [-0.5,0.5]
     vid_recon = torch.clamp(vid_recon, -0.5, 0.5) 

Note: The clamp() method used aboe works as follows:

See Also
New Transformer Variants Keep Flooding The Market, Here’s One From Microsoft Called Fastformer

If we define the range as say [-0.5,0.5], then if the tensor has a value say -1.7, it will be clamped to -0.5 since it’s less than -0.5 and hence out of the range. Similarly, if there is an element greater than the range’s upper limit 0.5, say 1,8, it will be clamped to 0.5.

  1. Visualize the reconstructed video.
 #Concatenate original and reconstructed video for visualizing both
 videos = torch.cat((video, vid_recon), dim=-1)
 #Permute the dimensions of the videos from C*T*H*W to T*H*W*C
 #C: number of color channels
 #T: number of images in a batch
 #H: height if image
 #W: width of image
 videos = videos[0].permute(1, 2, 3, 0) 
 """
 Convert the CUDA variables to NumPy. Since NumPy does not support CUDA, GPU to CPU transition is to be done first and then cange the type to unsigned integer 
 """
 videos = ((videos + 0.5) * 255).cpu().numpy().astype('uint8')
 #Create a matplotlib figure
 fig = plt.figure()
 #Title of the plot
 plt.title('Original video (left), Reconstructed video (right)')
 #Disable the axes
 plt.axis('off')
 #Display the plot
 img = plt.imshow(videos[0, :, :, :])
 plt.close() 

Define a function for drawing a clear frame

 def init():
     img.set_data(videos[0, :, :, :]) 

Define a function to be called at each frame for animation.

 def animate(i):
     img.set_data(videos[i, :, :, :])
     return img 

Create an animation by repeatedly calling the animate() function defined above.

anmt = animation.FuncAnimation(fig, animate, init_func=init, frames=videos.shape[0], interval=100) 

Convert the animation to HTML5 video tag

HTML(anmt.to_html5_video())

Output video:

References

What Do You Think?

Join Our Discord Server. Be part of an engaging online community. Join Here.


Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top