Mean shift clustering algorithm is a centroid-based algorithm that helps in various use cases of unsupervised learning. It is one of the best algorithms to be used in image processing and computer vision. It works by shifting data points towards centroids to be the mean of other points in the region. It is also known as the mode seeking algorithm. The algorithm’s advantage is that it assigns clusters to the data without automatically defining the number of clusters based on defined bandwidth.
Kernel Density Estimation
Like other clustering algorithms, Mean shift is based on the concept of Kernel Density Estimation(KDE), which is a way to estimate the probability density function of a random variable. KDE is a problem where the inferences of the population are made by data smoothing. It works by providing weights to each data point. The weight function is called a kernel. There are many kinds of kernels, one kind of kernel is the Gaussian kernel. Adding all those kernels together creates a density function(probability surface). The resultant density function variation depends on the used bandwidth parameter.
In the image below, we can see the distribution of some data points in a surface plot.
And in the image below, we can see the KDE surface where our data points are distributed in the surface plot(first image). The hills can be considered as the kernel.
In the contour plot of the KDE surface, we can see the exact smoothing of our data points.
From the images, we can understand how the KDE works in smoothing the data sets to make inferences from the data points. As the size of circles in the plot decreases, the density of the data point increases, which means most of the points in the kernel are trying to be on the small circle where the mean shift comes into the picture, which tries to increase or decrease the density function.
Mean shift is based on the idea of KDE, but what makes it different is that using the bandwidth parameter. We can make the points climb uphill to the nearest peak on the KDE surface. So, iteratively shifting each point to climb uphill to the peak.
The bandwidth parameter used to make the KDE surface varies on the different sizes. For example, we have a tall skinny kernel which means a small kernel bandwidth and in a case where the size of the kernel is short and fat, which means a large kernel bandwidth. A small kernel bandwidth makes the KDE surface hold the peak for every data point more formally, saying each point has its cluster; on the other hand, large kernel bandwidth results in fewer kernels or fewer clusters.
Here we can see the formation of kernels with bandwidth values is equal to two.
In the image, we can see what happens when the bandwidth value is low.
Let’s consider a kernel function 𝐊(xi – x) gives the weight to nearby points for defining the mean. So the weighted mean of the density in a window calculation is determined by.
Where N(x) is the neighbourhood of x.
The value of m(x) – x is called the mean shift.
As discussed before, from the mathematical formula, we can understand that the mean shift tries to shift the point, and when performed iteratively, it will move to the KDE peak.
Basically, in the whole algorithm, after making a copy of data points, those copied points are shifted against the original copy to reach the peak of its kernel surface. Next in the article, we will see how we can implement the algorithm using python with randomly generated data points to find out the clusters according to the size and bandwidth parameter.
Implementations in Python
Importing the libraries:
import numpy as np import pandas as pd from sklearn.cluster import MeanShift from sklearn.datasets.samples_generator import make_blobs import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D ordinates = [[2, 2, 3], [6, 7, 8], [5, 10, 13]] X, _ = make_blobs(n_samples = 120, centers = cordinates, cluster_std = 0.60)
Setting up the coordinates and generating the random data around the coordinates:
Visualizing the data points:
data_fig = plt.figure(figsize=(12, 10)) ax = data_fig.add_subplot(111, projection ='3d') ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker ='o',color ='green') plt.show()
Here we can see how the data is distributed in space. In space, we can easily say that there can be 3 clusters according to the coordinates as the inferences of the data. Now we will proceed with the mean shift to predict the cluster and define the centroids of the clusters. Sklearn provides the estimation function for bandwidth according to the data so that we don’t need to be worried about the bandwidth parameter. Importing the estimated bandwidth function.
from sklearn.cluster import estimate_bandwidth bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
Now we can define the mean shift cluster model and fit it into our data.
msc = MeanShift(bandwidth=bandwidth, bin_seeding=True) msc.fit(X) cluster_centers = msc.cluster_centers_ labels = msc.labels_ cluster_label = np.unique(labels) n_clusters = len(labels_unique) n_clusters
Here we can see it has been predicted as we have estimated there should be 3 clusters.
Visualizing the clusters:
msc_fig = plt.figure(figsize=(12, 10)) ax = msc_fig.add_subplot(111, projection ='3d') ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker ='o',color ='yellow') ax.scatter(cluster_centers[:, 0], cluster_centers[:, 1], cluster_centers[:, 2], marker ='o', color ='green', s = 300, linewidth = 5, zorder = 10) plt.title('Estimated number of clusters: %d' % n_clusters) plt.show()
Here in the green color we can see the cluster’s centroids and easily separate the data into 3 clusters. As we have discussed, it is very useful for image processing and computer vision. Next in the article, I am going to separate the colours of the images using the mean shift clustering algorithm. More formally, we can call it image segmentation using mean shift as we know that the pixel values in any image are based on the colors present in the image. Here I am using a thermograph as the image because the colours in this image are well distributed, and the number of colors is insufficient, so in the procedure, we will not get confused.
Importing the libraries.
import numpy as np from sklearn.cluster import MeanShift, estimate_bandwidth from sklearn.datasets.samples_generator import make_blobs from itertools import cycle from PIL import Image import matplotlib.pyplot as plt import matplotlib.pylab as pylab %matplotlib inline
Loading and visualizing the image using PIL and Matplotlib:
img = Image.open('/content/drive/MyDrive/Yugesh/Mean Shift Clustering Algo/Thermography_results_sm.jpg') img = np.array(image) # saving the image shape shape = img.shape # reshaping image reshape_img = np.reshape(image, [-1, 3]) #plotting the image plt.imshow(image) plt.title(img.shape)
Here we can see the image and its size as the title of the image. We have reshaped the image to flatten it so that the size of the array the model required we can get it. As we have discussed the bandwidth function of sklearn here, I am defining the bandwidth using the function.
bandwidth = estimate_bandwidth(reshape_img, quantile=0.1, n_samples=100) bandwidth
Fitting the meanshitt on reshape_img:
msc = MeanShift(bandwidth=bandwidth, bin_seeding=True) msc.fit(reshape_img)
Checking the insights of the model so that we can know what is going behind:
print("shape of labels : %d" % msc.labels_.shape) print( msc.cluster_centers_.shape) print("number of estimated clusters : %d" % len(np.unique(msc.labels_)))
Here we can see that it has generated 8 clusters which means that the image has clustered into 8 color segments—changing the shape of the labels, equivalent to the shape of the original image.
labels = msc.labels_ result_image = np.reshape(labels, shape[:2])
Let’s draw the images original and segmented.
fig = plt.figure(2, figsize=(14, 12)) ax = fig.add_subplot(121) ax = plt.imshow(img) ax = fig.add_subplot(122) ax = plt.imshow(result_image) plt.show()
Here we can see the original image and the resulting image. Using the pixel sizes of the images, we have generated the clusters using the mean shift algorithm. It has given us clusters for the image pixel values(note – the pixel values vary between 0 to 255). This is one of the easiest techniques to solve image segmentation and other image processing problems. We have seen earlier in the topic how it works and provides centroids to the data points, and also we have seen how it uses the mean shift in the KDE surface. There are various advantages of the algorithms like no effects of outliers, efficiency for complex structure datasets and no need to iterate between several clusters.
- A demo of the mean-shift clustering algorithm.
- Build your own mean shift.
- Google Colab for basic implementation code.
- Google Colab for image segmentation.
Subscribe to our NewsletterGet the latest updates and relevant offers by sharing your email.
Yugesh is a graduate in automobile engineering and worked as a data analyst intern. He completed several Data Science projects. He has a strong interest in Deep Learning and writing blogs on data science and machine learning.