Now Reading
How To Establish Domain Transferability In Neural Models

How To Establish Domain Transferability In Neural Models

Ram Sagar

If a neural network say CNNs, are tasked with identifying the numbers, it is supposed to do this easily considering its reputation with image classification tasks.

The above image has digits in two styles. CNNs can achieve reasonably good accuracy (98%) when trained and evaluated on the source domain (SVHN). However, the same CNN model may perform poorly (67.1% accuracy) when evaluated on the target domain (MNIST).

This drop in performance generally comes from the distinct distributions between the two domains.

The images from the SVHN dataset contain various computer fonts, cluttered background from streets, and cropped digits near the image boundaries. Whereas the images from the MNIST dataset contain handwritten strokes and a clean background.


For improving the accuracy on the target dataset, we need to encounter with what is called a covariate shift problem.

What Is The Covariate Shift Problem?

If a  part of the target set (i.e., raw images without labels) has to be accessed and a domain adaptation has to be performed to transfer the underlying knowledge learned from the source to target, the same CNN model can obtain immediate performance boost from 67.1% to 98.9%. This task is called the covariate shift problem.

The existing methods like an adversarial learning-based method for domain adaptation at pixel-level would try to translate input images from one domain to the other, bringing the input distributions closer.

But without knowing the current state of the task-specific decision boundary, adversarial networks might continue the effort to perfect the road pixel synthesis and therefore optimize towards an ineffective direction.

So, it is important to preserve a notion of decision boundaries during distribution alignment.

To reduce the discrepancies within this adversarial training, the machine learning developers at Apple propose a metric based on the Wasserstein distance.

A New Metric: Sliced Wasserstein Discrepancy

Named after Russian mathematician Leonard Vaserstein, Wasserstein metric is a distance function which is used to compare the probability distributions of two variables.

In machine learning applications like image classification, probability distribution plays a major role in concluding whether a certain pixel value matches with the target which in turn decides the accuracy of the prediction.

Building on this, the team at Apple define the Sliced Wasserstein Discrepancy (SWD): a 1-D variational formulation of the Wasserstein distance between the outputs of the classifiers.

See Also
Apple & Google Collaborates To Develop COVID-19 Contact Tracing Technology
Apple & Google Collaborates To Develop COVID-19 Contact Tracing Technology

As shown in the figure above, SWD is designed to capture the dissimilarity of probability measures p1 and p2 between the task-specific classifiers C1 and C2, which take input from feature generator G. This provides geometrically meaningful guidance to detect target samples that are far from the source.

The whole process can be done in 3 steps:

  • Train G, C1, and C2 on a labeled source set to shape the decision boundaries.
  • Train C1 and C2 to maximize SWD on an unlabeled target set to detect target samples that are outside the reach of the source.
  • Train G to minimize the same SWD on an unlabeled target set to generate feature representations that are inside the support of the source.
Source: Apple Machine learning

When this metric is implemented on the previously discussed SVHN and the MNIST dataset, this  method generates much more discriminative feature representations compared to the model trained without adaptation as can be seen above

Future Direction

The team behind this work are hopeful that this method of unsupervised domain adaptation helps improve the performance of machine learning models in the presence of a domain shift. This method also enables training of models that are performant in diverse scenarios, by lowering the cost of data capture and annotation required to excel in areas where ground truth data is scarce or hard to collect; eventually enabling personalized machine learning by on-device adaptation of models for enhanced user experiences.

Know more about Slice Wasserstein Discrepancy here.

Provide your comments below


Copyright Analytics India Magazine Pvt Ltd

Scroll To Top