Want To Simplify Neural Networks? Transform Them Into Decision Trees


The applications of deep learning are growing so vast that it has made a great headway into areas requiring heavy computing. And, the elements that make up DL models — artificial neural networks (ANN) — can be said to be the driving force behind the subject.

Neural networks can perform tasks efficiently, especially those related to classification. But, the problem lies in its understanding, meaning that neural networks are incomprehensible when it comes to their function. It is not yet ascertained as to why it performs classification so well. This article elaborates on a research paper by noted computer scientists Geoffrey Hinton and Nicholas Frosst, where they go in-depth about how a neural net can be made into a soft decision tree.

Interpretation Through Decision Trees

In the paper Distilling a Neural Network Into a Soft Decision Tree, Hinton and Prosst tell why decision trees are preferred to elucidate neural nets. “Unlike the hidden units in a neural net, a typical node at the lower levels of a decision tree is only used by a very small fraction of the training data so the lower parts of the decision tree tend to overfit unless the size of the training set is exponentially large compared with the depth of the tree.”

Hinton and Frosst use the neural network itself to train on a decision tree. If large amounts of data are encountered, poor statistical efficiency observed in decision trees is nullified by neural nets. An important point here, ‘soft decision trees’ are considered to train the neural net. Soft decision trees mean children nodes (also known as the leaf or terminal node if it ends here) in a decision tree are assigned a certain probability. Also, all leaves contribute equally to the final decision in the tree.

Mixture Of Experts and The Gradient Descent

To simplify how neural nets perform better, decision tree model in Hinton and Frosst’s study uses gradient descent and relies on a machine learning technique called ‘mixture of experts’ (equally dividing the problem space). This is to assign children nodes with a bias, a learned distribution and a weight. Altogether, this setup would form a soft decision tree model, and the leaves with the highest probability become the output here.

In order to avoid incoherent solutions while training the model, a penalty term is also included. Penalty term consists of probability distributions across the nodes. Furthermore, a  hyperparameter in the term strengthens the penalty. All in all, this penalty aspect reinforces the decreasing accuracy as the tree is descended from parent to children nodes.

Even though decision tree models fared slowly compared to neural nets in testing, the gist of comprehending a neural net is not lost.

Training On MNIST Dataset

When trained on a MNIST dataset, soft decision trees show less overfitting than neural nets. However, there is a slight hit on accuracy. Neural net with two convolutional hidden layers has a test accuracy of 99 percent while decision trees proved to be 94 percent accurate. When the same neural net is trained along ‘soft decision tree model’ as mentioned earlier, accuracy is 96 percent.

Visualisation of Soft Decision Tree on MNIST. Children nodes are learned filters whereas terminal nodes are probability distribution visualisation for various classes. (Image courtesy: Geoffrey Hinton and Nicholas Frosst)

What does soft decision tree models simplify? Decision trees rely on decisions rather than hierarchical features as evident in neural nets. After a layer or two in these networks, it is quite difficult to explain how the network behaves that way. There are many studies that have come up with possible explanations but fail to prove them with significant evidence.

Hinton and Frosst’s study tries to bring in a fresh approach towards explaining neural networks. As you can see here it is expressed in the form of a decision tree. By doing this, abstraction is made less at the processing front. If very large layers are present in the network, decision tree structure can also help with overfitting in addition to comprehending the large network. Ultimately, it is always important to check why certain models perform good or bad for specific tasks or use cases.

Download our Mobile App

Abhishek Sharma
I research and cover latest happenings in data science. My fervent interests are in latest technology and humor/comedy (an odd combination!). When I'm not busy reading on these subjects, you'll find me watching movies or playing badminton.

Subscribe to our newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day.
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Our Recent Stories

Our Upcoming Events

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox

6 IDEs Built for Rust

Rust IDEs aid efficient code development by offering features like code completion, syntax highlighting, linting, debugging tools, and code refactoring

Can OpenAI Save SoftBank? 

After a tumultuous investment spree with significant losses, will SoftBank’s plans to invest in OpenAI and other AI companies provide the boost it needs?

Oracle’s Grand Multicloud Gamble

“Cloud Should be Open,” says Larry at Oracle CloudWorld 2023, Las Vegas, recollecting his discussions with Microsoft chief Satya Nadella last week.