MITB Banner

Make Machine Learning Interpretable With Shapash

Shapash is a Python library that aims to make machine learning models trustworthy for everyone by making them more transparent and easy to understand.

Share

Shapash

What makes a machine learning model trustworthy? High accuracy on the test set? Metrics like accuracy, precision and recall might instil confidence in the more data-savvy people, but for non-data people like managers, business owners and end-users, the metrics don’t carry the same weight. Let’s take the example of a doctor who has been given a machine learning model for diagnosing patients. How can he be sure that the model is reliable? Sure, he could compare his diagnosis with the model’s diagnosis for a few patients, but maybe the model only works well for a small subset of the whole population. Or worse, maybe it’s making the right decision for the wrong reasons, i.e., it’s considering the wrong features or symptoms. This is where model explicability comes in. Continuing the doctor example, if the model were to list all the symptoms and their respective contributions for every diagnosis, the doctor would easily be able to verify it and would be more likely to trust the model.

Model interpretability and explicability have been the focal point of many research papers and open source contributions. However, most of these are geared towards data practitioners and specialists. Shapash is a Python library created by the folks at MAIF that visualizes machine learning models’ decision-making process. It aims to make machine learning models trustworthy for everyone by making them more transparent and easy to understand. Shapash creates easy-to-understand visualisations of global and local explainability. It also facilitates creating a web application that can provide a lot of value to end-users and business owners.  Shapash is compatible with most sklearn, lightgbm, xgboost, catboost models and can be used for regression and classification tasks. It uses a Shap backend to calculate the local contribution of features, but this can be replaced with any other technique for calculating local contributions. Data scientists can use the Shapash explainer for exploring and debugging their models or deployed to provide visualisations with every inference.

Interpreting a DecisionTreeRegressor with Shapash

We will be creating and explaining a DecisionTreeRegressor model for the wine quality dataset available on Kaggle.

  1. Install Shapash from PyPI.

!pip install shapash

  1. Import necessary libraries and classes.
 import pandas as pd
 import random
 from sklearn.model_selection import train_test_split
 from sklearn.ensemble import RandomForestRegressor
 from shapash.explainer.smart_explainer import SmartExplainer 
  1. Load the dataset and train the model we are going to interpret.
 df = pd.read_csv("winequality_red.csv")
 y = df['quality']
 X = df.drop(['quality'], axis = 1)
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=42)
 model = RandomForestRegressor(max_depth=5, random_state=42, n_estimators=12)
 model.fit(X_train, y_train)
 y_pred = pd.DataFrame(model.predict(X_test),columns=['pred'], index=X_test.index) 
  1. Create the Shapash SmartExplainer object for the model and plot the feature importance.
 xpl = SmartExplainer()
 xpl.compile(x=X_test, model=model, y_pred=y_pred)
 xpl.plot.features_importance() 
  1. Using the selection argument, we can plot the contribution of features for a subset of the data.
 subset = random.choices(X_test.index, k =50)
 xpl.plot.features_importance(selection=subset) 
  1. The contribution_plot can be used to analyse individual features’ contribution. The plot type adjusts depending on the type of feature you are interested in, categorical or continuous, and on the type of the task, i.e., regression or classification.

xpl.plot.contribution_plot('volatile acidity')

  1.  local_plot(), compare_plot(), and filter() methods can be used to create plots for particular data items and understand the reasoning behind an inference.


You can create contribution plots for individual data items using the local_plot() method. 

xpl.plot.local_plot(index=random.choice(X_test.index))

filter() method can be used to be more specific about the features we are interested in. The argument max_contrib controls the maximum number of features to display, threshold filters the features based on a minimum value of contribution, positive toggles whether to display negative contributions or not and features_to_hide is used to list features you don’t want to display.

 xpl.filter(max_contrib=8,
            features_to_hide = None,
            threshold=.05)
 xpl.plot.local_plot(index=random.choice(X_test.index)) 

The compare_plot() method can be used to compare the contributions of different features over multiple data items.

xpl.plot.compare_plot(row_num=random.choices([i for i in range(320)], k = 5))

  1. Most of this can also be done interactively through a web application, you can create one using the run_app() method.

app = xpl.run_app()

Shapash web application

To stop the application use the kill() method.

app.kill()

The above implementation can be found here.

References

To learn more about Shapash, please refer to the following resources:

Share
Picture of Aditya Singh

Aditya Singh

A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.