On the Variance of the Adaptive Learning Rate and Beyond

Overview

License Travis-CI

RAdam

On the Variance of the Adaptive Learning Rate and Beyond

We are in an early-release beta. Expect some adventures and rough edges.

Table of Contents

Introduction

If warmup is the answer, what is the question?

The learning rate warmup for Adam is a must-have trick for stable training in certain situations (or eps tuning). But the underlying mechanism is largely unknown. In our study, we suggest one fundamental cause is the large variance of the adaptive learning rates, and provide both theoretical and empirical support evidence.

In addition to explaining why we should use warmup, we also propose RAdam, a theoretically sound variant of Adam.

Motivation

As shown in Figure 1, we assume that gradients follow a normal distribution (mean: \mu, variance: 1). The variance of the adaptive learning rate is simulated and plotted in Figure 1 (blue curve). We observe that the adaptive learning rate has a large variance in the early stage of training.

When using a Transformer for NMT, a warmup stage is usually required to avoid convergence problems (e.g., Adam-vanilla converges around 500 PPL in Figure 2, while Adam-warmup successfully converges under 10 PPL). In further explorations, we notice that, if we use additional 2000 samples to estimate the adaptive learning rate, the convergence problems are avoided (Adam-2k); or, if we increase the value of eps, the convergence problems are also relieved (Adam-eps).

Therefore, we conjecture that the large variance in the early stage causes the convergence problem, and further propose Rectified Adam by analytically reducing the large variance. More details can be found in our paper.

Questions and Discussions

Do I need to tune learning rate?

Yes, the robustness of RAdam is not infinity. In our experiments, it works for a broader range of learning rates, but not all learning rates.

Notes on Transformer (more discussions can be found in our Transformer Clinic project)

Choice of the Original Transformer. We choose the original Transformer as our main study object because, without warmup, it suffers from the most serious convergence problems in our experiments. With such serious problems, our controlled experiments can better verify our hypothesis (i.e., we demonstrate that Adam-2k / Adam-eps can avoid spurious local optima by minimal changes).

Sensitivity. We observe that the Transformer is sensitive to the architecture configuration, despite its efficiency and effectiveness. For example, by changing the position of the layer norm, the model may / may not require the warmup to get a good performance. Intuitively, since the gradient of the attention layer could be more sparse and the adaptive learning rates for smaller gradients have a larger variance, they are more sensitive. Nevertheless, we believe this problem deserves more in-depth analysis and is beyond the scope of our study.

Why does warmup have a bigger impact on some models than others?

Although the adaptive learning rate has a larger variance in the early stage, the exact magnitude is subject to the model design. Thus, the convergent problem could be more serious for some models/tasks than others. In our experiments, we observe that RAdam achieves consistent improvements over the vanilla Adam. It verifies the variance issue widely exists (since we can get better performance by fixing it).

What if the gradient is not zero-meaned?

As in Figure 1 (above), even if the gradient is not zero-meaned, the original adaptive learning rate still has a larger variance in the beginning, thus applying the rectification can help to stabilize the training.

Another related concern is that, when the mean of the gradient is significantly larger than its variance, the magnitude of the "problematic" variance may not be very large (i.e., in Figure 1, when \mu equals to 10, the adaptive learning rate variance is relatively small and may not cause problems). We think it provides a possible explaination on why warmup have a bigger impact on some models than others. Still, we suggest that, in real-world applications, neural networks usually have some parts of parameters meet our assumption well (i.e., their gradient variance is larger than their gradient mean), and needs the rectification to stabilize the training.

Why does SGD need warmup?

To the best of our knowledge, the warmup heuristic is originally designed for large minibatch SGD [0], based on the intuition that the network changes rapidly in the early stage. However, we find that it does not explain why Adam requires warmup. Notice that, Adam-2k uses the same large learning rate but with a better estimation of the adaptive learning rate can also avoid the convergence problems.

The reason why sometimes warmup also helps SGD still lacks of theoretical support. FYI, when optimizing a simple 2-layer CNN with gradient descent, the thoery of [1] could be used to show the benifits of warmup. Specifically, the lr must be $O(cos \phi)$, where $\phi$ is the angle between the current weight and the ground true weight and $cos \phi$ could be very small due to high dimensional space and random initialization. And thus lr must be very small at the beginning to guarentee the convergence. $cos \phi$ however can be improved in the later stage, and thus the learning rate is allowed to be larger. Their theory somehow can justify why warmup is needed by gradient descend and neural networks. But it is still far-fetched for the real scenario.

[0] Goyal et al, Accurate, Large Minibatch SGD: Training Imagenet in 1 Hour, 2017

[1] Du et al, Gradient Descent Learns One-hidden-layer CNN: Don’t be Afraid of Spurious Local Minima, 2017

Quick Start Guide

  1. Directly replace the vanilla Adam with RAdam without changing any settings.
  2. Further tune hyper-parameters (including the learning rate) for a better performance.

Note that in our paper, our major contribution is to identify why we need the warmup for Adam. Although some researchers successfully improve their model performance (user comments), considering the difficulty of training NNs, directly plugging in RAdam may not result in an immediate performance boost. Based on our experience, replacing the vanilla Adam with RAdam usually results in a better performance; however, if warmup has already been employed and tuned in the baseline method, it is necessary to also tune hyper-parameters for RAdam.

Related Posts and Repos

Unofficial Re-Implementations

RAdam is very easy to implement, we provide PyTorch implementations here, while third party ones can be found at:

Keras Implementation

Keras Implementation

Julia implementation in Flux.jl

Unofficial Introduction & Mentions

We provide a simple introduction in Motivation, and more details can be found in our paper. There are some unofficial introductions available (with better writings), and they are listed here for reference only (contents/claims in our paper are more accurate):

Medium Post

related Twitter Post

CSDN Post (in Chinese)

User Comments

We are happy to see that our algorithms are found to be useful by some users : -)

"...I tested it on ImageNette and quickly got new high accuracy scores for the 5 and 20 epoch 128px leaderboard scores, so I know it works... https://forums.fast.ai/t/meet-radam-imo-the-new-state-of-the-art-ai-optimizer/52656

— Less Wright August 15, 2019

Thought "sounds interesting, I'll give it a try" - top 5 are vanilla Adam, bottom 4 (I only have access to 4 GPUs) are RAdam... so far looking pretty promising! pic.twitter.com/irvJSeoVfx

— Hamish Dickson (@_mishy) August 16, 2019

RAdam works great for me! It’s good to several % accuracy for free, but the biggest thing I like is the training stability. RAdam is way more stable! https://medium.com/@mgrankin/radam-works-great-for-me-344d37183943

— Grankin Mikhail August 17, 2019

"... Also, I achieved higher accuracy results using the newly proposed RAdam optimization function.... https://towardsdatascience.com/optimism-is-on-the-menu-a-recession-is-not-d87cce265b10

— Sameer Ahuja August 24, 2019

"... Out-of-box RAdam implementation performs better than Adam and finetuned SGD... https://twitter.com/ukrdailo/status/1166265186920980480

— Alex Dailo August 27, 2019

Citation

Please cite the following paper if you found our model useful. Thanks!

Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning Representations.

@inproceedings{liu2019radam,
 author = {Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei},
 booktitle = {Proceedings of the Eighth International Conference on Learning Representations (ICLR 2020)},
 month = {April},
 title = {On the Variance of the Adaptive Learning Rate and Beyond},
 year = {2020}
}

Owner
Liyuan Liu
Ph.D. Student @ DMG, UIUC
Liyuan Liu
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 01, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 03, 2023
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022