Beginner’s Guide To Machine Learning With Apache Spark

Pyspark is a data analysis tool created by the Apache Spark community for using Python and Spark. It allows you to work with Resilient Distributed Dataset(RDD) and DataFrames in python.

Spark is known as a fast, easy to use and general engine for big data processing. A distributed computing engine is used to process and analyse large amounts of data, just like Hadoop MapReduce. It is quite faster than the other processing engines when it comes to data handling from various platforms. In the industry, there is a big demand for engines that can process tasks like the above. Today or later, your company or client will be asked to develop sophisticated models that would enable you to discover a new opportunity or risk associated with it, and this all can be done with Pyspark. It is not hard to learn Python and SQL; it is easy to start with it.

Pyspark is a data analysis tool created by the Apache Spark community for using Python and Spark. It allows you to work with Resilient Distributed Dataset(RDD) and DataFrames in python. Pyspark has numerous features that make it easy, and an amazing framework for machine learning MLlib is there. When it comes to huge amounts of data, pyspark provides you with fast and real-time processing, flexibility, in-memory computation and various other features. In simple words, it is a Python-based library that gives a channel to use spark, which combines the simplicity of Python and the efficiency of spark.

Let’s take the brief information about the architecture of PySpark from the official  documentation;   

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

As it not allow you to write applications using python API’s but also provides a PySpark shell for interactively analyzing your data in a distributed environment. PySpark supports most Spark features such as Spark SQL, Data Frame, Streaming, MLlib for machine learning, and spark core.  

Let’s take a look at that one by one.

Spark SQL and DataFrame:

It is a  module for structured data processing. It gives an abstraction called DataFrame, and it can also be used as a SQL query engine. 


MLlib is a high-level machine learning library that provides a set of API’s that helps users to create and tune practical machine learning models; it has supported nearly all the algorithms, including classification, regression, collaborative filtering and so on.


With the help of the streaming feature, we can process the real-time data from various sources and this processed data can be pushed into system files, databases or even to the live dashboard.

Spark Core:

Spark Core is the base of the whole project. It works on specialized data structured called Resilient Distributed Dataset RDD in short and in-memory computing capabilities.

Today we will be focusing only on the MLlib and common data handling techniques using spark, and lastly, we will build a Logistic Regression model using spark, and also I will demonstrate how to do hypothesis testing.

Code Implementation Machine Learning With Apache Spark

The following code implementation is in reference to the official implementation.

Import all dependencies: 
 from pyspark.sql import SparkSession
 from import OneHotEncoder, StringIndexer, VectorAssembler
 from import Pipeline
 from import LogisticRegression
 from import BinaryClassificationEvaluator
 import pandas as pd
 import numpy as np
 import matplotlib.pyplot as plt 
Explore the data:

The dataset is taken from the kaggle repository, and it is related to Advertisement, i.e. we need to find which kind of user is more likely to click on the ad.

 spark = SparkSession.builder.appName('Logistic Regression').getOrCreate()
 df ='advertising.csv', header=True, inferSchema=True)

The input features are:- Daily Time Spent on Site, Age, Area Income, Daily Internet Usage, Ad Topic Line, City, Male, Country.

The output variable:- Clicked on Ad. 

Timestamps are not a relevant feature for our analysis; that’s why we are not considering them.

Let’s take a summary and correlation plot of our dataset;

 Num_features = [i[0] for i in df.dtypes if i[1] == 'int'or'float'] 
 Num_data =
 plot = pd.plotting.scatter_matrix(Num_data, figsize=(15,15)) 
Preparing data for ML:

From the above correlation graph, we can see no multicollinearity associated with any features, so we take all the features for further modeling. The preparation includes Categorical indexing, One hot encoding for Categorical features and Vector Assembler, which merges multiple columns into vector columns. 

 cat_Columns = ['Ad Topic Line','City','Country']
 stage = []
 # for indexing and Encoding categorical columns
 for Col in cat_Columns:
     String_indexer = StringIndexer(inputCol = Col, outputCol = Col + 'Index')
     encode = OneHotEncoder(inputCols=[String_indexer.getOutputCol()], outputCols=[Col + "classVec"])
     stage += [String_indexer, encode]
 label_string = String_indexer(inputCol = 'Clicked on Ad', outputCol = 'label')
 stage += [label_string]
 # assembling all the features 
 numeri_Col = ['Daily Time Spent on Site','Age','Area Income','Daily Internet Usage','Male']
 assembler_inputs = [c + "classVec" for c in cat_Columns] + numeri_Col
 assemble = VectorAssembler(inputCols=assembler_inputs, outputCol="features")
 stage += [assemble] 

The pipeline is used to chain the multiple transformers we used above and to avoid data leakage. 

 pipe = Pipeline(stages = stages)
 pipe_model =
 df = pipe_model.transform(df)
 selected_cols = ['label', 'features'] + cols
 df =
Train test split:
 train_, test_ = df.randomSplit([0.7, 0.3], seed = 2000)
 print("Training Count: " + str(train_.count()))
 print("Test Count: " + str(test_.count()))
 Total training count:- 683
 Total testing count:- 313 
Load and fit the Logistic regression Model:
 l_r = LogisticRegression(featuresCol = 'features', labelCol = 'label', maxIter=15)
 lr_model =
 prediction = lr_model.transform(test_) 
Lets Plot some Evaluation metrics like ROC and Recall Curve:
 training_summary = lr_model.summary
 roc = training_summary.roc.toPandas()
 plt.ylabel('False +ve Rate')
 plt.xlabel('True +ve Rate')
 plt.title('ROC Curve')
 print('Training ROC: ' + str(training_summary.areaUnderROC))
 recall =
 evaluator = BinaryClassificationEvaluator()
 print('Test ROC', evaluator.evaluate(prediction)) 


Test ROC:- 0.93

Example for hypothesis testing:
 from pyspark.mllib.linalg import Matrices, Vectors
 from pyspark.mllib.regression import LabeledPoint
 from pyspark.mllib.stat import Statistics
 from pyspark import SparkContext 
 if __name__ == '__main__':
   sc = SparkContext(appName='Hypothesis_Testing')
   # vector composed of the frequencies of event
   vec = Vectors.dense(0.1,0.15,0.20,0.25,0.30)
   # compute a goodness of fit. If second vector to test against
   # is not supplied as a parameter, the test runs against a uniform distribution.
   goodness_result = Statistics.chiSqTest(vec)
   ## result includes p-value, degree of freedom, test statistics, mtd used
   # and null hypothesis
   mat = Matrices.dense(3, 2, [1.0, 3.0, 5.0, 2.0, 4.0, 6.0])
   # Pearson's Independent test 
   independenceTestResult = Statistics.chiSqTest(mat)
   print('%s\n'% independence_result)
   obs = sc.parallelize(
         [LabeledPoint(1.0, [1.0, 0.0, 3.0]),
          LabeledPoint(1.0, [1.0, 2.0, 0.0]),
          LabeledPoint(1.0, [-1.0, 0.0, -0.5])])  # LabeledPoint(label, feature)
   # The contingency table is constructed from an RDD of LabeledPoint and used to conduct
   # the independence test. Returns an array containing the ChiSquaredTestResult for every feature
   # against the label.
   feature_results = Statistics.chiSqTest(obs)
   for i, result in enumerate(feature_results):
       print("Column %d:\n%s" % (i + 1, result))


From this article, we have seen the overview of Spark and its functionality. Then, in more detail, we have learned how to handle CSV files using Pyspark API, plot the correlation using the obtained dataset, and prepare the dataset so that the algorithm can handle pipeline building, model building, and evaluating the model’s performance. Lastly, we have taken examples of how to conduct the hypothesis testing using the ChiSquare Contingency test. More examples for ML algorithms are included in the Notebook.  


Vijaysinh Lendave
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox