Hands-on Guide to Effective Image Captioning Using Attention Mechanism

Before 2015 when the first attention model was proposed, machine translation was based on the simple encoder-decoder model, a stack of  RNN and LSTM layers. The encoder is used to process the entire sequence of input data into a context vector. This is expected to be a good summary of input data. The final stage of the encoder is the initial stage of the decoder. 

How do we humans perceive a given situation and express our thoughts regarding the situation? It is simply a matter of attention or focusing on the context. For example, say you see a scene of sunset at the beach, and your task is to give brief info of the scene; in this case, you try to express the scene by focusing on relatively important objects such as what people are wearing and based on your previous experience you will write your statements. Similarly, we can use this approach to automate this task. The backbone of this task is the use of an attention mechanism incorporated with Recurrent Network

So attention is an increasingly popular mechanism used in a wide range of neural architectures. The mechanism itself has been realised in a variety of formats. Attention is a powerful mechanism developed to enhance encoder and decoder architecture performance on neural network-based machine translation tasks. It is the most prominent idea in the Deep learning community. This mechanism is now used in various problems like image captioning. It was originally designed in the context of neural machine translation using Seq2Seq models.  

How has the attention mechanism evolved in machine learning?

Before 2015 when the first attention model was proposed, machine translation was based on the simple encoder-decoder model, a stack of  RNN and LSTM layers. The encoder is used to process the entire sequence of input data into a context vector. This is expected to be a good summary of input data. The final stage of the encoder is the initial stage of the decoder. 

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Simply, there are two models of RNN and LSTM layers. The First model encoder is supposed to process the input data and summarize it. We call it a context vector. Then this context vector is passed to the decoder, which translates the input sequence by just seeing it.

The main drawback of this approach is that if the encoder model makes a bad summary of input data, then the output decoder will not be as expected. Moreover, it has been observed; the encoder tends to make a bad summary when exposed to longer sentences which is a real-world scenario. This is called the long-range dependency problem of the RNN and LSTM layers.   

Here comes the cause of the loss of information. The embedding layer type of RNN can not remember the long sequences or sentences due to vanishing gradients where weights of the past context get reduced as exposed to longer and longer sequences. However, it can remember the part which it has seen recently. Whereas LSTM is introduced to overcome Vanishing Gradient’s problem, it also fails in some cases like time series prediction. And also, there is one problem in the standard approach when dealing with NLP: there is no way to give importance to some of the text. 

In 2015, Dzmitry Bahdanau came up with an idea to consider all the input data into context vectors. Still, relative importance to each word should be given. In other words, whenever this model is exposed to sentences, it searches for a set of information where the most relevant information is available.

The bidirectional LSTM is used where it generates a sequence of annotations of h1,h2,h3…hTx for each input sentence. All the vectors represent the Tx number of words in the input sentence, whereas, in the simple encoder and decoder model, only the last state of LSTM was used as a context vector.

Below is the demonstration of the captioning model, which incorporates the Attention mechanism and Recurrent Network. This model architecture is similar to mention in the Show, Attend and Tell: Neural Image Caption Generation with Visual Attention. The below code shows the implementation of the same. 

The MS-COCO dataset is used for training, and inceptionV3 is used to process the images and cache a subset of images based on which encoder and decoder are being trained. 

Implementation of Image captioning: 

The below code is taken from Tensorflow’s official implementation.

Import all dependencies:
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm
import collections
import random
import numpy as np
import os
import time
import json
from PIL import Image
Load the dataset:

The dataset contains nearly 82000 images and roughly five captions per image. The below weighs around 13GB, so try to use cloud service.  

# download caption annotation file
annot_folder = '/annotations/'
if not os.path.exists(os.path.abspath('.') + annot_folder):
  annot_zip = tf.keras.utils.get_file('captions.zip',

# download image file
img_fold = '/train2014/'
if not os.path.exists(os.path.abspath('.') + img_fold):
  img_zip =  tf.keras.utils.get_file('captions.zip',
  PATH = os.path.dirname(img_zip)+img_fold

To train fast, size is limited; we use 20,000 images and 30,000 captions associated with it. As we are about to limit the size, the result will be a little poor.  

with open(annot_fl,'r') as f:
  annotations = json.load(f)
img_path_to_cap = collections.defaultdict(list) 

for cal in annotations['annotations']:
  caption = f'<start> {cal['caption']} <end>'
  image_path = PATH + 'COCO_train2014_'+'%012d.jpg' %(val['image_id'])
img_pts = list(img_path_to_cap.keys())
tra_img_pts = img_pts[:6000]

tra_cap = []
img_name_ =[]
for image_path in tra_img_pts:
  cap_li = img_path_to_cap[image_path]
  tra_cap.extend([image_path] * len(cap_li))

There are around 6000 train images.

Load InceptionV3 and preprocess the image:

The shape of the output layer of the model is 8 x 8 x 2048, the last convolutional layer because we are using attention. 

img_mod = tf.keras.applications.InceptionV3(include_top=False,weights='imagenet')
new_ip = img_mod.input
hid_lyr = img_mod.layers[-1].output
image_features_extract_model = tf.keras.Model(new_ip, hid_lyr)

As we are using a pre-trained model, the images should be preprocessed and normalised to handle it.

def load_image(image_path):
  img = tf.io.read_file(image_path)
  img = tf.image.decode_jpeg(img, channels=3)
  img = tf.image.resize(img, (299,299))
  img = tf.keras.applications.inception_v3.preprocess_input(img)
  return img, image_path
Caching the extracted feature from model:
encode_trn = sorted(set(image_name_))
img_data = tf.data.Dataset.from_tensor_slices(endode_trn)
img_data = img_data.map(load_image, num_parallel_calls = tf.data.AUTOTUNE).batch(16)
for img, path in tqdm(img_data):
  batch_feature = image_features_extract_model(img)
  batch_feature = tf.reshape(batch_feature,(batch_feature.shape[0],-1,    batch_feature.shape[3]))

  for bf, p in zip(batch_feature, path):
    path_of_feature = p.numpy().decode('utf-8')
    np.save(path_of_feature, bf.numpy())
Process and tokenize the caption:

First we tokenize the captions which will give us the vocabulary of all the unique words. Thereafter we limit the length of the sequence to 5000 words. 

def max_lenght(tensor):
  return max(len(t) for t in tensor)

top_k = 5000
tok = tf.keras.preprocessing.text.Tokenizer(num_words=top_k,oov_token="<unk>",

tok.word_index['<pad>'] = 0
tok.index_word[0] = '<pad>'
train_seqs = tok.texts_to_sequences(tra_cap)
cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')
max_lenght = max_lenght(train_seqs)

Train test split:

img_to_vec = collections.defaultdict(list)
for img, cap in zip(img_name_, cap_vector):
img_keys = list(img_to_vec.keys())
slce_indx = int(len(img_keys)*0.8)
img_nme_tra_keys, img_nme_val_keys = img_keys[:slce_indx], img_keys[slce_indx:]
img_name_train = []
cap_train = []
for imgt in img_nme_tra_keys:
  capt_len = len(img_to_vec[imgt])
  img_name_train.extend([imgt] * capt_len)
img_name_val = []
cap_val = []
for imgv in img_nme_val_keys:
  capv_len = len(img_to_vec[imgv])
  img_name_val.extend([imgv] * capv_len)
len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)
Create the dataset for training: 
batch_size = 64
embedding_dim = 256
units = 512
vocab_size = tokK+1
num_steps = len(img_name_train)//batch_size
# shape extraction from inceptionV3
features_shape = 2048
atte_feature_shp = 64
def map_func(img_name, cap):
  img_tensor = np.load(img.decode('utf-8')+'.npy')
  return img_tensor, cap

data_set = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))
data_set = data_set.map(lambda item1, item2: tf.numpy_function(map_func,[item1, item2],[tf.float32, tf.int32]),
                        num_parallel_calls = tf.data.AUTOTUNE)
data_set = data_set.shuffle(BUFFER_SIZE).batch(batch_size)
data_set = data_set.prefetch(buffer_size = tf.data.AUTOTUNE)
Build the model:

The below is our main model with the encoder and decoder model, and the attention layer discussed above is used in the decoder. 

class Attention_mecha(tf.keras.Model):
  def __init__(self, units):
    super(Attention_mecha, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, features, hidden):
    hidden_with_time_axis = tf.expand_dims(hidden, 1)
    attention_hidden_layer = (tf.nn.tanh(self.W1(features) +
    score = self.V(attention_hidden_layer)
    attention_weights = tf.nn.softmax(score, axis=1)
    context_vector = attention_weights * features
    context_vector = tf.reduce_sum(context_vector, axis=1)
    return context_vector, attention_weights
class cnn_encoder(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(cnn_encoder, self).__init__()
        self.fc = tf.keras.layers.Dense(embedding_dim)
    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x
class rnn_decoder(tf.keras.Model):
  def __init__(self, embedding_dim, units, vocab_size):
    super(rnn_decoder, self).__init__()
    self.units = units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.units,
    self.fc1 = tf.keras.layers.Dense(self.units)
    self.fc2 = tf.keras.layers.Dense(vocab_size)
    self.attention = Attention_mecha(self.units)

  def call(self, x, features, hidden):
    context_vector, attention_weights = self.attention(features, hidden)
    x = self.embedding(x)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
    output, state = self.gru(x)
    x = self.fc1(output)
    x = tf.reshape(x, (-1, x.shape[2]))
    x = self.fc2(x)
    return x, state, attention_weights

  def reset_state(self, batch_size):
    return tf.zeros((batch_size, self.units))

Separating the encoder and decoder;

encoder = cnn_encoder(embedding_dim)
decoder = rnn_decoder(embedding_dim, units, vocab_size)

Compiling the model;

opti = tf.keras.optimizer.Adam()
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logit = True, reduction = 'none')
def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_obj(real,pred)
  mask = tf.cast(mask, dtype=loss_.dtype)
  return tf.reduce_mean(loss_)
Training Phase:

First, we will extract the features and pass them to the encoder model; further encoder output will act as input to the decoder model, and the decoder returns the predictions.

loss_plot = []
def train_step(img_tensor, target):
  hidden = decoder.reset_state(batch_size=target.shape[0])
  dec_input = tf.expand_dims([tok.word_index['<start>']]*target.shape[0],1)
  with tf,GradientTape() as tape:
    features = encoder(img_tensor)
    for i in range(1, target.shape[1]):
      prediction, hidden, _ = decoder(dec_input, features, hidden)
      loss += loss_function(target[:,i],predictions)
      dec_input = tf.expand_dims(target[:, i],1)

  total_loss = (loss / int(target.shape[1]))
  trainable_variables = encoder.trainable_variables+decoder.trainable_variables
  gradients = tape.gradient(loss, trainable_variables)
  return loss, total_loss

Training loop;

for epoch in range(start_epoch, EPOCHS):
  start = time.time()
  for (batch,(img_tensor, target)) in enumerate(data_set):
    batch_loss, t_loss = train_step(img_tensor, target)
    total_loss += t_loss
    if batch % 100==0:
      average_batch_loss = batch_loss.numpy()/int(target.shape[1])
      print(f'Epoch {epoch+1} Batch {batch} Loss {average_batch_loss:.4f}')

  # storing the values
  if epoch%5==0:
  print(f'Epoch {epoch+1} Loss {total_loss/num_steps:.6f}')
  print(f'Time taken for 1 epoch {time.time() -start:.2f} sec\n')
Generating the captions on random image:

The below evaluates function similar to the training loop.

def evaluate(image):
  attention_plot = np.zero(max_lenght, atte_feature_shp)
  hidden = decoder.reset_state(batch_size=1)
  temp_input = tf.expand_dims(load_image(image)[0],0)
  img_tensor_val = image_features_extract_model(temp_input)
  img_tensor_val = tf.reshape(img_tensor_val,(img_tensor_val.shape[0],-1,
  features = encoder(img_tensor_val)
  dec_input = tf.expand_ims([tok.word_index['<start>']],0)
  result = []
  for i in range(max_lenght):
    predictions,hidden,attention_weights = decoder(dec_input,
                                            features, hidden)
    attention_plot[i] = tf.reshape(attention_weights,(-1,)).numpy()
    predicted_id = tf.random.categorical(predictions,1)[0][0].numpy()
    if tok.index_word[predicted_id]=='<end>':
      return result, attention_plot
    dec_input = tf.expand_dims([predicted_id],0)

  attention_plot = attention_plot[:len(result),:]
  return result, attention_plot

def plot_attention(image, result, attention_plot):
  temp_image = np.array(Image.open(image))
  fig = plt.figure(figsize=(15,15))
  len_result = len(result)
  for i in range(len_result):
    temp_att = np.resize(attention_plot[i],(8,8))
    grid_size = max(np.ceil(len_result/2),2)
    ax = fig.add_subplot(grid_size,grid_size,i+1)
    img = ax.imshow(temp_image)

The below code takes an image from the mentioned URL and predicts its caption; captions will be somewhat weird as we have limited the training size.

image_url = 'https://www.incontextinternational.org/wp-content/uploads/2017/05/action.jpg'
image_extension = image_url[-4:]
image_path = tf.keras.utils.get_file('action.jpg'+image_extension, origin = image_url)
result,attention_plot = evaluate(image_path)
print('Prediction Caption:',' '.join(result))
plot_attention(image_path, result, attention_plot)


Predicted Caption: a man is flying at the person on a dark background <end>


From this article, we have seen the working of the attention mechanism and using it in the image captioning task. The attention does the above visualising the focus of the model. We can see how the model has focused on objects in the image and generated related words; using the full data size can improve the quality of captions generated.


Vijaysinh Lendave
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox