Earlier this month, PyTorch Tabular v0.7.0 was released on PyPI. This latest version of PyTorch Tabular aims to make deep learning with tabular data easy and accessible to real-world use cases and research. The core principle behind the library’s design includes low resistance useability, easy customisation, and easy deployment and scalability. The source code is available on GitHub.
Why PyTorch Tabular?
Despite the unreasonable effectiveness in modalities like image and text, deep learning has always lagged gradient boosting in tabular data, both in popularity and performance. But, in the last few years, newer models have been explicitly created for tabular data, pushing the performance of deep learning models. However, when it comes to popularity, there still exist challenges as there aren’t ready-to-use libraries such as Sci-Kit Learn for deep learning.
Developed by Manu Joseph, PyTorch Tabular is a new deep learning library that makes working with deep learning and tabular data easy and fast. The library has been built on frameworks PyTorch and PyTorch Lightning, and it works on pandas data frames directly. In addition, many state-of-the-art models like NODE and TabNet have already been integrated and implemented in the library with a unified API.
Besides PyTorch Tabular, other availability models include:
- FeedForward Network is a simple FF network with embedding layers for the categorical columns.
- Neural oblivious decision ensembles (NODE) for DL on tabular data was presented in ICLR 2020. On many datasets, it has beaten the well-tuned gradient boosting models.
- TabNet: Developed by Google Research, attentive interpretable tabular learning is another model which uses ‘sparse attention’ in various steps of decision making to model the outcome.
- Mixture density networks use ‘Gaussian components’ to approximate the target function and provide a probabilistic prediction.
- AutoInt: Automatic feature interaction learning using self-attentive neural networks is a model that attempts to learn interactions between the features and creates a better representation. Later, this representation is used in the downstream scenario.
- An adaptation of the transformer model for tabular data, TabTransformer creates contextual representations for categorical features.
PyTorch Tabular Design
According to the author, PyTorch Tabular is designed to make the standard modelling pipeline easy enough for practitioners and standard enough for production deployment, alongside its focus on customisation to enable wide usage in research. To satisfy these objectives, PyTorch Tabular has adopted a ‘config-driven’ approach. It includes five config files to drive the whole process: DataConfig, ModelConfig, TrainerConfig, OptimizerConfig, and ExperimentConfig. These config files are set programmatically and through YAML files, making it easy for data scientists and ML engineers.
In addition to this, PyTorch Tabular uses BaseModel. This abstract class implements the standard part of any model definition like loss and metric calculation, etc., alongside Data Module and TabularModel. PyTorch Tabular uses a Data Module to unify and standardise the data processing, and TabularModel to bring together the configs, initialise the right model, the data module, and handle the train and prediction functions with methods like ‘fit’ and ‘predict.’
Deep learning for tabular data is gaining popularity in the research community and the industry. With the rise in popularity, it becomes essential to have a unified and easy to use API for tabular data, similar to what sci-kit learn has done for classical machine learning algorithms.
PyTorch Tabular plans to reduce the entry barriers in using new SOTA deep learning model architectures and reduce the ‘engineering’ work for researchers and developers.
In the coming days, PyTorch Tabular looks to:
- Add more contributors
- Add GaussRank as Feature Transformation
- Add ability to use custom activations in CategoryEmbeddingModel
- Add differential dropouts (layer-wise) in CategoryEmbeddingModel
- Add Fourier Encoding for cyclic time variables
- Integrate Optuna Hyperparameter Tuning
- Add Text and Image Modalities for mixed modal problems
- Add Variable Importance
- Integrate SHAP for interpretability