So far, we have seen various applications of neural networks; most of the tasks performed by neural networks are based on unstructured data like Images, videos, text, audio files, and neural networks did a pretty good job on them. However, most of the business decisions are made with the help of tabular data. It has been observed that the performance of neural networks on tabular data is not up to the mark. Popular models from ensemble techniques like Random Forest, GradientBoosted, AdaBoost, XG-Boost outperforms the neural network for tabular data because it provides good interpretability.
A few days back, a naval architecture was launched, ‘XBNet’, which stands for ‘Extremely Boosted Neural Network’, which combines gradient boosted tree with a feed-forward neural network, making the model robust for all performance metrics. In this approach, trees are being trained in every layer of the architecture, and feature importance given by the trees and weight determined by gradient descent is used to adjust the weights of layers where trees are trained.
XBNet takes raw tabular data as input, and the model is being trained using an optimization technique called Boosted Gradient Descent which is initialized with the help of feature importance of a gradient boosted trees further it updates the weights of each layer in the neural network in two steps as below:
- Updates Weight by Gradient descent.
- Updates weights by using feature importance of gradient boosted trees.
Before moving further, let’s briefly summarize ‘Boosted gradient descent’ and ‘Feature importance in trees’.
What is boosted gradient descent?
Gradient boosting is an ML technique for regression, classification and other tasks which produces prediction models in the form of an ensemble of weak prediction models like decision trees. When a decision tree is a weak learner, the resulting algorithm is called gradient boosted trees which usually outperforms Random forest. It builds the model stage-wise as other boosting methods do, and it generalizes them by optimising an arbitrary differentiable loss function.
What is the feature importance in trees?
Generally, feature importance provides a score indicating how useful or valuable each feature was in constructing the boosted decision tree within the model. The importance is calculated explicitly for each attribute in the dataset allowing attributes to be ranked and compared with each other. To know more details about feature importance, click here.
The feature importance of gradient boosted trees is determined by information gain of the tree features, which gives the idea to determine which attribute in a given set of features is most useful. That is used to distinguish the classes that are being learned. The information gain is calculated with the help of entropy. Entropy is used to calculate the homogeneity of a sample. The entropy and information gain are calculated as below; (all the formulas are taken from the official research paper)
Let P be a probability distribution such that-
P = (p1,p2,p3….pn)
where pn is the probability of a data point that belongs to a subset di of the dataset
Later on, calculated information gain is used to determine the feature importance of boosted trees.
Training and Optimization:
The XBNet architecture creates a sequential structure of layers with input and output layers. The feature importance of the gradient boosted tree are trained at the time initialization of a model. As shown in the above picture, the gradient boosted tree is connected to each layer.
While training, the data that is fed at input completes forward and backward propagation and weights of the layer get updated according to gradient descent once. Then, before moving towards the next epoch of training, it goes through all the layers and updates its weight again according to the feature importance of the gradient boosted tree.
To ensure proper balance between weights given by the gradient descent and feature importance, weights given by the feature importance are scaled down to the same power as that of the gradient descent algorithm. This is because, after some epochs, the feature importance remains in the same order due to its nature, but several orders decrease the weights provided by gradient descent.
This architecture’s major and unique highlight is that the layers’ weights depend on the gradient descent algorithm and the feature importance of gradient boosted trees. This, in turn, boosts the performance of architecture.
Here we will compare the performance of XBNet and custom neural networks maintaining the same training parameters.
Install the architecture using pip as below
! pip install --upgrade git+https://github.com/tusharsarkar3/XBNet.git
Import all the dependencies:
import torch import numpy as np from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris from XBNet.training_utils import training,predict from XBNet.models import XBNETClassifier from XBNet.run import run_XBNET
Set the input output features and train test split
x = data.data y = data.target x_train,x_test,y_train,y_test = train_test_split(x,y,test_size= 0.3, random_state= True)
Initialize the architecture with training data; while initializing you need to set input-output dimensions of each layer here. I have set the number of layers as two, so I need to set the dimension manually. Don’t worry; it is pretty straightforward; you will be prompted to do so, as shown below.
model = XBNETClassifier(x_train,y_train,num_layers=2) Output: Enter dimensions of linear layers: Enter input dimensions of layer 1: 10 Enter output dimensions of layer 1: 10 Set bias as True or False: False Enter input dimensions of layer 2: 10 Enter output dimensions of layer 2: 10 Set bias as True or False: False Enter your last layer 1. Sigmoid 2. Softmax 3. None 3
Set the loss function and optimizer.
criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
Run the architecture using run_XBNET
m,acc, lo, val_ac, val_lo = run_XBNET(x_train,x_test,y_train,y_test,model,criterion,optimizer,epochs=100,batch_size=32)
Classification report for training and validation, respectively.
precision recall f1-score support 0 1.00 1.00 1.00 36 1 1.00 0.91 0.95 32 2 0.93 1.00 0.96 37 accuracy 0.97 105 macro avg 0.97 0.97 0.97 105 weighted avg 0.97 0.97 0.97 105 precision recall f1-score support 0 1.00 1.00 1.00 14 1 1.00 1.00 1.00 18 2 1.00 1.00 1.00 13 accuracy 1.00 45 macro avg 1.00 1.00 1.00 45 weighted avg 1.00 1.00 1.00 45
The plot of accuracy and loss-
plt.figure(figsize=(10,5)) plt.subplot(1,2,1) plt.plot(acc,label='training accuracy') plt.plot(val_ac,label = 'validation accuracy') plt.xlabel('epochs') plt.ylabel('accuracy') plt.legend() plt.subplot(1,2,2) plt.plot(lo,label='training loss') plt.plot(val_lo,label = 'validation loss') plt.xlabel('epochs') plt.ylabel('loss') plt.legend()
Custom neural network:
from keras.models import Sequential from keras.layers import Dense,Conv1D,Flatten from keras.optimizers import Adam model=Sequential() model.add(Conv1D(30,3, input_shape=(4,1), activation='relu')) model.add(Dense(10,activation='relu')) model.add(Flatten()) model.add(Dense(3,activation='softmax')) model.compile(loss='categorical_crossentropy',optimizer=Adam(),metrics=['accuracy']) history = model.fit(x_train,y_train,epochs=100,batch_size=32,validation_split=0.2)
The plot of accuracy and loss –
Some test results from the official research paper
As I mentioned before, the main highlight of this architecture is that weight distribution through the layer and the way it is being distributed, i.e., maintaining the balance between feature importance and gradient descent, has shown the result extremely well on training and validation data when compared to the custom neural network. The performance, interpretability and scalability of this architecture have set up a new benchmark.
Note: Major of the content of this article is taken from research paper
Subscribe to our NewsletterGet the latest updates and relevant offers by sharing your email.
Vijaysinh is an enthusiast in machine learning and deep learning. He is skilled in ML algorithms, data manipulation, handling and visualization, model building.