Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

Overview

PyTorch Implementation of Differentiable ODE Solvers

This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1].

As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.

Installation

To install latest stable version:

pip install torchdiffeq

To install latest on GitHub:

pip install git+https://github.com/rtqichen/torchdiffeq

Examples

Examples are placed in the examples directory.

We encourage those who are interested in using this library to take a look at examples/ode_demo.py for understanding how to use torchdiffeq to fit a simple spiral ODE.

ODE Demo

Basic usage

This library provides one main interface odeint which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value,

dy/dt = f(t, y)    y(t_0) = y_0.

The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.

To solve an IVP using the default solver:

from torchdiffeq import odeint

odeint(func, y0, t)

where func is any callable implementing the ordinary differential equation f(t, x), y0 is an any-D Tensor representing the initial values, and t is a 1-D Tensor containing the evaluation points. The initial time is taken to be t[0].

Backpropagation through odeint goes through the internals of the solver. Note that this is not numerically stable for all solvers (but should probably be fine with the default dopri5 method). Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.

To use the adjoint method:

from torchdiffeq import odeint_adjoint as odeint

odeint(func, y0, t)

odeint_adjoint simply wraps around odeint, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.

The biggest gotcha is that func must be a nn.Module when using the adjoint method. This is used to collect parameters of the differential equation.

Differentiable event handling

We allow terminating an ODE solution based on an event function. Backpropagation through most solvers is supported. For usage of event handling in deep learning applications, see reference [2].

This can be invoked with odeint_event:

from torchdiffeq import odeint_event
odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs)
  • func and y0 are the same as odeint.
  • t0 is a scalar representing the initial time value.
  • event_fn(t, y) returns a tensor, and is a required keyword argument.
  • reverse_time is a boolean specifying whether we should solve in reverse time. Default is False.
  • odeint_interface is one of odeint or odeint_adjoint, specifying whether adjoint mode should be used for differentiating through the ODE solution. Default is odeint.
  • **kwargs: any remaining keyword arguments are passed to odeint_interface.

The solve is terminated at an event time t and state y when an element of event_fn(t, y) is equal to zero. Multiple outputs from event_fn can be used to specify multiple event functions, of which the first to trigger will terminate the solve.

Both the event time and final state are returned from odeint_event, and can be differentiated. Gradients will be backpropagated through the event function.

The numerical precision for the event time is determined by the atol argument.

See example of simulating and differentiating through a bouncing ball in examples/bouncing_ball.py.

Bouncing Ball

Keyword arguments for odeint(_adjoint)

Keyword arguments:

  • rtol Relative tolerance.
  • atol Absolute tolerance.
  • method One of the solvers listed below.
  • options A dictionary of solver-specific options, see the further documentation.

List of ODE Solvers:

Adaptive-step:

  • dopri8 Runge-Kutta of order 8 of Dormand-Prince-Shampine.
  • dopri5 Runge-Kutta of order 5 of Dormand-Prince-Shampine [default].
  • bosh3 Runge-Kutta of order 3 of Bogacki-Shampine.
  • fehlberg2 Runge-Kutta-Fehlberg of order 2.
  • adaptive_heun Runge-Kutta of order 2.

Fixed-step:

  • euler Euler method.
  • midpoint Midpoint method.
  • rk4 Fourth-order Runge-Kutta with 3/8 rule.
  • explicit_adams Explicit Adams-Bashforth.
  • implicit_adams Implicit Adams-Bashforth-Moulton.

Additionally, all solvers available through SciPy are wrapped for use with scipy_solver.

For most problems, good choices are the default dopri5, or to use rk4 with options=dict(step_size=...) set appropriately small. Adjusting the tolerances (adaptive solvers) or step size (fixed solvers), will allow for trade-offs between speed and accuracy.

Frequently Asked Questions

Take a look at our FAQ for frequently asked questions.

Further documentation

For details of the adjoint-specific and solver-specific options, check out the further documentation.

References

Applications of differentiable ODE solvers and event handling are discussed in these two papers:

[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." Advances in Neural Information Processing Systems. 2018. [arxiv]

[2] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." International Conference on Learning Representations. 2021. [arxiv]


If you found this library useful in your research, please consider citing.

@article{chen2018neuralode,
  title={Neural Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
  journal={Advances in Neural Information Processing Systems},
  year={2018}
}

@article{chen2021eventfn,
  title={Learning Neural Event Functions for Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian},
  journal={International Conference on Learning Representations},
  year={2021}
}
Owner
Ricky Chen
Ricky Chen
PyTorch toolkit for biomedical imaging

farabio is a minimal PyTorch toolkit for out-of-the-box deep learning support in biomedical imaging. For further information, see Wikis and Docs.

San Askaruly 47 Dec 28, 2022
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 03, 2023
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
Training PyTorch models with differential privacy

Opacus is a library that enables training PyTorch models with differential privacy. It supports training with minimal code changes required on the cli

1.3k Dec 29, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 01, 2023
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022