Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)"

Related tags

Deep LearningSB-FBSDE
Overview

Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory [ICLR 2022]

Official PyTorch implementation of the paper "Likelihood Training of Schrödinger Bridge using Forward-Backward SDEs Theory (SB-FBSDE)" which introduces a new class of deep generative models that generalizes score-based models to fully nonlinear forward and backward diffusions.

SB-FBSDE result

This repo is co-maintained by Guan-Horng Liu and Tianrong Chen. Contact us if you have any questions! If you find this library useful, please cite ⬇️

@inproceedings{chen2022likelihood,
  title={Likelihood Training of Schr{\"o}dinger Bridge using Forward-Backward SDEs Theory},
  author={Chen, Tianrong and Liu, Guan-Horng and Theodorou, Evangelos A},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1). First, install the dependencies with Anaconda and activate the environment sb-fbsde with

conda env create --file requirements.yaml python=3
conda activate sb-fbsde

Training

python main.py \
  --problem-name <PROBLEM_NAME> \
  --forward-net <FORWARD_NET> \
  --backward-net <BACKWARD_NET> \
  --num-FID-sample <NUM_FID_SAMPLE> \ # add this flag only for CIFAR-10
  --dir <DIR> \
  --log-tb 

To train an SB-FBSDE from scratch, run the above command, where

  • PROBLEM_NAME is the dataset. We support gmm (2D mixture of Gaussian), checkerboard (2D toy dataset), mnist, celebA32, celebA64, cifar10.
  • FORWARD_NET & BACKWARD_NET are the deep networks for forward and backward drifts. We support Unet, nscnpp, and a toy network for 2D datasets.
  • NUM_FID_SAMPLE is the number of generated images used to evaluate FID locally. We recommend 10000 for training CIFAR-10. Note that this requires first downloading the FID statistics checkpoint.
  • DIR specifies where the results (e.g. snapshots during training) shall be stored.
  • log-tb enables logging with Tensorboard.

Additionally, use --load to restore previous checkpoint or pre-trained model. For training CIFAR-10 specifically, we support loading the pre-trained NCSN++ as the backward policy of the first SB training stage (this is because the first SB training stage can degenerate to denoising score matching under proper initialization; see more details in Appendix D of our paper).

Other configurations are detailed in options.py. The default configurations for each dataset are provided in the configs folder.

Evaluating the CIFAR-10 Checkpoint

To evaluate SB-FBSDE on CIFAR-10 (we achieve FID 3.01 and NLL 2.96), create a folder checkpoint then download the model checkpoint and FID statistics checkpoint either from Google Drive or through the following commands.

mkdir checkpoint && cd checkpoint

# FID stat checkpoint. This's needed whenever we
# need to compute FID during training or sampling.
gdown --id 1Tm_5nbUYKJiAtz2Rr_ARUY3KIFYxXQQD 

# SB-FBSDE model checkpoint for reproducing results in the paper.
gdown --id 1Kcy2IeecFK79yZDmnky36k4PR2yGpjyg 

After downloading the checkpoints, run the following commands for computing either NLL or FID. Set the batch size --samp-bs properly depending on your hardware.

# compute NLL
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-NLL --samp-bs <BS>
# compute FID
python main.py --problem-name cifar10 --forward-net Unet --backward-net ncsnpp --dir ICLR-2022-reproduce
  --load checkpoint/ciifar10_sbfbsde_stage_8.npz --compute-FID --samp-bs <BS> --num-FID-sample 50000 --use-corrector --snr 0.15
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Denoising Diffusion Implicit Models

Denoising Diffusion Implicit Models (DDIM) Jiaming Song, Chenlin Meng and Stefano Ermon, Stanford Implements sampling from an implicit model that is t

465 Jan 05, 2023
Using Hotel Data to predict High Value And Potential VIP Guests

Description Using hotel data and AI to predict high value guests and potential VIP guests. Hotel can leverage on prediction resutls to run more effect

HCG 12 Feb 14, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
Save-restricted-v-3 - Save restricted content Bot For telegram

Save restricted content Bot Contact: Telegram A stable telegram bot to get restr

DEVANSH 11 Dec 21, 2022
Content shared at DS-OX Meetup

Streamlit-Projects Streamlit projects available in this repo: An introduction to Streamlit presented at DS-OX (Feb 26, 2020) meetup Streamlit 101 - Ja

Arvindra 69 Dec 23, 2022
A TensorFlow implementation of Neural Program Synthesis from Diverse Demonstration Videos

ViZDoom http://vizdoom.cs.put.edu.pl ViZDoom allows developing AI bots that play Doom using only the visual information (the screen buffer). It is pri

Hyeonwoo Noh 1 Aug 19, 2020
This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Developed By Google!

Machine Learning Hand Detector This is a Machine Learning Based Hand Detector Project, It Uses Machine Learning Models and Modules Like Mediapipe, Dev

Popstar Idhant 3 Feb 25, 2022
The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022

DG-TrajGen The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022. Our Meth

Wang 25 Sep 26, 2022
MaRS - a recursive filtering framework that allows for truly modular multi-sensor integration

The Modular and Robust State-Estimation Framework, or short, MaRS, is a recursive filtering framework that allows for truly modular multi-sensor integration

Control of Networked Systems - University of Klagenfurt 143 Dec 29, 2022
[NeurIPS 2021] "G-PATE: Scalable Differentially Private Data Generator via Private Aggregation of Teacher Discriminators"

G-PATE This is the official code base for our NeurIPS 2021 paper: "G-PATE: Scalable Differentially Private Data Generator via Private Aggregation of T

AI Secure 14 Oct 12, 2022
Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation The code repository for "Audio-Visual Generalized Few-Shot Learning with

Kaiaicy 3 Jun 27, 2022
[CVPR 2021] "Multimodal Motion Prediction with Stacked Transformers": official code implementation and project page.

mmTransformer Introduction This repo is official implementation for mmTransformer in pytorch. Currently, the core code of mmTransformer is implemented

DeciForce: Crossroads of Machine Perception and Autonomy 232 Dec 31, 2022
Novel Instances Mining with Pseudo-Margin Evaluation for Few-Shot Object Detection

Novel Instances Mining with Pseudo-Margin Evaluation for Few-Shot Object Detection (NimPme) The official implementation of Novel Instances Mining with

12 Sep 08, 2022
Simple codebase for flexible neural net training

neural-modular Simple codebase for flexible neural net training. Allows for seamless exchange of models, dataset, and optimizers. Uses hydra for confi

Jannik Kossen 7 Apr 05, 2022
The Official Implementation of the ICCV-2021 Paper: Semantically Coherent Out-of-Distribution Detection.

SCOOD-UDG (ICCV 2021) This repository is the official implementation of the paper: Semantically Coherent Out-of-Distribution Detection Jingkang Yang,

Jake YANG 62 Nov 21, 2022
Malware Analysis Neural Network project.

MalanaNeuralNetwork Description Malware Analysis Neural Network project. Table of Contents Getting Started Requirements Installation Clone Set-Up VENV

2 Nov 13, 2021
Official Pytorch Implementation of GraphiT

GraphiT: Encoding Graph Structure in Transformers This repository implements GraphiT, described in the following paper: Grégoire Mialon*, Dexiong Chen

Inria Thoth 80 Nov 27, 2022
This is the official implementation of our proposed SwinMR

SwinMR This is the official implementation of our proposed SwinMR: Swin Transformer for Fast MRI Please cite: @article{huang2022swin, title={Swi

A Yang Lab (led by Dr Guang Yang) 27 Nov 17, 2022
PyTorch common framework to accelerate network implementation, training and validation

pytorch-framework PyTorch common framework to accelerate network implementation, training and validation. This framework is inspired by works from MML

Dongliang Cao 3 Dec 19, 2022
Acute ischemic stroke dataset

AISD Acute ischemic stroke dataset contains 397 Non-Contrast-enhanced CT (NCCT) scans of acute ischemic stroke with the interval from symptom onset to

Kongming Liang 21 Sep 06, 2022