Guide to Medical Transformer: Attention for Medical Image Segmentation

Medical Transformer relies on a gated position-sensitive axial attention mechanism that aims to work well on small datasets. It introduces Local-Global (LoGo) a novel training methodology for modelling image data efficiently.

Transformers are revolutionizing the Natural Language Processing domain at an unprecedented pace. The novel idea behind the success of these  seq2seq models is “Attention”.  Attention is a mechanism to give importance to relevant features of the input dynamically. This provides the model with the ability to capture long-range dependencies. Even the computer vision domain has borrowed this idea to improve the deep convolutional networks at the forefront of vision-related tasks.

Medical Transformer relies on a gated position-sensitive axial attention mechanism that aims to work well on small datasets. It introduces Local-Global (LoGo) a novel training methodology for modelling image data efficiently.

THE BELAMY

Sign up for your weekly dose of what's up in emerging technology.

This model was introduced by Jeya Maria Jose Valanarasu, Poojan Oza, Ilker Hacihaliloglu, Vishal M. Patel (researchers at Johns Hopkins University, Baltimore, MD, USA and Rutgers, The State University of New Jersey, NJ, USA) in a paper published on 21 Feb 2021.

Medical Image Segmentation

There are tons of applications for image segmentation across different industries. Medical Imaging too, highly benefits from automatic image segmentation. It helps in com. Segmentation of organs or lesions from a medical scan helps clinicians make an accurate diagnosis, plan the surgical procedure, and propose treatment strategies. In this article let’s explore the problem of segmenting tumors in various organs like the liver, breast, colon etc.

Medical image segmentation faces a severe problem of data sparsity as the size of datasets containing medical images is very low compared to other vision applications.

Let’s explore the following dataset. This dataset contains images of tumor tissues at very high magnification. Nuclei boundary annotations of these tumor cells are also provided in a XML file.

Given such images, our interest is to create masks that can segment tumor cells. These masks in later stages can be used to do grading and staging of the tumors.

Tissue Image
Segment mask created from a manual annotation file.

There are 30 images in the training set with 22,000 annotations in total. We can see that the sample size is quite small. Medical Transformer is designed keeping in mind the sparsity of medical imaging data. So we can get decent masks from this model. 

Architecture

In image data, we have to take care of both the local dependencies and the long-range dependencies. To do this we introduce a Local-Global approach. Medical Transformer’s architecture will contain two branches

1.Global Branch to capture the dependencies between pixels and the entire image.

2.Local branch to capture finer dependencies among neighbouring pixels.

Image is passed through a convolution block before passing through the global branch. The same image is broken down into patches and sent through a similar convolution block before passing through the local branch sequentially. A resampler aggregates the outputs from the local branch based on the position of the patch and generates output feature maps.

Now we use a 1X1 convolutional layer to pool these output feature maps into a segmentation mask.

Convolutional Block

The primary purpose of this block is to extract feature maps on which we can apply the rest of the model to get segmentation masks. This block consists of the following layers

       x = self.conv1(x)
       x = self.bn1(x)
       x = self.relu(x)
       x = self.conv2(x)
       x = self.bn2(x)
       x = self.relu(x)
       x = self.conv3(x)
       x = self.bn3(x)
       x = self.relu(x)

Encoder Block

An encoder block is a bottleneck containing the following Attention blocks. The global branch contains only two encoder blocks whereas the local branch has 4 such blocks.

The main idea of attention is to generate representations of input feature maps by taking into account the relevance of other pixels at each pixel.

 But this will become too computationally expensive really easily because we have to keep track of one triple of the query, key and value vectors. Axial attention is used instead of pixel attention to avoid this problem.

Axial attention factorizes the attention block into two attention blocks one dealing with the height axis and the other with the width axis. This model does not consider positional information yet. Position information is analogous to sequence information in the NLP domain. We should enforce our models to learn where a pattern is present? along with if the pattern is present?

To do this biases are added which represent the positional encodings.

Now each pixel value is calculated using the following equation.

Here r is the positional encoding along the width axis. Another such equation is used to calculate the result of the height layer.

This model works really well on data with a large number of samples, but learning the positional encodings becomes increasingly difficult with a lesser number of samples. If we can’t properly learn these encodings they affect the model performance significantly. A gated approach is introduced to deal with this issue. We will add gates in the model which will control the amount of bias that is added when calculating feature maps.

The following diagram shows how these gates control the addition of bias representing the positional information

Here’s the implementation of this attention layer.

 def forward(self,x):
         x = x.contiguous().view(N * W, C, H)
         # Transformations
         qkv = self.bn_qkv(self.qkv_transform(x))
         q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)
         # Calculate position embedding
         all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size)
         q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0)
         qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
         kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)
         qk = torch.einsum('bgci, bgcj->bgij', q, k)
         # multiply by factors
         qr = torch.mul(qr, self.f_qr)
         kr = torch.mul(kr, self.f_kr)
         stacked_similarity = torch.cat([qk, qr, kr], dim=1)
         stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
         # (N, groups, H, H, W)
         similarity = F.softmax(stacked_similarity, dim=3)
         sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
         sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
         # multiply by factors
         sv = torch.mul(sv, self.f_sv)
         sve = torch.mul(sve, self.f_sve)
         stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
         output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)
         return output 

Decoder Block

Decoder blocks contain the routine convolutional blocks with short circuits from corresponding encoder blocks. Following is the code for the decoder of the local branch.

Decoder Block

Now that we have an understanding of the transformers architecture let’s try to use the author’s implementation to solve our segmentation problem.

Usage

We need to clone the GitHub repository using the following command

!git clone https://github.com/jeya-maria-jose/Medical-Transformer

After some Data preparation, we will have the data in the following format

 Train Folder-----
       img----
           0001.png
           0002.png
           .......
       label---
           0001.png
           0002.png
           .......
 Validation Folder-----
       img----
           0001.png
           0002.png
           .......
       label---
           0001.png
           0002.png
           .......
 Test Folder-----
       img----
           0001.png
           0002.png
           .......
       label---
           0001.png
           0002.png
           ....... 

Training can  be done using the train.py file

 train_DIR='/content/train folder'
 validation_DIR='/content/validation folder'
 test_DIR='/content/test folder'
 train_res_DIR='/content/train results'
 test_rese_DIR='/content/test_results'
 command="python Medical-Transformer/train.py \
  --train_dataset \"{}\" \
  --val_dataset \"{}\" \
  --direc '{}' \
  --batch_size 4 \
  --epoch 400 \
  --save_freq 10 \
  --modelname \"gatedaxialunet\" \
  --learning_rate 0.001 \
  --imgsize 128 \
  --gray \"no\" \
 ".format(train_DIR,validation_DIR,train_res_DIR)
 !{command} 

For testing we need to run test.py.

 command2="python Medical-Transformer/test.py \
 --loaddirec \"{}\" \
 --val_dataset \"{}\" \
 --direc '{}' \
 --batch_size 1 \
 --modelname \"gatedaxialunet\" \
 --imgsize 128 \
 --gray \"no\"".format('/content/train resultsfinal_model.pth',test_DIR,test_rese_DIR)
 !{command2} 

Here are the predictions on test data

Predictions of the model

We got a 0.65 f1 score and 85% Intersection over union on these predictions.

Conclusion

Attention Mechanism works well for a variety of use cases. We need to carefully extend the formulation of attention to make it efficient in tasks where there are different practical constraints. Medical Transformer is a great example of this extension. This shows that there is future scope for using Transformer models in vision-related applications.

References

Paper

Github Repository

Colab notebook

Dataset

More Great AIM Stories

Pavan Kandru
AI enthusiast with a flair for NLP. I love playing with exotic data.

Our Upcoming Events

Masterclass, Virtual
How to achieve real-time AI inference on your CPU
7th Jul

Masterclass, Virtual
How to power applications for the data-driven economy
20th Jul

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

Conference, Virtual
Deep Learning DevCon 2022
29th Oct

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM
MOST POPULAR

What can SEBI learn from casinos?

It is said that casino AI technology comes with superior risk management systems compared to traditional data analytics that regulators are currently using.

Will Tesla Make (it) in India?

Tesla has struggled with optimising their production because Musk has been intent on manufacturing all the car’s parts independent of other suppliers since 2017.