Teaching Neural Networks When to Stop

Training Neural Networks When to Stop

If you were running and finished a course in five steps, you would not continue to march at the finish line to make sure you reached 10 steps in total. However, this absurd characteristic is a key problem that plagues deep neural networks, a type of machine learning model that governs a wide breadth of applications.

Typically, neural networks must go through a predetermined number of layers in order to complete every task, despite being able to complete the task in more or less layers. 

Now, a team of researchers from Georgia Tech, Google Brain, and King Abdullah University of Science and Technology have created a steerable architecture that allows neural networks to sequentially determine whether to stop at an intermediate layer for each input or to continue going. 

This novel approach combines a feed-forward deep model with a variational stopping policy, allowing the network to adaptively stop at an earlier layer to avoid wasting energy. Experimentally, research has shown that the new deep learning model with the newly applied stopping policy is able to improve the performances on a diverse set of tasks such as image denoising and multitask learning.

“Recently, there have been many efforts to bridge traditional algorithms with deep neural networks by combining the interpretability of the former and flexibility of the latter. Inspired by traditional algorithms which have certain stopping criteria for outputting results at different iterations, we design a variational stopping policy to decide which layer to stop for each input in the neural network,” said Xinshi Chen a Ph.D. student from the School of Computational Science and Engineering and researcher on the project.

According to Chen, training the neural network along with the stopping policy is very challenging and is one of the most important contributions of this research.

“Notably, our paper proposes a principled and efficient algorithm to jointly train these two components together. This algorithm can be mathematically explained from the variational Bayes perspective and can be generally applied to many problems,” she said.

What’s more is that deep neural networks are typically considered black boxes, meaning that researchers don’t mathematically know why an output – no matter how accurate – is produced. By bridging deep learning neural networks with the traditional algorithms’ steps, it has broken the black box restriction and is now an inherently interpretable – and therefore, more accountable – system. 

The findings of this research are published in the paper, Learning to Stop While Learning to Predict, which is set to be presented at the virtual Thirty-seventh International Conference on Machine Learning July 14 , 1:00-1:45 and July 15 12:00-12:45 EDT.

Contact: 

Kristen Perez

Communications Officer