What makes a machine learning model trustworthy? High accuracy on the test set? Metrics like accuracy, precision and recall might instil confidence in the more data-savvy people, but for non-data people like managers, business owners and end-users, the metrics don’t carry the same weight. Let’s take the example of a doctor who has been given a machine learning model for diagnosing patients. How can he be sure that the model is reliable? Sure, he could compare his diagnosis with the model’s diagnosis for a few patients, but maybe the model only works well for a small subset of the whole population. Or worse, maybe it’s making the right decision for the wrong reasons, i.e., it’s considering the wrong features or symptoms. This is where model explicability comes in. Continuing the doctor example, if the model were to list all the symptoms and their respective contributions for every diagnosis, the doctor would easily be able to verify it and would be more likely to trust the model.
Model interpretability and explicability have been the focal point of many research papers and open source contributions. However, most of these are geared towards data practitioners and specialists. Shapash is a Python library created by the folks at MAIF that visualizes machine learning models’ decision-making process. It aims to make machine learning models trustworthy for everyone by making them more transparent and easy to understand. Shapash creates easy-to-understand visualisations of global and local explainability. It also facilitates creating a web application that can provide a lot of value to end-users and business owners. Shapash is compatible with most sklearn, lightgbm, xgboost, catboost models and can be used for regression and classification tasks. It uses a Shap backend to calculate the local contribution of features, but this can be replaced with any other technique for calculating local contributions. Data scientists can use the Shapash explainer for exploring and debugging their models or deployed to provide visualisations with every inference.
Interpreting a DecisionTreeRegressor with Shapash
- Install Shapash from PyPI.
!pip install shapash
- Import necessary libraries and classes.
import pandas as pd import random from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestRegressor from shapash.explainer.smart_explainer import SmartExplainer
- Load the dataset and train the model we are going to interpret.
df = pd.read_csv("winequality_red.csv") y = df['quality'] X = df.drop(['quality'], axis = 1) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=42) model = RandomForestRegressor(max_depth=5, random_state=42, n_estimators=12) model.fit(X_train, y_train) y_pred = pd.DataFrame(model.predict(X_test),columns=['pred'], index=X_test.index)
- Create the Shapash
SmartExplainerobject for the model and plot the feature importance.
xpl = SmartExplainer() xpl.compile(x=X_test, model=model, y_pred=y_pred) xpl.plot.features_importance()
- Using the
selectionargument, we can plot the contribution of features for a subset of the data.
subset = random.choices(X_test.index, k =50) xpl.plot.features_importance(selection=subset)
contribution_plotcan be used to analyse individual features’ contribution. The plot type adjusts depending on the type of feature you are interested in, categorical or continuous, and on the type of the task, i.e., regression or classification.
filter()methods can be used to create plots for particular data items and understand the reasoning behind an inference.
You can create contribution plots for individual data items using the
filter() method can be used to be more specific about the features we are interested in. The argument
max_contrib controls the maximum number of features to display,
threshold filters the features based on a minimum value of contribution,
positive toggles whether to display negative contributions or not and
features_to_hide is used to list features you don’t want to display.
xpl.filter(max_contrib=8, features_to_hide = None, threshold=.05) xpl.plot.local_plot(index=random.choice(X_test.index))
compare_plot() method can be used to compare the contributions of different features over multiple data items.
xpl.plot.compare_plot(row_num=random.choices([i for i in range(320)], k = 5))
- Most of this can also be done interactively through a web application, you can create one using the
app = xpl.run_app()
To stop the application use the
The above implementation can be found here.
To learn more about Shapash, please refer to the following resources:
Subscribe to our NewsletterGet the latest updates and relevant offers by sharing your email.
A machine learning enthusiast with a knack for finding patterns. In my free time, I like to delve into the world of non-fiction books and video essays.