Amid the wake of privacy and security concerns, federated learning has become one of the fastest-growing areas in research. From Apple to Google and Facebook, etc., all are now looking at ways to train a wide variety of deep learning models on edge devices inherently limited in bandwidth and possessing low connectivity.
Recently, Google researchers rolled out FedJAX, a JAX-based open-source library for federated learning simulations that focuses on ease-of-use in research. This new library aims to make developing and evaluating federated algorithms faster and easier for research. In addition, it serves as simple building blocks for implementing federated algorithms, prepackaged datasets, models, and faster simulation speed.
Check out the source code of FedJAX on GitHub, along with tutorial notebooks and examples. Also, check out the top research work in the field of federated learning here.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Your newsletter subscriptions are subject to AIM Privacy Policy and Terms and Conditions.
Here are some of the highlights of FedJAX
- FedJAX resembles the pseudo-code used to describe novel algorithms in academic papers, making it easy to get started.
- While FedJAX provides building blocks for federated learning, users can replace these with NumPy and JAX while keeping the overall training reasonably fast.
Dataset and Models
At present, there is a wide variety of commonly used datasets and models in federated learning research, including image recognition, language modelling, etc. This growing number of datasets and models are used straight out of the box in FedJAX. As a result, the preprocessed datasets and models do not have to be written from scratch. This encourages valid comparisons between various federated algorithms and accelerates the development of new algorithms, said Google researchers.
Currently, FedJAX comes with the following datasets and sample models, namely EMNIST-62 (character recognition task), Shakespeare (next character prediction task), and Stack Overflow (next word prediction task).
Plus, FedJAX provides tools to create new datasets and models that can be used with the rest of the library. It comes with standard implementations of federated averaging and other federated algorithms for training a shared model on decentralised examples like adaptive federated optimisers, agnostic federated averaging, and Mime. Thus, making it easier to compare and evaluate against existing algorithms and models.
Outcome
Google researchers benchmarked a standard FedJAX implementation of adaptive federated averaging on two tasks – the image recognition task for the federated EMNSIT-62 data set and the word prediction task for the Stack Overflow dataset.
While federated EMNSIT-62 is a smaller dataset that consists of 3400 users and their writing samples, the Stack Overflow dataset is much larger and consists of millions of Q&As from the Stack Overflow forum for hundreds of thousands of users.
Further, the researchers measured performance on various hardware specialised for machine learning. For example, for federated EMNIST-62, they trained a model for 1500 rounds with 50 clients per round on GPU (NVIDIA V100) and TPU (1 TensorCore) accelerators. Stack Overflow trained models for 1500 rounds with 50 clients per round on GPU (NVIDIA V100) using jax.jit, TPU (1 TensorCore on Google TPU v2) using jax.jit, and multi-core TPU (8 TensorCore on Google TPU v2) using only jax.jit.
Benchmark results for federated EMNIST-62 (Source: Google)
Benchmark results for Slack Overflow (Source: Google)
The team also recorded the average training round completion time, time taken for full evaluation on test data, and time for the overall execution, which includes both training and full evaluation. Further, they said that with standard hyperparameters and TPUs, the full experiments for federated EMNIST-62 are completed in a few minutes and take about an hour for Stack Overflow. The below image shows Stack Overflow average training round as the number of clients per round increases:
(Source: Google)
What’s next?
Google researchers said that they hope that FedJAX will foster even more investigation and interest in federated learning in the future. “Moving forward, we plan to continually grow our existing collection of algorithms, aggregation mechanisms, datasets, and models,” said the Google researchers.