Researchers from MIT and Facebook AI have introduced projUNN, an efficient method for training deep networks with unitary matrices. In their paper, the authors, which also include Yann Lecun, introduce two variants – Direct (projUNN-D) and Tangent (projUNN-T) to parameterise full N-dimensional unitary or orthogonal matrices with training runtime scaling as O(kN2).
Vanishing and exploding gradient problem
In cases where networks are deep, or the inputs are long sequences of data, learning in neural networks can become unstable. For example, in recurrent neural networks, the recurrent states are evolved through the repeated application of a linear transformation which is followed by pointwise nonlinearity, which can become unstable when the eigenvalues of the linear transformations are not of a unitary value. One can avoid this by using unitary matrices; they are usually used to overcome the vanishing and exploding gradients problem.
For the uninitiated, a gradient is a derivative of the loss function with respect to the weights. It is used to update the weights to minimise the loss function during backpropagation in neural networks. A vanishing gradient occurs when the derivative or the slope steadily gets smaller as we go backwards with every layer. When the weight update is exponentially small, the training time takes too much time and, in the worst case, may completely stop the neural network training. On the other hand, exploding gradients occur when the slop gets larger with every layer during backpropagation (opposite to what happens with vanishing gradients). Due to high weights, the gradient will never converge, resulting in it oscillating around the minima without really coming to a global minima point.
Subscribe to our Newsletter
Join our editors every weekday evening as they steer you through the most significant news of the day, introduce you to fresh perspectives, and provide unexpected moments of joy
Unitary neural networks were initially developed to address the problem of vanishing and exploding gradients in RNNs while learning information in a long sequence of data, more efficiently than the existing parameterisation such as LSTM. In the previous studies, unitarity is maintained by constructing a series of parameterised unitary transformations. One of such popular methods is the efficient unitary recurrent neural network (EUNN) that parameterised unitary matrices by composing unitary transformations like Given rotations and Fourier transformers.
While employing unitary matrices in each layer is effective, maintaining long-range stability by restricting network parameters to be strictly unitary comes at the cost of expensive parameterisation of increased training runtime.
What is projUNN
RNNs are ‘notoriously difficult’ to train. When the eigenvalues of the hidden to hidden weight matrix deviate from absolute 1, optimisation becomes difficult, especially when trying to learn long-term dependencies.
In the RNN setting, the earlier algorithms in applying n*n unitary matrices have parameterised matrices into layers of unitary/orthogonal transformations. In the layer-wise setting, unitary is enforced on all the values of the parameters, but many layers are required to form a composition that can recreate any desired unitary.
The authors of the current study propose projUNN, where matrices are updated directly via gradient-based optimisation and projected back to the closest unitary (projUNN-D) or transported in the direction of the gradient (projUNN-T). The authors claimed that projUNN is especially effective in the extreme case where gradients are approximated by rank-one matrices. With RNN, projUNN matches or exceeds existing benchmarks for the state of the art unitary neural network algorithms.
“Our PROJUNN shows that one need not sacrifice performance or runtime in training unitary neural network architectures,” the authors wrote. They also claimed that the results take advantage of the approximate low-rank structure of parameter gradients to perform updates at almost optimal runtime.
Read the full paper here.