Google Embraces JAX, Introduces Open-Source Library For Federated Simulation

FedJAX resembles the pseudo-code used to describe novel algorithms in academic papers.
Google Embraces JAX, Introduces Open-Source Library For Federated Simulation

Amid the wake of privacy and security concerns, federated learning has become one of the fastest-growing areas in research. From Apple to Google and Facebook, etc., all are now looking at ways to train a wide variety of deep learning models on edge devices inherently limited in bandwidth and possessing low connectivity. 

Recently, Google researchers rolled out FedJAX, a JAX-based open-source library for federated learning simulations that focuses on ease-of-use in research. This new library aims to make developing and evaluating federated algorithms faster and easier for research. In addition, it serves as simple building blocks for implementing federated algorithms, prepackaged datasets, models, and faster simulation speed. 

Check out the source code of FedJAX on GitHub, along with tutorial notebooks and examples. Also, check out the top research work in the field of federated learning here


Sign up for your weekly dose of what's up in emerging technology.

Here are some of the highlights of FedJAX

  • FedJAX resembles the pseudo-code used to describe novel algorithms in academic papers, making it easy to get started.
  • While FedJAX provides building blocks for federated learning, users can replace these with NumPy and JAX while keeping the overall training reasonably fast.

Dataset and Models 

At present, there is a wide variety of commonly used datasets and models in federated learning research, including image recognition, language modelling, etc. This growing number of datasets and models are used straight out of the box in FedJAX. As a result, the preprocessed datasets and models do not have to be written from scratch. This encourages valid comparisons between various federated algorithms and accelerates the development of new algorithms, said Google researchers. 

Currently, FedJAX comes with the following datasets and sample models, namely EMNIST-62 (character recognition task), Shakespeare (next character prediction task), and Stack Overflow (next word prediction task). 

Plus, FedJAX provides tools to create new datasets and models that can be used with the rest of the library. It comes with standard implementations of federated averaging and other federated algorithms for training a shared model on decentralised examples like adaptive federated optimisers, agnostic federated averaging, and Mime. Thus, making it easier to compare and evaluate against existing algorithms and models. 


Google researchers benchmarked a standard FedJAX implementation of adaptive federated averaging on two tasks – the image recognition task for the federated EMNSIT-62 data set and the word prediction task for the Stack Overflow dataset. 

While federated EMNSIT-62 is a smaller dataset that consists of 3400 users and their writing samples, the Stack Overflow dataset is much larger and consists of millions of Q&As from the Stack Overflow forum for hundreds of thousands of users. 

Further, the researchers measured performance on various hardware specialised for machine learning. For example, for federated EMNIST-62, they trained a model for 1500 rounds with 50 clients per round on GPU (NVIDIA V100) and TPU (1 TensorCore) accelerators. Stack Overflow trained models for 1500 rounds with 50 clients per round on GPU (NVIDIA V100) using jax.jit, TPU (1 TensorCore on Google TPU v2) using jax.jit, and multi-core TPU (8 TensorCore on Google TPU v2) using only jax.jit. 

Benchmark results for federated EMNIST-62 (Source: Google)

Benchmark results for Slack Overflow (Source: Google

The team also recorded the average training round completion time, time taken for full evaluation on test data, and time for the overall execution, which includes both training and full evaluation. Further, they said that with standard hyperparameters and TPUs, the full experiments for federated EMNIST-62 are completed in a few minutes and take about an hour for Stack Overflow. The below image shows Stack Overflow average training round as the number of clients per round increases:

(Source: Google

What’s next? 

Google researchers said that they hope that FedJAX will foster even more investigation and interest in federated learning in the future. “Moving forward, we plan to continually grow our existing collection of algorithms, aggregation mechanisms, datasets, and models,” said the Google researchers. 

More Great AIM Stories

Amit Raja Naik
Amit Raja Naik is a seasoned technology journalist who covers everything from data science to machine learning and artificial intelligence for Analytics India Magazine, where he examines the trends, challenges, ideas, and transformations across the industry.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

Conference, in-person (Bangalore)
Cypher 2023
20-22nd Sep, 2023

3 Ways to Join our Community

Whatsapp group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our newsletter

Get the latest updates from AIM