What makes a machine learning model trustworthy? High accuracy on the test set? Metrics like accuracy, precision and recall might instil confidence in the more data-savvy people, but for non-data people like managers, business owners and end-users, the metrics don’t carry the same weight. Let’s take the example of a doctor who has been given a machine learning model for diagnosing patients. How can he be sure that the model is reliable? Sure, he could compare his diagnosis with the model’s diagnosis for a few patients, but maybe the model only works well for a small subset of the whole population. Or worse, maybe it’s making the right decision for the wrong reasons, i.e., it’s considering the wrong features or symptoms. This is where model explicability comes in. Continuing the doctor example, if the model were to list all the symptoms and their respective contributions for every diagnosis, the doctor would easily be able to verify it and would be more likely to trust the model.
Model interpretability and explicability have been the focal point of many research papers and open source contributions. However, most of these are geared towards data practitioners and specialists. Shapash is a Python library created by the folks at MAIF that visualizes machine learning models’ decision-making process. It aims to make machine learning models trustworthy for everyone by making them more transparent and easy to understand. Shapash creates easy-to-understand visualisations of global and local explainability. It also facilitates creating a web application that can provide a lot of value to end-users and business owners. Shapash is compatible with most sklearn, lightgbm, xgboost, catboost models and can be used for regression and classification tasks. It uses a Shap backend to calculate the local contribution of features, but this can be replaced with any other technique for calculating local contributions. Data scientists can use the Shapash explainer for exploring and debugging their models or deployed to provide visualisations with every inference.
Interpreting a DecisionTreeRegressor with Shapash
We will be creating and explaining a DecisionTreeRegressor model for the wine quality dataset available on Kaggle.
- Install Shapash from PyPI.
!pip install shapash
- Import necessary libraries and classes.
import pandas as pd import random from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestRegressor from shapash.explainer.smart_explainer import SmartExplainer
- Load the dataset and train the model we are going to interpret.
df = pd.read_csv("winequality_red.csv") y = df['quality'] X = df.drop(['quality'], axis = 1) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=42) model = RandomForestRegressor(max_depth=5, random_state=42, n_estimators=12) model.fit(X_train, y_train) y_pred = pd.DataFrame(model.predict(X_test),columns=['pred'], index=X_test.index)
- Create the Shapash
SmartExplainer
object for the model and plot the feature importance.
xpl = SmartExplainer() xpl.compile(x=X_test, model=model, y_pred=y_pred) xpl.plot.features_importance()
- Using the
selection
argument, we can plot the contribution of features for a subset of the data.
subset = random.choices(X_test.index, k =50) xpl.plot.features_importance(selection=subset)
- The
contribution_plot
can be used to analyse individual features’ contribution. The plot type adjusts depending on the type of feature you are interested in, categorical or continuous, and on the type of the task, i.e., regression or classification.
xpl.plot.contribution_plot('volatile acidity')
-
local_plot()
,compare_plot()
, andfilter()
methods can be used to create plots for particular data items and understand the reasoning behind an inference.
You can create contribution plots for individual data items using the local_plot()
method.
xpl.plot.local_plot(index=random.choice(X_test.index))
filter()
method can be used to be more specific about the features we are interested in. The argument max_contrib
controls the maximum number of features to display, threshold
filters the features based on a minimum value of contribution, positive
toggles whether to display negative contributions or not and features_to_hide
is used to list features you don’t want to display.
xpl.filter(max_contrib=8, features_to_hide = None, threshold=.05) xpl.plot.local_plot(index=random.choice(X_test.index))
The compare_plot()
method can be used to compare the contributions of different features over multiple data items.
xpl.plot.compare_plot(row_num=random.choices([i for i in range(320)], k = 5))
- Most of this can also be done interactively through a web application, you can create one using the
run_app()
method.
app = xpl.run_app()
To stop the application use the kill()
method.
app.kill()
The above implementation can be found here.
References
To learn more about Shapash, please refer to the following resources: