Hands-on TransUNet: Transformers For Medical Image Segmentation

TransUNet, a Transformers-based U-Net framework, achieves state-of-the-art performance in medical image segmentation applications

TransUNet, a Transformers-based U-Net framework, achieves state-of-the-art performance in medical image segmentation applications. U-Net, the U-shaped convolutional neural network architecture, becomes a standard today with numerous successes in medical image segmentation tasks. U-Net has a symmetric deep encoder-decoder network with skip-connections to improve detail retention. U-Net has excellent representational power but has poor long-range relation due to convolution operations’ intrinsic locality. Medical image segmentation tasks have a common issue of variations in texture, shape, and size of segments. Therefore U-Net with a self-attention mechanism gives a not-bad performance in medical image segmentation. On the other hand, Transformers are considered the alternative architectures to U-Nets that rely only on attention mechanisms. Transformers are good at modelling global contexts by demonstrating superior transferability for downstream tasks under large-scale pretraining. Transformers show exceptional performance in various machine learning tasks, including natural language processing and image recognition. However, transformers concentrate solely on the global context but fail to capture detailed localization information. It results in low-resolution features and, thus, loss of some minute and valuable information.

A lot of extensions or variations of either U-Nets or Transformers are proposed to improve models’ performances. TransUNet has been introduced by Jieneng Chen, Yongyi Lu, Qihang Yu, Xiangde Luo, Ehsan Adeli, Yan Wang, Le Lu, Alan L. Yuille, and Yuyin Zhou as a hybrid-version of U-Net and Transformers by grasping the abilities of both the architectures. TransUNet can capture in-depth localization information at all network stages while holding the power of transferring the context over a long-range within the network. TransUNet yields better results compared to CNN-based self-attention models. It is built as a framework upon Vision Transformer (ViT), the state-of-the-art method on ImageNet classification. TransUNet demonstrates superior as well as state-of-the-art performance in applications such as multi-organ segmentation and cardiac segmentation.  

Architecture of TransUNet Framework

Transformer unit of the framework consists of 12 Transformer layers. A single Transformer layer of the TransUNet framework consists of a stack of a normalization layer, a Multiple Sequence Alignments layers (MSA), a normalization layer and a multi-layer perceptron. 

Input images are fed to a CNN unit where the features are extracted at different depth-levels at different layers. This CNN unit acts as the encoder part of the framework. The linear projection of extracted features are fed to the Transformer unit to obtain global context of features. Decoder part of the framework is constructed by feeding the output features of the Transformer unit and skip-connections from the encoder part. The output of the decoder is fed to a Segmentation head where the segmented image version of the original image is obtained.

Dice Similarity Coefficient for 0-skip, 1-skip, and 3-skip connections

3-Skip-connections is found to be best among 0, 1, or 3 number of skip connections which yields the highest average dice similarity coefficient (DSC). The best version of the TransUNet framework has 3-skip-connections between the encoder and the decoder parts.

Python Implementation of TransUNet

The public datasets for training or/and testing the framework is available in Google cloud storage in Numpy format. Among various datasets available, we do implementation here using R50+ViT-B_16 dataset. Following are the steps to implement TransUNet in the python environment.


Dataset can be downloaded on the local machine using the following command. A new directory is created to store the dataset so that it can be identified using a pre-trained model.

 wget https://storage.googleapis.com/vit_models/imagenet21k/R50%2BViT-B_16.npz &&
 mkdir -p ../model/vit_checkpoint/imagenet21k &&
 mv R50+ViT-B_16.npz ../model/vit_checkpoint/imagenet21k/R50%2BViT-B_16.npz 


Pre-trained TransUNet can be imported on local machine using the command

!git clone https://github.com/Beckschen/TransUNet.git


The necessary files, python codes for pre-trained weights, training, testing and datasets are downloaded from the corresponding github repository.

Presence of files in the directory can be ensured using the command

 cd TransUNet
 ls -p 



Necessary libraries, packages can be installed on the local machine by running the requirements.txt file available with the downloaded git files. We have to make sure that the CUDA GPU runtime is enabled. In a notebook environment such as colab or Jupyter, CUDA GPU can be enabled using the Runtime menu option. 

 cd TransUNet
 pip install -r requirements.txt 

It may be noted that the current implementation runs on the PyTorch framework.


TransUNet model can be recreated and the weights can be restored back from the pre-trained model using the following command. It can be noted that this command is meant for the  R50-ViT-B_16 dataset. For a different dataset, the command can be modified accordingly. 

 cd TransUNet/
 CUDA_VISIBLE_DEVICES=0 python train.py --dataset Synapse --vit_name R50-ViT-B_16 

Based on the device’s memory, the batch size can be reduced to 12 or 6 to save memory. However the ‘base_lr’ variable in the code has to be decreased linearly, and both batch sizes can reach similar performance.


Once trained, the model can be tested with a suitable dataset using the following command. Here testing is done with the Synapse dataset as performed in the original research paper.

 cd TransUNet/
 python test.py --dataset Synapse --vit_name R50-ViT-B_16 

If the user is interested in testing the model with their own dataset, the data must be saved in the ‘dataset’ directory in Numpy format. 3D images should be clipped within [-125, 275], normalized to [0, 1]. 2D slices should be extracted from 3D volume for training cases while keeping the 3D volume in h5 format for testing cases.

Presence of dataset in the right directory can be confirmed using ‘tree’ command

 sudo apt-get install tree



Performance of TransUNet

The TransUNet framework is evaluated using the Synapse multi-organ segmentation dataset. This dataset consists of abdominal CT scans with annotations of 8 abdominal organs namely, aorta, gallbladder, left kidney, right kidney, liver, pancreas, spleen, and stomach. Groundtruth images are color-annotated to denote the segments. TransUNet segments the given test image with a relatively greater average dice similarity coefficient (DSC) than any existing state-of-the-art. The average Hausdorff distance (HD), a measure of mis-segmentation, is found minimum in the TransUNet method than other existing methods.

Qualitative comparison of TransUNet with existing state-of-the-art methods.

The TransUNet outperforms the existing state-of-the-arts such as Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation (V-Net), Domain Adaptive Relational Reasoning (DARR), Convolutional networks for biomedical image segmentation (U-Net), and Attention Gated Networks (AttnUNet) in terms of average Dice Similarity Coefficient and average Hausdorff distance.

Note: Images other than the code outputs are obtained from the original research paper of TransUNet

Further reading on TransUNet and references:

Download our Mobile App

Rajkumar Lakshmanamoorthy
A geek in Machine Learning with a Master's degree in Engineering and a passion for writing and exploring new things. Loves reading novels, cooking, practicing martial arts, and occasionally writing novels and poems.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox