Listen to this story
Google open-sourced TensorFlow in 2015. Major organisations like Uber, Airbnb, NASA, GE Healthcare, and Twitter co-opted TensorFlow to optimise their operations. The open-source Python library is mainly used for the training and inference of deep neural networks. For example, GE Healthcare uses TensorFlow to increase the speed and accuracy of MRIs in identifying specific body parts.
Meta AI released PyTorch, an open-source machine learning platform, in 2016. PyTorch allows quicker prototyping than TensorFlow. In addition, it is more tightly integrated with the Python ecosystem than TensorFlow, and the debugging experience is much simpler. PyTorch can be debugged using one of the many widely available Python debugging tools. However, in the case of TensorFlow, developers have to learn tfdbg (TensorFlow debugger) to evaluate TensorFlow expressions at runtime.
Moreover, Pytorch offers an extensive suite of tools, libraries, pre-trained models, and datasets for each stage of development, enabling the developer community to create and deploy AI innovations at scale quickly.
Further, the flaws in TensorFlow helped PyTorch’s rise. Although TensorFlow upgraded the binaries, the latest version of TensorFlow 2.9.0 will not work on CUDA 11.7 (the latest version).
JAX, the rising star
In 2020, DeepMind Technologies, a subsidiary of Alphabet Inc, said it’s using JAX to accelerate AI/ML research, triggering discussions around the future of TensorFlow.
Just After Execution or JAX is a relatively new machine learning framework developed by Google Research Teams for high-performance numerical computing. The Python library’s API is based on NumPy, a collection of functions used in scientific computing.
JAX includes an extensible system of composable function transformations like Differentiation, Vectorisation and JIT-compilation to support machine learning research. ML researchers often apply a single function to lots of data. The automatic vectorisation feature in JAX simplifies it. Further, the NumPy-based API and JIT compilation allow researchers to easily scale high-performance computing to one or many accelerators.
JAX offers a more straightforward design for handling the most complex problems in machine learning. It is highly optimised and supports large-scale data parallelism that permits the distribution of complex technical problems across multiple TPUs rather than running individual pieces of code in distinct chips. Using JAX, you can get access to as many TPUs. Further, in JAX, the code written for one project can be reused for other projects.
Is JAX overshadowing TensorFlow?
JAX is flexible, easy to adopt, and enables rapid experimentation with novel algorithms and architectures. However, JAX faces several challenges. First, JAX doesn’t have an option to load data and pre-process data easily and requires either TensorFlow or PyTorch to handle much of the setup. Second, JAX does not permit higher-level model abstraction and lacks deployment portability. Third, the JAX library is not rich, and researchers need to write the entire code from scratch for most projects.
On the other hand, TensorFlow has built an extensive infrastructure over the years – be it open source projects, pre-trained models, tutorials, higher-level abstractions (via Keras), and portability to deployment destinations. In May 2022, Google released TensorFlow 2.9 with several new and improved features:
The latest version has integrated the oneDNN performance library with TensorFlow to achieve better performance on Intel CPUs.
DTensor, a new TensorFlow API for distributed model processing, can move from data parallelism to model parallelism seamlessly.
TraceType for tf.function makes it easy to understand retracing rules.
A new experimental version of the Keras Optimiser API provides a more unified and expanded catalogue of built-in optimisers.
Improved determinism performance
New support for Windows WSL2.
“JAX is a high-performance framework for general numerical/scientific computation (e.g. matrix operations, etc.) at scale. Since it is more than just a deep learning framework, it provides much more flexibility compared to TensorFlow through both high-level as well as low-level APIs. On the flip side, since it is still under development and not as mature as Tensorflow, it is currently being used primarily by researchers as opposed to developers,” said Anshuman Gupta, VP Data Science, MiQ Digital India.
TensorFlow has been adopted in various industries like healthcare, retail, social media, education, etc. Several products have been built on TensorFlow with the help of Keras–a widely used deep learning library to make machine learning models. On the other hand, JAX is a fledgeling framework.
If JAX can provide the feasibility and flexibility to run complex models in the future, then Google might switch to JAX as the optimal framework over TensorFlow.