Google AI Releases Method To Determine Neural Network Learning Sequence

A method called Task Affinity Groupings (TAG) has been proposed by Google AI that determines which tasks should be trained together in multi-task neural networks. 


Traditional machine learning (ML) models typically focus on learning one task at a time. However, there may be instances when learning from many related tasks simultaneously would lead to better modelling performance. This is addressed in the domain of multi-task learning, a subfield of ML in which multiple objectives are trained within the same model at the same time. 

To tackle this problem, a method called Task Affinity Groupings (TAG) has been proposed by Google AI that determines which tasks should be trained together in multi-task neural networks. 


Sign up for your weekly dose of what's up in emerging technology.

The approach attempts to divide a set of tasks into smaller subsets such that the performance across all tasks is maximised. To accomplish such a goal, it trains all tasks together in a single multi-task model and measures the degree to which one task’s gradient update on the model’s parameters would affect the loss of the other tasks in the network. This quantity is denoted as inter-task affinity. The research team’s experimental findings indicated that selecting groups of tasks that maximise inter-task affinity correlates strongly with overall model performance.

The team drew inspiration from meta-learning, a domain of machine learning that trains a neural network that can be quickly adapted to a new and previously unseen task. One of the classic meta-learning algorithms, MAML, applies a gradient update to the models’ parameters for a collection of tasks and then updates its original set of parameters to minimise the loss for a subset of tasks in that collection computed at the updated parameter values. 

TAG employs a similar mechanism to gain insight into the training dynamics of multi-task neural networks. In particular, it updates the model’s parameters only to a single task, looks at how this change would affect the other tasks in the multi-task neural network, and then undoes this update. This process is then repeated for every other task to gather information on how each task in the network would interact with any other task. Training then continues as normal by updating the model’s shared parameters with respect to every task in the network.

Image: Google AI

Collecting these statistics, and looking at their dynamics throughout training, reveals that certain tasks consistently exhibit beneficial relationships, while some are antagonistic towards each other. A network selection algorithm can leverage this data in order to group tasks together that maximises inter-task affinity, subject to a practitioner’s choice of how many multi-task networks can be used during inference.

On the CelebA and Taskonomy datasets, TAG shows competitive performance, operating between 32x and 11.5x faster than traditional methods. On the Taskonomy dataset, this speedup translates to 2,008 fewer Tesla V100 GPU hours to find task groupings.

TAG, therefore, seems like an efficient method to determine which tasks should train together in a single training run. The method looks at how tasks interact through training, notably, the effect that updating the model’s parameters when training on one task would have on the loss values of the other tasks in the network. The research team at Google AI found that selecting groups of tasks to maximise this score correlates strongly with model performance.

More Great AIM Stories

Victor Dey
Victor is an aspiring Data Scientist & is a Master of Science in Data Science & Big Data Analytics. He is a Researcher, a Data Science Influencer and also an Ex-University Football Player. A keen learner of new developments in Data Science and Artificial Intelligence, he is committed to growing the Data Science community.

Our Upcoming Events

Conference, in-person (Bangalore)
MachineCon 2022
24th Jun

Conference, Virtual
Deep Learning DevCon 2022
30th Jul

Conference, in-person (Bangalore)
Cypher 2022
21-23rd Sep

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM