Second-Order Neural ODE Optimizer, NeurIPS 2021 spotlight

Related tags

Deep Learningsnopt
Overview

Second-order Neural ODE Optimizer
(NeurIPS 2021 Spotlight) [arXiv]

✔️ faster convergence in wall-clock time | ✔️ O(1) memory cost |
✔️ better test-time performance | ✔️ architecture co-optimization

This repo provides PyTorch code of Second-order Neural ODE Optimizer (SNOpt), a second-order optimizer for training Neural ODEs that retains O(1) memory cost with superior convergence and test-time performance.

SNOpt result

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1) and torchdiffeq >= 0.2.0 are required.

  1. Install the dependencies with Anaconda and activate the environment snopt with
    conda env create --file requirements.yaml python=3
    conda activate snopt
  2. [Optional] This repo provides a modification (with 15 lines!) of torchdiffeq that allows SNOpt to collect 2nd-order information during adjoint-based training. If you wish to run torchdiffeq on other commit, simply copy-and-paste the folder to this directory then apply the provided snopt_integration.patch.
    cp -r <path_to_your_torchdiffeq_folder> .
    git apply snopt_integration.patch

Run the code

We provide example code for 8 datasets across image classification (main_img_clf.py), time-series prediction (main_time_series.py), and continuous normalizing flow (main_cnf.py). The command lines to generate similar results shown in our paper are detailed in scripts folder. Datasets will be automatically downloaded to data folder at the first call, and all results will be saved to result folder.

bash scripts/run_img_clf.sh     <dataset> # dataset can be {mnist, svhn, cifar10}
bash scripts/run_time_series.sh <dataset> # dataset can be {char-traj, art-wr, spo-ad}
bash scripts/run_cnf.sh         <dataset> # dataset can be {miniboone, gas}

For architecture (specifically integration time) co-optimization, run

bash scripts/run_img_clf.sh cifar10-t1-optimize

Integration with your workflow

snopt can be integrated flawlessly with existing training work flow. Below we provide a handy checklist and pseudo-code to help your integration. For more complex examples, please refer to main_*.py in this repo.

  • Import torchdiffeq that is patched with snopt integration; otherwise simply use torchdiffeq in this repo.
  • Inherit snopt.ODEFuncBase as your vector field; implement the forward pass in F rather than forward.
  • Create Neural ODE with ode layer(s) using snopt.ODEBlock; implement properties odes and ode_mods.
  • Initialize snopt.SNOpt as preconditioner; call train_itr_setup() and step() before standard optim.zero_grad() and optim.step() (see the code below).
  • That's it 🤓 ! Enjoy your second-order training 🚂 🚅 !
import torch
from torchdiffeq import odeint_adjoint as odesolve
from snopt import SNOpt, ODEFuncBase, ODEBlock
from easydict import EasyDict as dict

class ODEFunc(ODEFuncBase):
    def __init__(self, opt):
        super(ODEFunc, self).__init__(opt)
        self.linear = torch.nn.Linear(input_dim, input_dim)

    def F(self, t, z):
        return self.linear(z)

class NeuralODE(torch.nn.Module):
    def __init__(self, ode):
        super(NeuralODE, self).__init__()
        self.ode = ode

    def forward(self, z):
        return self.ode(z)

    @property
    def odes(self): # in case we have multiple odes, collect them in a list
        return [self.ode]

    @property
    def ode_mods(self): # modules of all ode(s)
        return [mod for mod in self.ode.odefunc.modules()]

# Create Neural ODE
opt = dict(
    optimizer='SNOpt',tol=1e-3,ode_solver='dopri5',use_adaptive_t1=False,snopt_step_size=0.01)
odefunc = ODEFunc(opt)
integration_time = torch.tensor([0.0, 1.0]).float()
ode = ODEBlock(opt, odefunc, odesolve, integration_time)
net = NeuralODE(ode)

# Create SNOpt optimizer
precond = SNOpt(net, eps=0.05, update_freq=100)
optim = torch.optim.SGD(net.parameters(), lr=0.001)

# Training loop
for (x,y) in training_loader:
    precond.train_itr_setup() # <--- additional step for precond
    optim.zero_grad()

    loss = loss_function(net(x), y)
    loss.backward()

    # Run SNOpt optimizer
    precond.step()            # <--- additional step for precond
    optim.step()

What the library actually contains

This snopt library implements the following objects for efficient 2nd-order adjoint-based training of Neural ODEs.

  • ODEFuncBase: Defines the vector field (inherits torch.nn.Module) of Neural ODE.
  • CNFFuncBase: Serves the same purposes as ODEFuncBase except for CNF applications.
  • ODEBlock: A Neural-ODE module (torch.nn.Module) that solves the initial value problem (given the vector field, integration time, and a ODE solver) and handles integration time co-optimization with feedback policy.
  • SNOpt: Our primary 2nd-order optimizer (torch.optim.Optimizer), implemented as a "preconditioner" (see example code above). It takes the following arguments.
    • net is the Neural ODE. Note that the entire network (rather than net.parameters()) is required.
    • eps is the the regularization that stabilizes preconditioning. We recommend the value in [0.05, 0.1].
    • update_freq is the frequency to refresh the 2nd-order information. We recommend the value 100~200.
    • alpha decides the running averages of eigenvalues. We recommend fixing the value to 0.75.
    • full_precond decides whether we wish to precondition layers aside from those in Neural ODEs.
  • SNOptAdjointCollector: A helper to collect information from torchdiffeq to construct 2nd-order matrices.
  • IntegrationTimeOptimizer: Our 2nd-order method that co-optimizes the integration time (i.e., t1). This is done by calling t1_train_itr_setup(train_it) and update_t1() together with optim.zero_grad() and optim.step() (see trainer.py).

The options are passed in as opt and contains the following fields (see options.py for full descriptions.)

  • optimizer is the training method. Use "SNOpt" to enable our method.
  • ode_solver specifies the ODE solver (default is "dopri5") with the absolute/relative tolerance tol.
  • For CNF applications, use divergence_type to specify how divergence should be computed.
  • snopt_step_size determines the step sizes SNOpt will sample along the integration to compute 2nd-order matrices. We recommend the value 0.01 for integration time [0,1], which yield around 100 sampled points.
  • For integration time (t1) co-optimization, enable the flag use_adaptive_t1 and setup the following options.
    • adaptive_t1 specifies t1 optimization method. Choices are "baseline" and "feedback"(ours).
    • t1_lr is the learning rate. We recommend the value in [0.05, 0.1].
    • t1_reg is the coefficient of the quadratic penalty imposed on t1. The performance is quite sensitive to this value. We recommend the value in [1e-4, 1e-3].
    • t1_update_freq is the frequency to update t1. We recommend the value 50~100.

Remarks & Citation

The current library only supports adjoint-based training, yet it can be extended to normal odeint method (stay tuned!). The pre-processing of tabular and uea datasets are adopted from ffjord and NeuralCDE, and the eigenvalue-regularized preconditioning is adopted from EKFAC-pytorch.

If you find this library useful, please cite ⬇️ . Contact me ([email protected]) if you have any questions!

@inproceedings{liu2021second,
  title={Second-order Neural ODE Optimizer},
  author={Liu, Guan-Horng and Chen, Tianrong and Theodorou, Evangelos A},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021},
}
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation. Training python train.py --c

Rishikesh (ऋषिकेश) 55 Dec 26, 2022
TensorFlow-LiveLessons - "Deep Learning with TensorFlow" LiveLessons

TensorFlow-LiveLessons Note that the second edition of this video series is now available here. The second edition contains all of the content from th

Deep Learning Study Group 830 Jan 03, 2023
[ICCV 2021] Our work presents a novel neural rendering approach that can efficiently reconstruct geometric and neural radiance fields for view synthesis.

MVSNeRF Project page | Paper This repository contains a pytorch lightning implementation for the ICCV 2021 paper: MVSNeRF: Fast Generalizable Radiance

Anpei Chen 529 Dec 30, 2022
PyTorch implementation of an end-to-end Handwritten Text Recognition (HTR) system based on attention encoder-decoder networks

AttentionHTR PyTorch implementation of an end-to-end Handwritten Text Recognition (HTR) system based on attention encoder-decoder networks. Scene Text

Dmitrijs Kass 31 Dec 22, 2022
Unadversarial Examples: Designing Objects for Robust Vision

Unadversarial Examples: Designing Objects for Robust Vision This repository contains the code necessary to replicate the major results of our paper: U

Microsoft 93 Nov 28, 2022
DECAF: Deep Extreme Classification with Label Features

DECAF DECAF: Deep Extreme Classification with Label Features @InProceedings{Mittal21, author = "Mittal, A. and Dahiya, K. and Agrawal, S. and Sain

46 Nov 06, 2022
Uses OpenCV and Python Code to detect a face on the screen

Simple-Face-Detection This code uses OpenCV and Python Code to detect a face on the screen. This serves as an example program. Important prerequisites

Denis Woolley (CreepyD) 1 Feb 12, 2022
Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing

Cerberus Transformer: Joint Semantic, Affordance and Attribute Parsing Paper Introduction Multi-task indoor scene understanding is widely considered a

62 Dec 05, 2022
render sprites into your desktop environment as shaped windows using GTK

spritegtk render static or animated sprites into your desktop environment as dynamic shaped windows using GTK requires pycairo and PYGobject: pip inst

hermit 20 Oct 27, 2022
TensorFlow (Python API) implementation of Neural Style

neural-style-tf This is a TensorFlow implementation of several techniques described in the papers: Image Style Transfer Using Convolutional Neural Net

Cameron 3.1k Jan 02, 2023
Official PyTorch implementation of Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval.

Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval PyTorch This is the PyTorch implementation of Retrieve in Style: Unsupervised Fa

60 Oct 12, 2022
A code implementation of AC-GC: Activation Compression with Guaranteed Convergence, in NeurIPS 2021.

Code For AC-GC: Lossy Activation Compression with Guaranteed Convergence This code is intended to be used as a supplemental material for submission to

Dave Evans 2 Nov 01, 2022
Pytorch implementation of Masked Auto-Encoder

Masked Auto-Encoder (MAE) Pytorch implementation of Masked Auto-Encoder: Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick

Jiyuan 22 Dec 13, 2022
Delta Conformity Sociopatterns Analysis - Delta Conformity Sociopatterns Analysis

Delta_Conformity_Sociopatterns_Analysis ∆-Conformity is a local homophily measur

2 Jan 09, 2022
Official Pytorch Code for the paper TransWeather

TransWeather Official Code for the paper TransWeather, Arxiv Tech Report 2021 Paper | Website About this repo: This repo hosts the implentation code,

Jeya Maria Jose 81 Dec 30, 2022
sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

445 Jan 02, 2023
Python scripts for performing stereo depth estimation using the HITNET Tensorflow model.

HITNET-Stereo-Depth-estimation Python scripts for performing stereo depth estimation using the HITNET Tensorflow model from Google Research. Stereo de

Ibai Gorordo 76 Jan 02, 2023
Simple ONNX operation generator. Simple Operation Generator for ONNX.

sog4onnx Simple ONNX operation generator. Simple Operation Generator for ONNX. https://github.com/PINTO0309/simple-onnx-processing-tools Key concept V

Katsuya Hyodo 6 May 15, 2022
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022