Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
Studying Python release adoptions by looking at PyPI downloads

Analysis of version adoptions on PyPI We get PyPI download statistics via Google's BigQuery using the pypinfo tool. Usage First you need to get an acc

Julien Palard 9 Nov 04, 2022
The self-supervised goal reaching benchmark introduced in Discovering and Achieving Goals via World Models

Lexa-Benchmark Codebase for the self-supervised goal reaching benchmark introduced in 'Discovering and Achieving Goals via World Models'. Setup Create

1 Oct 14, 2021
Implementation of Axial attention - attending to multi-dimensional data efficiently

Axial Attention Implementation of Axial attention in Pytorch. A simple but powerful technique to attend to multi-dimensional data efficiently. It has

Phil Wang 250 Dec 25, 2022
SwinTrack: A Simple and Strong Baseline for Transformer Tracking

SwinTrack This is the official repo for SwinTrack. A Simple and Strong Baseline Prerequisites Environment conda (recommended) conda create -y -n SwinT

LitingLin 196 Jan 04, 2023
This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems

Stability Audit This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems, Humantic

Data, Responsibly 4 Oct 27, 2022
PyTorch Implementation of VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis.

VAENAR-TTS - PyTorch Implementation PyTorch Implementation of VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis.

Keon Lee 67 Nov 14, 2022
Transfer SemanticKITTI labeles into other dataset/sensor formats.

LiDAR-Transfer Transfer SemanticKITTI labeles into other dataset/sensor formats. Content Convert datasets (NUSCENES, FORD, NCLT) to KITTI format Minim

Photogrammetry & Robotics Bonn 64 Nov 21, 2022
Random Forests for Regression with Missing Entries

Random Forests for Regression with Missing Entries These are specific codes used in the article: On the Consistency of a Random Forest Algorithm in th

Irving Gómez-Méndez 1 Nov 15, 2021
Reusable constraint types to use with typing.Annotated

annotated-types PEP-593 added typing.Annotated as a way of adding context-specific metadata to existing types, and specifies that Annotated[T, x] shou

125 Dec 26, 2022
Deep Learning Models for Causal Inference

Extensive tutorials for learning how to build deep learning models for causal inference using selection on observables in Tensorflow 2.

Bernard J Koch 151 Dec 31, 2022
StyleGAN2 Webtoon / Anime Style Toonify

StyleGAN2 Webtoon / Anime Style Toonify Korea Webtoon or Japanese Anime Character Stylegan2 base high Quality 1024x1024 / 512x512 Generate and Transfe

121 Dec 21, 2022
Implementation of "Learning to Match Features with Seeded Graph Matching Network" ICCV2021

SGMNet Implementation PyTorch implementation of SGMNet for ICCV'21 paper "Learning to Match Features with Seeded Graph Matching Network", by Hongkai C

87 Dec 11, 2022
Official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspective with Transformer"

[AAAI2022] UCTransNet This repo is the official implementation of "UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wise Perspectiv

Haonan Wang 199 Jan 03, 2023
OOD Generalization and Detection (ACL 2020)

Pretrained Transformers Improve Out-of-Distribution Robustness How does pretraining affect out-of-distribution robustness? We create an OOD benchmark

littleRound 57 Jan 09, 2023
It's A ML based Web Site build with python and Django to find the breed of the dog

ML-Based-Dog-Breed-Identifier This is a Django Based Web Site To Identify the Breed of which your DOG belogs All You Need To Do is to Follow These Ste

Sanskar Dwivedi 2 Oct 12, 2022
A no-BS, dead-simple training visualizer for tf-keras

A no-BS, dead-simple training visualizer for tf-keras TrainingDashboard Plot inter-epoch and intra-epoch loss and metrics within a jupyter notebook wi

Vibhu Agrawal 3 May 28, 2021
Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness

Computationally Efficient Optimization of Plackett-Luce Ranking Models for Relevance and Fairness This repository contains the code used for the exper

H.R. Oosterhuis 28 Nov 29, 2022
MiraiML: asynchronous, autonomous and continuous Machine Learning in Python

MiraiML Mirai: future in japanese. MiraiML is an asynchronous engine for continuous & autonomous machine learning, built for real-time usage. Usage In

Arthur Paulino 25 Jul 27, 2022
RARA: Zero-shot Sim2Real Visual Navigation with Following Foreground Cues

RARA: Zero-shot Sim2Real Visual Navigation with Following Foreground Cues FGBG (foreground-background) pytorch package for defining and training model

Klaas Kelchtermans 1 Jun 02, 2022
Source code for our paper "Improving Empathetic Response Generation by Recognizing Emotion Cause in Conversations"

Source code for our paper "Improving Empathetic Response Generation by Recognizing Emotion Cause in Conversations" this repository is maintained by bo

Yuhan Liu 24 Nov 29, 2022