Now Reading
Make Machine Learning Interpretable With Shapash

Make Machine Learning Interpretable With Shapash


What makes a machine learning model trustworthy? High accuracy on the test set? Metrics like accuracy, precision and recall might instil confidence in the more data-savvy people, but for non-data people like managers, business owners and end-users, the metrics don’t carry the same weight. Let’s take the example of a doctor who has been given a machine learning model for diagnosing patients. How can he be sure that the model is reliable? Sure, he could compare his diagnosis with the model’s diagnosis for a few patients, but maybe the model only works well for a small subset of the whole population. Or worse, maybe it’s making the right decision for the wrong reasons, i.e., it’s considering the wrong features or symptoms. This is where model explicability comes in. Continuing the doctor example, if the model were to list all the symptoms and their respective contributions for every diagnosis, the doctor would easily be able to verify it and would be more likely to trust the model.

Model interpretability and explicability have been the focal point of many research papers and open source contributions. However, most of these are geared towards data practitioners and specialists. Shapash is a Python library created by the folks at MAIF that visualizes machine learning models’ decision-making process. It aims to make machine learning models trustworthy for everyone by making them more transparent and easy to understand. Shapash creates easy-to-understand visualisations of global and local explainability. It also facilitates creating a web application that can provide a lot of value to end-users and business owners.  Shapash is compatible with most sklearn, lightgbm, xgboost, catboost models and can be used for regression and classification tasks. It uses a Shap backend to calculate the local contribution of features, but this can be replaced with any other technique for calculating local contributions. Data scientists can use the Shapash explainer for exploring and debugging their models or deployed to provide visualisations with every inference.

Register for FREE Workshop on Data Engineering>>

Interpreting a DecisionTreeRegressor with Shapash

We will be creating and explaining a DecisionTreeRegressor model for the wine quality dataset available on Kaggle.

  1. Install Shapash from PyPI.

!pip install shapash

  1. Import necessary libraries and classes.
 import pandas as pd
 import random
 from sklearn.model_selection import train_test_split
 from sklearn.ensemble import RandomForestRegressor
 from shapash.explainer.smart_explainer import SmartExplainer 
  1. Load the dataset and train the model we are going to interpret.
 df = pd.read_csv("winequality_red.csv")
 y = df['quality']
 X = df.drop(['quality'], axis = 1)
 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=42)
 model = RandomForestRegressor(max_depth=5, random_state=42, n_estimators=12), y_train)
 y_pred = pd.DataFrame(model.predict(X_test),columns=['pred'], index=X_test.index) 
  1. Create the Shapash SmartExplainer object for the model and plot the feature importance.
 xpl = SmartExplainer()
 xpl.compile(x=X_test, model=model, y_pred=y_pred)
  1. Using the selection argument, we can plot the contribution of features for a subset of the data.
 subset = random.choices(X_test.index, k =50)
  1. The contribution_plot can be used to analyse individual features’ contribution. The plot type adjusts depending on the type of feature you are interested in, categorical or continuous, and on the type of the task, i.e., regression or classification.

xpl.plot.contribution_plot('volatile acidity')

  1.  local_plot(), compare_plot(), and filter() methods can be used to create plots for particular data items and understand the reasoning behind an inference.

You can create contribution plots for individual data items using the local_plot() method. 


filter() method can be used to be more specific about the features we are interested in. The argument max_contrib controls the maximum number of features to display, threshold filters the features based on a minimum value of contribution, positive toggles whether to display negative contributions or not and features_to_hide is used to list features you don’t want to display.

            features_to_hide = None,

The compare_plot() method can be used to compare the contributions of different features over multiple data items.

xpl.plot.compare_plot(row_num=random.choices([i for i in range(320)], k = 5))

See Also
Avalanche cover art

  1. Most of this can also be done interactively through a web application, you can create one using the run_app() method.

app = xpl.run_app()

Shapash web application

To stop the application use the kill() method.


The above implementation can be found here.


To learn more about Shapash, please refer to the following resources:

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.
Join our Telegram Group. Be part of an engaging community

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top