Hands-On Tutorial on Mean Shift Clustering Algorithm

Mean shift is based on the idea of KDE. We can make the points climb uphill to the nearest peak on the KDE surface. Mean shift is based on the idea of KDE, but what makes it different is that using the bandwidth parameter. We can make the points climb uphill to the nearest peak on the KDE surface. So, iteratively shifting each point to climb uphill to the peak.

Mean shift clustering algorithm is a centroid-based algorithm that helps in various use cases of unsupervised learning. It is one of the best algorithms to be used in image processing and computer vision. It works by shifting data points towards centroids to be the mean of other points in the region. It is also known as the mode seeking algorithm. The algorithm’s advantage is that it assigns clusters to the data without automatically defining the number of clusters based on defined bandwidth. 

Kernel Density Estimation

Like other clustering algorithms, Mean shift is based on the concept of Kernel Density Estimation(KDE), which is a way to estimate the probability density function of a random variable. KDE is a problem where the inferences of the population are made by data smoothing. It works by providing weights to each data point. The weight function is called a kernel. There are many kinds of kernels, one kind of kernel is the Gaussian kernel. Adding all those kernels together creates a density function(probability surface). The resultant density function variation depends on the used bandwidth parameter. 


Sign up for your weekly dose of what's up in emerging technology.

In the image below, we can see the distribution of some data points in a surface plot.

Image source

And in the image below, we can see the KDE surface where our data points are distributed in the surface plot(first image). The hills can be considered as the kernel.

Image source

In the contour plot of the KDE surface, we can see the exact smoothing of our data points.

Image source

From the images, we can understand how the KDE works in smoothing the data sets to make inferences from the data points. As the size of circles in the plot decreases, the density of the data point increases, which means most of the points in the kernel are trying to be on the small circle where the mean shift comes into the picture, which tries to increase or decrease the density function.

Mean shift

Mean shift is based on the idea of KDE, but what makes it different is that using the bandwidth parameter. We can make the points climb uphill to the nearest peak on the KDE surface. So, iteratively shifting each point to climb uphill to the peak.

The bandwidth parameter used to make the KDE surface varies on the different sizes. For example, we have a tall skinny kernel which means a small kernel bandwidth and in a case where the size of the kernel is short and fat, which means a large kernel bandwidth. A small kernel bandwidth makes the KDE surface hold the peak for every data point more formally, saying each point has its cluster; on the other hand, large kernel bandwidth results in fewer kernels or fewer clusters.

Image source

Here we can see the formation of kernels with bandwidth values is equal to two.

Image source

In the image, we can see what happens when the bandwidth value is low.

Let’s consider a kernel function ????(xi – x) gives the weight to nearby points for defining the mean. So the weighted mean of the density in a window calculation is determined by.

Image source.

Where N(x) is the neighbourhood of x.

The value of m(x) – x is called the mean shift.

 As discussed before, from the mathematical formula, we can understand that the mean shift tries to shift the point, and when performed iteratively, it will move to the KDE peak. 

Basically, in the whole algorithm, after making a copy of data points, those copied points are shifted against the original copy to reach the peak of its kernel surface. Next in the article, we will see how we can implement the algorithm using python with randomly generated data points to find out the clusters according to the size and bandwidth parameter.

Implementations in Python

Importing the libraries:

import numpy as np
import pandas as pd
from sklearn.cluster import MeanShift
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

ordinates = [[2, 2, 3], [6, 7, 8], [5, 10, 13]]

X, _ = make_blobs(n_samples = 120, centers = cordinates,

cluster_std = 0.60)

Setting up the coordinates and generating the random data around the coordinates:

Visualizing the data points:

data_fig = plt.figure(figsize=(12, 10))  
ax = data_fig.add_subplot(111, projection ='3d')  
ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker ='o',color ='green')  


Here we can see how the data is distributed in space. In space, we can easily say that there can be 3 clusters according to the coordinates as the inferences of the data. Now we will proceed with the mean shift to predict the cluster and define the centroids of the clusters. Sklearn provides the estimation function for bandwidth according to the data so that we don’t need to be worried about the bandwidth parameter. Importing the estimated bandwidth function.

Importing libraries:

from sklearn.cluster import  estimate_bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

Now we can define the mean shift cluster model and fit it into our data.

msc = MeanShift(bandwidth=bandwidth, bin_seeding=True)
cluster_centers = msc.cluster_centers_
labels = msc.labels_
cluster_label = np.unique(labels)
n_clusters = len(labels_unique)


Here we can see it has been predicted as we have estimated there should be 3 clusters.

Visualizing the clusters:

msc_fig = plt.figure(figsize=(12, 10))

ax = msc_fig.add_subplot(111, projection ='3d')

ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker ='o',color ='yellow')

ax.scatter(cluster_centers[:, 0], cluster_centers[:, 1],
           cluster_centers[:, 2], marker ='o', color ='green',
          s = 300, linewidth = 5, zorder = 10)
plt.title('Estimated number of clusters: %d' % n_clusters)  


Here in the green color we can see the cluster’s centroids and easily separate the data into 3 clusters. As we have discussed, it is very useful for image processing and computer vision. Next in the article, I am going to separate the colours of the images using the mean shift clustering algorithm. More formally, we can call it image segmentation using mean shift as we know that the pixel values in any image are based on the colors present in the image. Here I am using a thermograph as the image because the colours in this image are well distributed, and the number of colors is insufficient, so in the procedure, we will not get confused.

 Importing the libraries.

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
from itertools import cycle
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
%matplotlib inline

 Loading and visualizing the image using PIL and Matplotlib:


img = Image.open('/content/drive/MyDrive/Yugesh/Mean Shift Clustering Algo/Thermography_results_sm.jpg')

img = np.array(image)

# saving the image shape
shape = img.shape

# reshaping image 
reshape_img = np.reshape(image, [-1, 3])

#plotting the image


Here we can see the image and its size as the title of the image. We have reshaped the image to flatten it so that the size of the array the model required we can get it. As we have discussed the bandwidth function of sklearn here, I am defining the bandwidth using the function.


bandwidth = estimate_bandwidth(reshape_img, quantile=0.1, n_samples=100)


Fitting the meanshitt on reshape_img:

msc = MeanShift(bandwidth=bandwidth, bin_seeding=True)


Checking the insights of the model so that we can know what is going behind:

print("shape of labels : %d" % msc.labels_.shape)
print( msc.cluster_centers_.shape)
print("number of estimated clusters : %d" % len(np.unique(msc.labels_)))


Here we can see that it has generated 8 clusters which means that the image has clustered into 8 color segments—changing the shape of the labels, equivalent to the shape of the original image. 

labels = msc.labels_
result_image = np.reshape(labels, shape[:2])

Let’s draw the images original and segmented.

fig = plt.figure(2, figsize=(14, 12))
ax = fig.add_subplot(121)
ax = plt.imshow(img) 
ax = fig.add_subplot(122)
ax = plt.imshow(result_image)  


Here we can see the original image and the resulting image. Using the pixel sizes of the images, we have generated the clusters using the mean shift algorithm. It has given us clusters for the image pixel values(note – the pixel values vary between 0 to 255). This is one of the easiest techniques to solve image segmentation and other image processing problems. We have seen earlier in the topic how it works and provides centroids to the data points, and also we have seen how it uses the mean shift in the KDE surface. There are various advantages of the algorithms like no effects of outliers, efficiency for complex structure datasets and no need to iterate between several clusters. 


More Great AIM Stories

Yugesh Verma
Yugesh is a graduate in automobile engineering and worked as a data analyst intern. He completed several Data Science projects. He has a strong interest in Deep Learning and writing blogs on data science and machine learning.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM