The effectiveness of a machine learning model is often marred with its inability to explain its decisions to the users. To address this problem, a whole new branch of explainable AI (XAI) has emerged, and the researchers are actively pursuing different methodologies to establish a user-friendly AI.
But what about the existing XAI approaches? Are they any good? Where do they fail? To answer these questions, a team of researchers from UC Berkeley and Boston University have investigated the challenges and possible solutions. Their exploration led to a novel technique that will be discussed in the last section of this article.
Explaining The Inexplicable
To illustrate the inexplicability, one of the authors, Alvin Wan, in a blog, has used the example of saliency maps and decision trees.
Saliency maps are kind of heat maps that show pixel concentration in certain areas of an image. This information is used to gain knowledge about the rationale behind certain predictions. Saliency maps are one of the widely used XAI methods. Even though both saliency maps highlight the correct object, wrote Wan, few predictions are incorrect and answering this could help us improve the model.
Whereas, using decision trees to explain AI predictions is more traditional than saliency maps.
The above picture is a depiction of how an algorithm can work its way through branches to classify a burger. However, Wan noted that decision trees lag behind neural networks by up to 40% accuracy on image classification datasets. With low-dimensional tabular data, as shown above, the decision rules in a decision tree are simple to interpret. For example, if the dish contains a bun, then pick the right child, as shown. However, decision rules are not as straightforward for inputs like high-dimensional images.
So, to preserve the high interpretability of the decision trees and the performance of neural networks, Wan and his colleagues introduced a new approach — Neural-backed Decision Trees (NBDTs).
Overview Of Neural-Backed Decision Trees
This is not the first time that decision trees and deep learning is being used in combination but the existing methods, wrote the authors, have resulted in models that achieved lower
accuracies than that of modern neural networks even on small datasets (e.g. MNIST), and they required significantly different architectures, forcing practitioners to do a trade-off between accuracy and interpretability.
Unlike the previous methods, neural-backed decision trees proves that interpretability improves with accuracy.
When an NBDT was used to run inference on the image of a zebra, the model was able to give accurate predictions even at the intermediate level, i.e. it shows that a zebra is both an animal and an ungulate (hoofed category).
The training and inference process for a Neural-Backed Decision Tree can be broken down into four steps:
- First, a hierarchy for the decision tree called Induced Hierarchy is constructed that determines which sets of classes the NBDT must decide between.
- This hierarchy yields a particular loss function, called the Tree Supervision Loss, which is used to train the original neural network, without any modifications.
- A sample is passed through the neural network backbone for inference.
- Inference is completed by running the final fully-connected layer as a sequence of decision rules, which are called Embedded Decision Rules. These decisions culminate in the final prediction.
The results show narrowing the accuracy gap between neural networks and decision trees to 1% on CIFAR10, CIFAR100, TinyImageNet and to 2% on ImageNet; advancing the state-of-the-art for interpretable methods by ∼14% on ImageNet to 75.30% top-1 accuracy.
Key Findings
The whole work can be summarised as follows:
- Neural-Backed Decision Trees helped remove the dilemma between accuracy and interpretability. It has been found that interpretability improves with accuracy
- Any classification neural network can be converted into an NBDT
- As a fortuitous side effect, the tree supervision loss also boosts the original neural network accuracy by 0.5%
Read more about NBDTs here.