MITB Banner

Guide To ResNeSt: A Better ResNet With The Same Costs

ResNeSt architecture combines the channel-wise attention with multi-path representation into a single unified Split-Attention block.

Share

resnest

Convolution neural networks have largely dominated the computer vision domain, but in the last few years, feature-map attention architectures like SE-Net and SK-Net have started to assert dominance. In their paper “ResNeSt: Split-Attention Networks”,  Hang Zhang,  Chongruo Wu, et al. proposed a new ResNet variant that combines the best of both worlds. The ResNeSt architecture leverages the channel-wise attention with multi-path representation into a single unified Split-Attention block. It learns cross-channel feature correlations while preserving independent representation in the meta structure.

Architecture & Approach

Split-Attention block

ResNeSt introduces the Split-Attention block, which enables feature-map attention across different feature-map groups. This Split-Attention block consists of a feature-map group and split attention operations. Like the ResNeXt block, the feature is divided into several groups, and the number of feature-map groups controlled by a cardinality hyperparameter K. ResNeSt adds a new radix hyperparameter R that indicates the number of splits within a cardinal group, so the total number of feature groups is G = KR.

A combined representation for each cardinal group is obtained by fusing with an element-wise summation across multiple splits. The kth cardinal group is represented as:

Here U k ∈ R H×W×C/K  for k ∈ 1, 2, …K, and H, W and C are the block output feature-map sizes. Global contextual information with embedded channel-wise statistics is gathered with global average pooling across spatial dimensions sk. The cth component is calculated as:

A weighted aggregation of the cardinal group representation Vk ∈ RH×W×C/K is done using channel-wise soft attention. The cth channel is then calculated as:

Here aki(c) denotes the soft assignment weight given by:

The mapping Gci determines the weight of each split for channel c based on the global context representation. Cardinal group representations are then concatenated along the channel dimension: V = Concat{V1 , V2 , …VK}. The final output Y of the Split-Attention block is produced using a shortcut connection: Y = V + X. 

ResNeSt Block cardinality-major implementation
Radix-Major Implementation

This layout where the feature-map groups with the same cardinal index reside next to each other physically is called the cardinality-major implementation. It is straightforward but is difficult to modularize and accelerate using standard CNN operators. For this, an equivalent radix-major implementation is used. In this layout, the input feature-map is first divided into RK groups, in which each group has a cardinality-index and radix-index. The groups with the same radix-index reside next to each other. Then a summation across different splits is done so that the feature-map groups with the same cardinality-index but different radix-index are fused. A global pooling layer aggregates over the spatial dimension while keeping the channel dimension separated. This is equivalent to conducting global pooling to each individual cardinal group then concatenating the results.

ResNeSt Block radix-major implementation

Two consecutive fully connected layers with the number of groups equal to cardinality are added after the pooling layer to predict each split’s attention weights. The use of grouped FC layers makes it identical to apply each pair of FCs separately on top of each cardinal group. In this implementation, the first 1 × 1 convolutional layers can be unified into one layer. The 3 × 3 convolutional layers can be implemented using a single grouped convolution with the number of groups of RK. Therefore, the ResNeSt Split-Attention block is now modularized using standard CNN operators.

Network Tweaks

It is essential to preserve spatial information for transfer learning on dense prediction tasks such as detection or segmentation. Generally, ResNet implementations apply the strided convolution at the 3×3 layer instead of the 1×1 layer to preserve such information. Convolutional layers require handling feature-map boundaries with zero-padding strategies. This is suboptimal when transferring to other dense prediction tasks. Instead of using strided convolution at the transitioning block, ResNeSt uses an average pooling layer with a kernel size of 3×3.  It also adopts two modifications introduced by ResNet-D, ResNeSt replaces the 7×7 convolutional layer with three consecutive 3×3 convolutional layers and adds a 2×2 average pooling layer to the shortcut connection before the 1 × 1 convolutional layer for the transitioning blocks with a stride of two.

Image Classification with ResNeSt

  1. Install PyTorch and fvcore.
 pip install torch
 pip install fvcore 
  1. Download the pre-trained ResNeSt model from Torch hub and set it eval mode for making inferences.
 model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True)
 model.eval() 
  1. Get the image(s) for making inferences. You can get sample ImageNet images here.

    filename = "n01491361_tiger_shark.jfif"
  2. Download the ImageNet class labels

wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

  1.  Read the image, process it to match the input specification of ResNeSt and create a mini-batch for inference.
 from PIL import Image
 from torchvision import transforms

 input_image = Image.open(filename)
 preprocess = transforms.Compose([
     transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 ])

 input_tensor = preprocess(input_image)
 input_batch = input_tensor.unsqueeze(0)  
  1. Switch to GPU if available, and make the inference(s).
 if torch.cuda.is_available():
     input_batch = input_batch.to('cuda')
     model.to('cuda')

 with torch.no_grad():
     output = model(input_batch) 
  1. Apply softmax on the confidence scores of the 1000 ImageNet classes to get the probabilities. Print the top 3(k) categories. 
 probabilities = torch.nn.functional.softmax(output[0], dim=0)
 with open("imagenet_classes.txt", "r") as f:
     categories = [s.strip() for s in f.readlines()]
 top3_prob, top3_catid = torch.topk(probabilities, 3)
 for i in range(top3_prob.size(0)):
     print(categories[top3_catid[i]], top3_prob[i].item()) 

Colab Notebook, refer to the above implementation.

References
Share
Picture of Aditya Singh

Aditya Singh

A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.