How Point Transformer Excels In 3D Image Processing

Point Transformer reaches a new milestone in various public 3D image datasets by outperforming the present strongest models
Point Transformer

Transformers outshine convolutional neural networks and recurrent neural networks in many applications from various domains, including natural language processing, image classification and medical image segmentation. Point Transformer is introduced to establish state-of-the-art performances in 3D image data processing as another piece of evidence. Point Transformer is robust to perform multiple tasks such as 3D image semantic segmentation, 3D image classification and 3D image part segmentation.

3D images are quite different from and complex than 2D images. 2D images are collections of pixels arranged in a 2D grid, whereas 3D images are collections of 3D data point clouds embedded in a continuous space as sets. This difference makes standard computer vision deep learning networks not suitable for 3D image processing. A standard convolutional layer operates on a 2D image with a simple convolution operator. But a convolution operator cannot be applied on sparse clouds of 3D image points. 

Applications on 3D image data such as Augmented Reality, Virtual Reality, Autonomous Vehicles and Robot navigations grow exponentially, leading to a strong requirement for powerful yet efficient networks to deploy in production environments. Modifications done to convolutional networks such as voxelization, sparse convolution, continuous convolution and graph networking have been implemented by so much research recently. However, the compute-efficiency requirements are not fulfilled by those models.

Hengshuang Zhao and Philip Torr of the University of Oxford, Li Jiang and Jiaya Jia of the Chinese University of Hong Kong and Vladlen Koltun of the Intel Labs have implemented self-attention based networks to solve 3D image processing problems. They have named their model, the Point Transformer, which reaches a new milestone in various public 3D image datasets by outperforming the present strongest models.

Point Transformer tasks
A Point Transformer handles three major 3D image tasks: 
classification, semantic segmentation and part segmentation

How does Point Transformer work?

Transformers and its variants, with the self-attention mechanism at their core, lead many machine learning fields nowadays with powerful models such as Vision Transformer (Image classification), TransUNet (Medical Image Segmentation), ENCONTER (Language Modeling), CTRL (Controlled Language Generation). Many continuous research attempts have been performed to incorporate the self-attention mechanism in different domains and tasks.

The self-attention mechanism in a Transformer follows a simple set operation. It is not affected by the cardinality and permutations of the input features. Since the 3D image points form a cloud set locally, the self-attention operator suits it perfectly. The point Transformer layer performs self-attention operations and pointwise operations on the 3D point clouds. It performs 3D scene understanding on sparse point clouds easily with these implementations. A Point Transformer Network is built by stacking these Point Transformer layers. This network can be used as a general backbone to 3D scene understanding applications.

A typical Point Transformer layer
A typical Point Transformer layer

Input to a Point Transformer layer is a set of a 3D point and its k-nearest neighbors. Three parallel stacks of networks do operate on these input points. One is a multi-layer perceptron performing position encoding function. Two networks are pointwise feature transformation networks performing simple projections and linear transformations of input point clouds. These linear transformations are fed to two separate normalization functions along with the output of position encoding function. A normalization function is usually a softmax function and may vary based on the application. The two normalisation functions’ outputs are mapped to an aggregation network via respective feature aggregation mapping functions. The aggregation network yields the necessary output of a Point Transformer layer.

Python Implementation of Point Transformer

Point Transformer is available as a PyPi package. It can be simply pip installed to use in applications. Point Transformer is implemented in the PyTorch environment. Its requirements are Python 3.7+, PyTorch 1.6+ and einops 0.3+.

!pip install point-transformer-pytorch 

Import the necessary libraries and modules.

 import torch
 from point_transformer_pytorch import PointTransformerLayer 

An example implementation of a Point Transformer layer is provided in the following codes.

 attn = PointTransformerLayer(
     dim = 128,
     pos_mlp_hidden_dim = 64,
     attn_mlp_hidden_mult = 4
 feats = torch.randn(1, 16, 128)
 pos = torch.randn(1, 16, 3)
 mask = torch.ones(1, 16).bool()
 attn(feats, pos, mask = mask) # (1, 16, 128) 


Number of nearest neighbors can be controlled through the corresponding argument in the PointTransformerLayer module. In the following example implementation, the number of nearest neighbors is set to 16. While processing, the layer will consider 16 nearest points in the 3D cloud space.

 attn = PointTransformerLayer(
     dim = 128,
     pos_mlp_hidden_dim = 64,
     attn_mlp_hidden_mult = 4,
     num_neighbors = 16          
     # only the 16 nearest neighbors would be attended to for each point
 feats = torch.randn(1, 2048, 128)
 pos = torch.randn(1, 2048, 3)
 mask = torch.ones(1, 2048).bool()
 attn(feats, pos, mask = mask) # (1, 16, 128) 


The background source implementation of PointTransformerLayer is expressed in the following codes. The PyTorch environment is created by importing the necessary packages.

 import torch
 from torch import nn, einsum
 from einops import repeat 

Helper functions for the layer development are defined as follows:

 def exists(val):
     return val is not None
 def max_value(t):
     return torch.finfo(t.dtype).max
 def batched_index_select(values, indices, dim = 1):
     value_dims = values.shape[(dim + 1):]
     values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
     indices = indices[(..., *((None,) * len(value_dims)))]
     indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
     value_expand_len = len(indices_shape) - (dim + 1)
     values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
     value_expand_shape = [-1] * len(values.shape)
     expand_slice = slice(dim, (dim + value_expand_len))
     value_expand_shape[expand_slice] = indices.shape[expand_slice]
     values = values.expand(*value_expand_shape)
     dim += value_expand_len
     return values.gather(dim, indices) 

Finally, the layer is developed on top of PyTorch’s nn module as a Python Class. It performs masking, attention and aggregation through its forward method.

 class PointTransformerLayer(nn.Module):
     def __init__(
         pos_mlp_hidden_dim = 64,
         attn_mlp_hidden_mult = 4,
         num_neighbors = None
         self.num_neighbors = num_neighbors
         self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
         self.pos_mlp = nn.Sequential(
             nn.Linear(3, pos_mlp_hidden_dim),
             nn.Linear(pos_mlp_hidden_dim, dim)
         self.attn_mlp = nn.Sequential(
             nn.Linear(dim, dim * attn_mlp_hidden_mult),
             nn.Linear(dim * attn_mlp_hidden_mult, dim),
     def forward(self, x, pos, mask = None):
         n, num_neighbors = x.shape[1], self.num_neighbors
         # get queries, keys, values
         q, k, v = self.to_qkv(x).chunk(3, dim = -1)
         # calculate relative positional embeddings
         rel_pos = pos[:, :, None, :] - pos[:, None, :, :]
         rel_pos_emb = self.pos_mlp(rel_pos)
         # use subtraction of queries to keys. i suppose this is a better inductive bias for point clouds than dot product
         qk_rel = q[:, :, None, :] - k[:, None, :, :]
         # prepare mask
         if exists(mask):
             mask = mask[:, :, None] * mask[:, None, :]
         # expand values
         v = repeat(v, 'b j d -> b i j d', i = n)
         # determine k nearest neighbors for each point, if specified
         if exists(num_neighbors) and num_neighbors < n:
             rel_dist = rel_pos.norm(dim = -1)
             if exists(mask):
                 mask_value = max_value(rel_dist)
                 rel_dist.masked_fill_(~mask, mask_value)
             dist, indices = rel_dist.topk(num_neighbors, largest = False)
             v = batched_index_select(v, indices, dim = 2)
             qk_rel = batched_index_select(qk_rel, indices, dim = 2)
             rel_pos_emb = batched_index_select(rel_pos_emb, indices, dim = 2)
             mask = batched_index_select(mask, indices, dim = 2) if exists(mask) else None
         # add relative positional embeddings to value
         v = v + rel_pos_emb
         # use attention mlp, making sure to add relative positional embedding first
         sim = self.attn_mlp(qk_rel + rel_pos_emb)
         # masking
         if exists(mask):
             mask_value = -max_value(sim)
             sim.masked_fill_(~mask[..., None], mask_value)
         # attention
         attn = sim.softmax(dim = -2)
         # aggregate
         agg = einsum('b i j d, b i j d -> b i d', attn, v)
         return agg 

More details on the source code and setup procedure can be found at the official repository.

Performance of Point Transformer

Point Transformer is trained and evaluated on various public datasets for 3D cloud image shape classification, 3D object part segmentation and 3D semantic segmentation. 

For semantic scene segmentation, the S3DIS dataset is used. It consists of 3D scenes of rooms in six areas belonging to 13 categories: ceiling, floor, and table from three different buildings. For 3D image shape classification, the ModelNet40 dataset is used. It consists of CAD models of 40 object categories. For 3D object part segmentation, the ShapeNetPart dataset is used. It consists of models from 16 shape categories. 

Point Transformer in 3D semantic segmentation
Point Transformer in 3D semantic segmentation

Point Transformer outperforms recent top models, PointNet, SegCloud, SPGraph, MinkowskiNet and KPConv in semantic scene segmentation on the mIoU, mAcc and OA metrics and becomes the state-of-the-art.

Point Transformer in 3D Shape classification and retrieval
Point Transformer in 3D Shape classification and retrieval

In 3D shape classification, Point Transformer becomes the state-of-the-art on accuracy metric by outperforming recent top models, DGCNN and KPConv.

Point Transformer in 3D Object Part Segmentation
Point Transformer in 3D Object Part Segmentation

Point Transformer outperforms recent top models, PointNet, SPLATNet, SpiderCNN, PCNN, DGCNN, SGPN, PointConv, InterpCNN and KPConv in 3D object part segmentation on the instance mIoU metric and becomes the state-of-the-art in performance.

Further reading:

Download our Mobile App

Rajkumar Lakshmanamoorthy
A geek in Machine Learning with a Master's degree in Engineering and a passion for writing and exploring new things. Loves reading novels, cooking, practicing martial arts, and occasionally writing novels and poems.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox