Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
FAST Aiming at the problems of cumbersome steps and slow download speed of GNSS data

FAST Aiming at the problems of cumbersome steps and slow download speed of GNSS data, a relatively complete set of integrated multi-source data download terminal software fast is developed. The softw

ChangChuntao 23 Dec 31, 2022
Lama-cleaner: Image inpainting tool powered by LaMa

Lama-cleaner: Image inpainting tool powered by LaMa

Qing 5.8k Jan 05, 2023
Large-Scale Pre-training for Person Re-identification with Noisy Labels (LUPerson-NL)

LUPerson-NL Large-Scale Pre-training for Person Re-identification with Noisy Labels (LUPerson-NL) The repository is for our CVPR2022 paper Large-Scale

43 Dec 26, 2022
Optimize Trading Strategies Using Freqtrade

Optimize trading strategy using Freqtrade Short demo on building, testing and optimizing a trading strategy using Freqtrade. The DevBootstrap YouTube

DevBootstrap 139 Jan 01, 2023
🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries From the makers of spaCy, Prodigy and FastAPI Thinc is a

Explosion 2.6k Dec 30, 2022
Detection of PCBA defect

Detection_of_PCBA_defect Detection_of_PCBA_defect Use yolov5 to train. $pip install -r requirements.txt Detect.py will detect file(jpg,mp4...) in cu

6 Nov 28, 2022
Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight)

Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight) Abstract Due to the limited and even imbalanced dat

Hanzhe Hu 99 Dec 12, 2022
Autonomous Movement from Simultaneous Localization and Mapping

Autonomous Movement from Simultaneous Localization and Mapping About us Built by a group of Clarkson University students with the help from Professor

14 Nov 07, 2022
Serving PyTorch 1.0 Models as a Web Server in C++

Serving PyTorch Models in C++ This repository contains various examples to perform inference using PyTorch C++ API. Run git clone https://github.com/W

Onur Kaplan 223 Jan 04, 2023
Automatic Image Background Subtraction

Automatic Image Background Subtraction This repo contains set of scripts for automatic one-shot image background subtraction task using the following

Oleg Sémery 6 Dec 05, 2022
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate. Website • Key Features • How To Use • Docs •

Pytorch Lightning 21.1k Dec 29, 2022
Implementation for paper LadderNet: Multi-path networks based on U-Net for medical image segmentation

Implementation for paper LadderNet: Multi-path networks based on U-Net for medical image segmentation This implementation is based on orobix implement

Juntang Zhuang 116 Sep 06, 2022
Photographic Image Synthesis with Cascaded Refinement Networks - Pytorch Implementation

Photographic Image Synthesis with Cascaded Refinement Networks-Pytorch (https://arxiv.org/abs/1707.09405) This is a Pytorch implementation of cascaded

Soumya Tripathy 63 Mar 27, 2022
The official codes of "Semi-supervised Models are Strong Unsupervised Domain Adaptation Learners".

SSL models are Strong UDA learners Introduction This is the official code of paper "Semi-supervised Models are Strong Unsupervised Domain Adaptation L

Yabin Zhang 26 Dec 26, 2022
Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations, CVPR 2019 (Oral)

Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations The code of: Weakly Supervised Learning of Instance Segmentation with I

Jiwoon Ahn 472 Dec 29, 2022
Software associated to AAAI paper "Planning with Biological Neurons and Synapses"

jBrain Software associated with the AAAI 2022 paper Francesco D'Amore, Daniel Mitropolsky, Pierluigi Crescenzi, Emanuele Natale, Christos H. Papadimit

Pierluigi Crescenzi 1 Apr 10, 2022
an implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation using PyTorch

revisiting-sepconv This is a reference implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation [1] using PyTorch. Given two f

Simon Niklaus 59 Dec 22, 2022
AgML is a comprehensive library for agricultural machine learning

AgML is a comprehensive library for agricultural machine learning. Currently, AgML provides access to a wealth of public agricultural datasets for common agricultural deep learning tasks.

Plant AI and Biophysics Lab 1 Jul 07, 2022
Drone Task1 - Drone Task1 With Python

Drone_Task1 Matching Results 3.mp4 1.mp4

MLV Lab (Machine Learning and Vision Lab at Korea University) 11 Nov 14, 2022