Google Embraces JAX, Introduces Open-Source Library For Federated Simulation

FedJAX resembles the pseudo-code used to describe novel algorithms in academic papers.
Google Embraces JAX, Introduces Open-Source Library For Federated Simulation

Amid the wake of privacy and security concerns, federated learning has become one of the fastest-growing areas in research. From Apple to Google and Facebook, etc., all are now looking at ways to train a wide variety of deep learning models on edge devices inherently limited in bandwidth and possessing low connectivity. 

Recently, Google researchers rolled out FedJAX, a JAX-based open-source library for federated learning simulations that focuses on ease-of-use in research. This new library aims to make developing and evaluating federated algorithms faster and easier for research. In addition, it serves as simple building blocks for implementing federated algorithms, prepackaged datasets, models, and faster simulation speed. 

Check out the source code of FedJAX on GitHub, along with tutorial notebooks and examples. Also, check out the top research work in the field of federated learning here

Subscribe to our Newsletter

Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.

Here are some of the highlights of FedJAX

  • FedJAX resembles the pseudo-code used to describe novel algorithms in academic papers, making it easy to get started.
  • While FedJAX provides building blocks for federated learning, users can replace these with NumPy and JAX while keeping the overall training reasonably fast.

Dataset and Models 

At present, there is a wide variety of commonly used datasets and models in federated learning research, including image recognition, language modelling, etc. This growing number of datasets and models are used straight out of the box in FedJAX. As a result, the preprocessed datasets and models do not have to be written from scratch. This encourages valid comparisons between various federated algorithms and accelerates the development of new algorithms, said Google researchers. 

Currently, FedJAX comes with the following datasets and sample models, namely EMNIST-62 (character recognition task), Shakespeare (next character prediction task), and Stack Overflow (next word prediction task). 

Plus, FedJAX provides tools to create new datasets and models that can be used with the rest of the library. It comes with standard implementations of federated averaging and other federated algorithms for training a shared model on decentralised examples like adaptive federated optimisers, agnostic federated averaging, and Mime. Thus, making it easier to compare and evaluate against existing algorithms and models. 


Google researchers benchmarked a standard FedJAX implementation of adaptive federated averaging on two tasks – the image recognition task for the federated EMNSIT-62 data set and the word prediction task for the Stack Overflow dataset. 

While federated EMNSIT-62 is a smaller dataset that consists of 3400 users and their writing samples, the Stack Overflow dataset is much larger and consists of millions of Q&As from the Stack Overflow forum for hundreds of thousands of users. 

Further, the researchers measured performance on various hardware specialised for machine learning. For example, for federated EMNIST-62, they trained a model for 1500 rounds with 50 clients per round on GPU (NVIDIA V100) and TPU (1 TensorCore) accelerators. Stack Overflow trained models for 1500 rounds with 50 clients per round on GPU (NVIDIA V100) using jax.jit, TPU (1 TensorCore on Google TPU v2) using jax.jit, and multi-core TPU (8 TensorCore on Google TPU v2) using only jax.jit. 

Benchmark results for federated EMNIST-62 (Source: Google)

Benchmark results for Slack Overflow (Source: Google

The team also recorded the average training round completion time, time taken for full evaluation on test data, and time for the overall execution, which includes both training and full evaluation. Further, they said that with standard hyperparameters and TPUs, the full experiments for federated EMNIST-62 are completed in a few minutes and take about an hour for Stack Overflow. The below image shows Stack Overflow average training round as the number of clients per round increases:

(Source: Google

What’s next? 

Google researchers said that they hope that FedJAX will foster even more investigation and interest in federated learning in the future. “Moving forward, we plan to continually grow our existing collection of algorithms, aggregation mechanisms, datasets, and models,” said the Google researchers. 

Amit Raja Naik
Amit Raja Naik is a seasoned technology journalist who covers everything from data science to machine learning and artificial intelligence for Analytics India Magazine, where he examines the trends, challenges, ideas, and transformations across the industry.

Download our Mobile App

MachineHack | AI Hackathons, Coding & Learning

Host Hackathons & Recruit Great Data Talent!

AIMResearch Pioneering advanced AI market research

With a decade of experience under our belt, we are transforming how businesses use AI & data-driven insights to succeed.

The Gold Standard for Recognizing Excellence in Data Science and Tech Workplaces

With Best Firm Certification, you can effortlessly delve into the minds of your employees, unveil invaluable perspectives, and gain distinguished acclaim for fostering an exceptional company culture.

AIM Leaders Council

World’s Biggest Community Exclusively For Senior Executives In Data Science And Analytics.

3 Ways to Join our Community

Telegram group

Discover special offers, top stories, upcoming events, and more.

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Subscribe to our Daily newsletter

Get our daily awesome stories & videos in your inbox