Transformers outshine convolutional neural networks and recurrent neural networks in many applications from various domains, including natural language processing, image classification and medical image segmentation. Point Transformer is introduced to establish state-of-the-art performances in 3D image data processing as another piece of evidence. Point Transformer is robust to perform multiple tasks such as 3D image semantic segmentation, 3D image classification and 3D image part segmentation.
3D images are quite different from and complex than 2D images. 2D images are collections of pixels arranged in a 2D grid, whereas 3D images are collections of 3D data point clouds embedded in a continuous space as sets. This difference makes standard computer vision deep learning networks not suitable for 3D image processing. A standard convolutional layer operates on a 2D image with a simple convolution operator. But a convolution operator cannot be applied on sparse clouds of 3D image points.
Applications on 3D image data such as Augmented Reality, Virtual Reality, Autonomous Vehicles and Robot navigations grow exponentially, leading to a strong requirement for powerful yet efficient networks to deploy in production environments. Modifications done to convolutional networks such as voxelization, sparse convolution, continuous convolution and graph networking have been implemented by so much research recently. However, the compute-efficiency requirements are not fulfilled by those models.
Hengshuang Zhao and Philip Torr of the University of Oxford, Li Jiang and Jiaya Jia of the Chinese University of Hong Kong and Vladlen Koltun of the Intel Labs have implemented self-attention based networks to solve 3D image processing problems. They have named their model, the Point Transformer, which reaches a new milestone in various public 3D image datasets by outperforming the present strongest models.
How does Point Transformer work?
Transformers and its variants, with the self-attention mechanism at their core, lead many machine learning fields nowadays with powerful models such as Vision Transformer (Image classification), TransUNet (Medical Image Segmentation), ENCONTER (Language Modeling), CTRL (Controlled Language Generation). Many continuous research attempts have been performed to incorporate the self-attention mechanism in different domains and tasks.
The self-attention mechanism in a Transformer follows a simple set operation. It is not affected by the cardinality and permutations of the input features. Since the 3D image points form a cloud set locally, the self-attention operator suits it perfectly. The point Transformer layer performs self-attention operations and pointwise operations on the 3D point clouds. It performs 3D scene understanding on sparse point clouds easily with these implementations. A Point Transformer Network is built by stacking these Point Transformer layers. This network can be used as a general backbone to 3D scene understanding applications.
Input to a Point Transformer layer is a set of a 3D point and its k-nearest neighbors. Three parallel stacks of networks do operate on these input points. One is a multi-layer perceptron performing position encoding function. Two networks are pointwise feature transformation networks performing simple projections and linear transformations of input point clouds. These linear transformations are fed to two separate normalization functions along with the output of position encoding function. A normalization function is usually a softmax function and may vary based on the application. The two normalisation functions’ outputs are mapped to an aggregation network via respective feature aggregation mapping functions. The aggregation network yields the necessary output of a Point Transformer layer.
Python Implementation of Point Transformer
Point Transformer is available as a PyPi package. It can be simply pip installed to use in applications. Point Transformer is implemented in the PyTorch environment. Its requirements are Python 3.7+, PyTorch 1.6+ and einops 0.3+.
!pip install point-transformer-pytorch
Import the necessary libraries and modules.
import torch from point_transformer_pytorch import PointTransformerLayer
An example implementation of a Point Transformer layer is provided in the following codes.
attn = PointTransformerLayer( dim = 128, pos_mlp_hidden_dim = 64, attn_mlp_hidden_mult = 4 ) feats = torch.randn(1, 16, 128) pos = torch.randn(1, 16, 3) mask = torch.ones(1, 16).bool() attn(feats, pos, mask = mask) # (1, 16, 128)
Number of nearest neighbors can be controlled through the corresponding argument in the
PointTransformerLayer module. In the following example implementation, the number of nearest neighbors is set to 16. While processing, the layer will consider 16 nearest points in the 3D cloud space.
attn = PointTransformerLayer( dim = 128, pos_mlp_hidden_dim = 64, attn_mlp_hidden_mult = 4, num_neighbors = 16 # only the 16 nearest neighbors would be attended to for each point ) feats = torch.randn(1, 2048, 128) pos = torch.randn(1, 2048, 3) mask = torch.ones(1, 2048).bool() attn(feats, pos, mask = mask) # (1, 16, 128)
The background source implementation of
PointTransformerLayer is expressed in the following codes. The PyTorch environment is created by importing the necessary packages.
import torch from torch import nn, einsum from einops import repeat
Helper functions for the layer development are defined as follows:
def exists(val): return val is not None def max_value(t): return torch.finfo(t.dtype).max def batched_index_select(values, indices, dim = 1): value_dims = values.shape[(dim + 1):] values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) indices = indices[(..., *((None,) * len(value_dims)))] indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) value_expand_len = len(indices_shape) - (dim + 1) values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] value_expand_shape = [-1] * len(values.shape) expand_slice = slice(dim, (dim + value_expand_len)) value_expand_shape[expand_slice] = indices.shape[expand_slice] values = values.expand(*value_expand_shape) dim += value_expand_len return values.gather(dim, indices)
class PointTransformerLayer(nn.Module): def __init__( self, *, dim, pos_mlp_hidden_dim = 64, attn_mlp_hidden_mult = 4, num_neighbors = None ): super().__init__() self.num_neighbors = num_neighbors self.to_qkv = nn.Linear(dim, dim * 3, bias = False) self.pos_mlp = nn.Sequential( nn.Linear(3, pos_mlp_hidden_dim), nn.ReLU(), nn.Linear(pos_mlp_hidden_dim, dim) ) self.attn_mlp = nn.Sequential( nn.Linear(dim, dim * attn_mlp_hidden_mult), nn.ReLU(), nn.Linear(dim * attn_mlp_hidden_mult, dim), ) def forward(self, x, pos, mask = None): n, num_neighbors = x.shape, self.num_neighbors # get queries, keys, values q, k, v = self.to_qkv(x).chunk(3, dim = -1) # calculate relative positional embeddings rel_pos = pos[:, :, None, :] - pos[:, None, :, :] rel_pos_emb = self.pos_mlp(rel_pos) # use subtraction of queries to keys. i suppose this is a better inductive bias for point clouds than dot product qk_rel = q[:, :, None, :] - k[:, None, :, :] # prepare mask if exists(mask): mask = mask[:, :, None] * mask[:, None, :] # expand values v = repeat(v, 'b j d -> b i j d', i = n) # determine k nearest neighbors for each point, if specified if exists(num_neighbors) and num_neighbors < n: rel_dist = rel_pos.norm(dim = -1) if exists(mask): mask_value = max_value(rel_dist) rel_dist.masked_fill_(~mask, mask_value) dist, indices = rel_dist.topk(num_neighbors, largest = False) v = batched_index_select(v, indices, dim = 2) qk_rel = batched_index_select(qk_rel, indices, dim = 2) rel_pos_emb = batched_index_select(rel_pos_emb, indices, dim = 2) mask = batched_index_select(mask, indices, dim = 2) if exists(mask) else None # add relative positional embeddings to value v = v + rel_pos_emb # use attention mlp, making sure to add relative positional embedding first sim = self.attn_mlp(qk_rel + rel_pos_emb) # masking if exists(mask): mask_value = -max_value(sim) sim.masked_fill_(~mask[..., None], mask_value) # attention attn = sim.softmax(dim = -2) # aggregate agg = einsum('b i j d, b i j d -> b i d', attn, v) return agg
More details on the source code and setup procedure can be found at the official repository.
Performance of Point Transformer
Point Transformer is trained and evaluated on various public datasets for 3D cloud image shape classification, 3D object part segmentation and 3D semantic segmentation.
For semantic scene segmentation, the S3DIS dataset is used. It consists of 3D scenes of rooms in six areas belonging to 13 categories: ceiling, floor, and table from three different buildings. For 3D image shape classification, the ModelNet40 dataset is used. It consists of CAD models of 40 object categories. For 3D object part segmentation, the ShapeNetPart dataset is used. It consists of models from 16 shape categories.
Point Transformer outperforms recent top models, PointNet, SegCloud, SPGraph, MinkowskiNet and KPConv in semantic scene segmentation on the mIoU, mAcc and OA metrics and becomes the state-of-the-art.
In 3D shape classification, Point Transformer becomes the state-of-the-art on accuracy metric by outperforming recent top models, DGCNN and KPConv.
Point Transformer outperforms recent top models, PointNet, SPLATNet, SpiderCNN, PCNN, DGCNN, SGPN, PointConv, InterpCNN and KPConv in 3D object part segmentation on the instance mIoU metric and becomes the state-of-the-art in performance.