Convolution neural networks have largely dominated the computer vision domain, but in the last few years, feature-map attention architectures like SE-Net and SK-Net have started to assert dominance. In their paper “ResNeSt: Split-Attention Networks”, Hang Zhang, Chongruo Wu, et al. proposed a new ResNet variant that combines the best of both worlds. The ResNeSt architecture leverages the channel-wise attention with multi-path representation into a single unified Split-Attention block. It learns cross-channel feature correlations while preserving independent representation in the meta structure.
Architecture & Approach
Split-Attention block
ResNeSt introduces the Split-Attention block, which enables feature-map attention across different feature-map groups. This Split-Attention block consists of a feature-map group and split attention operations. Like the ResNeXt block, the feature is divided into several groups, and the number of feature-map groups controlled by a cardinality hyperparameter K. ResNeSt adds a new radix hyperparameter R that indicates the number of splits within a cardinal group, so the total number of feature groups is G = KR.
A combined representation for each cardinal group is obtained by fusing with an element-wise summation across multiple splits. The kth cardinal group is represented as:
Here U k ∈ R H×W×C/K for k ∈ 1, 2, …K, and H, W and C are the block output feature-map sizes. Global contextual information with embedded channel-wise statistics is gathered with global average pooling across spatial dimensions sk. The cth component is calculated as:
A weighted aggregation of the cardinal group representation Vk ∈ RH×W×C/K is done using channel-wise soft attention. The cth channel is then calculated as:
Here aki(c) denotes the soft assignment weight given by:
The mapping Gci determines the weight of each split for channel c based on the global context representation. Cardinal group representations are then concatenated along the channel dimension: V = Concat{V1 , V2 , …VK}. The final output Y of the Split-Attention block is produced using a shortcut connection: Y = V + X.
Radix-Major Implementation
This layout where the feature-map groups with the same cardinal index reside next to each other physically is called the cardinality-major implementation. It is straightforward but is difficult to modularize and accelerate using standard CNN operators. For this, an equivalent radix-major implementation is used. In this layout, the input feature-map is first divided into RK groups, in which each group has a cardinality-index and radix-index. The groups with the same radix-index reside next to each other. Then a summation across different splits is done so that the feature-map groups with the same cardinality-index but different radix-index are fused. A global pooling layer aggregates over the spatial dimension while keeping the channel dimension separated. This is equivalent to conducting global pooling to each individual cardinal group then concatenating the results.
Two consecutive fully connected layers with the number of groups equal to cardinality are added after the pooling layer to predict each split’s attention weights. The use of grouped FC layers makes it identical to apply each pair of FCs separately on top of each cardinal group. In this implementation, the first 1 × 1 convolutional layers can be unified into one layer. The 3 × 3 convolutional layers can be implemented using a single grouped convolution with the number of groups of RK. Therefore, the ResNeSt Split-Attention block is now modularized using standard CNN operators.
Network Tweaks
It is essential to preserve spatial information for transfer learning on dense prediction tasks such as detection or segmentation. Generally, ResNet implementations apply the strided convolution at the 3×3 layer instead of the 1×1 layer to preserve such information. Convolutional layers require handling feature-map boundaries with zero-padding strategies. This is suboptimal when transferring to other dense prediction tasks. Instead of using strided convolution at the transitioning block, ResNeSt uses an average pooling layer with a kernel size of 3×3. It also adopts two modifications introduced by ResNet-D, ResNeSt replaces the 7×7 convolutional layer with three consecutive 3×3 convolutional layers and adds a 2×2 average pooling layer to the shortcut connection before the 1 × 1 convolutional layer for the transitioning blocks with a stride of two.
Image Classification with ResNeSt
- Install PyTorch and fvcore.
pip install torch pip install fvcore
- Download the pre-trained ResNeSt model from Torch hub and set it eval mode for making inferences.
model = torch.hub.load('zhanghang1989/ResNeSt', 'resnest50', pretrained=True) model.eval()
- Get the image(s) for making inferences. You can get sample ImageNet images here.
filename = "n01491361_tiger_shark.jfif"
- Download the ImageNet class labels
wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
- Read the image, process it to match the input specification of ResNeSt and create a mini-batch for inference.
from PIL import Image from torchvision import transforms input_image = Image.open(filename) preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0)
- Switch to GPU if available, and make the inference(s).
if torch.cuda.is_available(): input_batch = input_batch.to('cuda') model.to('cuda') with torch.no_grad(): output = model(input_batch)
- Apply softmax on the confidence scores of the 1000 ImageNet classes to get the probabilities. Print the top 3(k) categories.
probabilities = torch.nn.functional.softmax(output[0], dim=0) with open("imagenet_classes.txt", "r") as f: categories = [s.strip() for s in f.readlines()] top3_prob, top3_catid = torch.topk(probabilities, 3) for i in range(top3_prob.size(0)): print(categories[top3_catid[i]], top3_prob[i].item())
Colab Notebook, refer to the above implementation.