PyTorch Implementation of DSB for Score Based Generative Modeling. Experiments managed using Hydra.

Overview

Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling

This repository contains the implementation for the paper Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling.

If using this code, please cite the paper:

    @article{de2021diffusion,
              title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
              author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
              journal={arXiv preprint arXiv:2106.01357},
              year={2021}
            }

Contributors

  • Valentin De Bortoli
  • James Thornton
  • Jeremy Heng
  • Arnaud Doucet

What is a Schrödinger bridge?

The Schrödinger Bridge (SB) problem is a classical problem appearing in applied mathematics, optimal control and probability; see [1, 2, 3]. In the discrete-time setting, it takes the following (dynamic) form. Consider as reference density p(x0:N) describing the process adding noise to the data. We aim to find p*(x0:N) such that p*(x0) = pdata(x0) and p*(xN) = pprior(xN) and minimize the Kullback-Leibler divergence between p* and p. In this work we introduce Diffusion Schrodinger Bridge (DSB), a new algorithm which uses score-matching approaches [4] to approximate the Iterative Proportional Fitting algorithm, an iterative method to find the solutions of the SB problem. DSB can be seen as a refinement of existing score-based generative modeling methods [5, 6].

Schrodinger bridge

Installation

This project can be installed from its git repository.

  1. Obtain the sources by:

    git clone https://github.com/anon284/schrodinger_bridge.git

or, if git is unavailable, download as a ZIP from GitHub https://github.com/.

  1. Install:

    conda env create -f conda.yaml

    conda activate bridge

  2. Download data examples:

    • CelebA: python data.py --data celeba --data_dir './data/'
    • MNIST: python data.py --data mnist --data_dir './data/'

How to use this code?

  1. Train Networks:
  • 2d: python main.py dataset=2d model=Basic num_steps=20 num_iter=5000
  • mnist python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>
  • celeba python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>

Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.

References

.. [1] Hans Föllmer Random fields and diffusion processes In: École d'été de Probabilités de Saint-Flour 1985-1987

.. [2] Christian Léonard A survey of the Schrödinger problem and some of its connections with optimal transport In: Discrete & Continuous Dynamical Systems-A 2014

.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon Optimal Transport in Systems and Control In: Annual Review of Control, Robotics, and Autonomous Systems 2020

.. [4] Aapo Hyvärinen and Peter Dayan Estimation of non-normalized statistical models by score matching In: Journal of Machine Learning Research 2005

.. [5] Yang Song and Stefano Ermon Generative modeling by estimating gradients of the data distribution In: Advances in Neural Information Processing Systems 2019

.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel Denoising diffusion probabilistic models In: Advances in Neural Information Processing Systems 2020

Owner
James Thornton
James Thornton
A pytorch implementation of Reading Wikipedia to Answer Open-Domain Questions.

DrQA A pytorch implementation of the ACL 2017 paper Reading Wikipedia to Answer Open-Domain Questions (DrQA). Reading comprehension is a task to produ

Runqi Yang 394 Nov 08, 2022
Inferred Model-based Fuzzer

IMF: Inferred Model-based Fuzzer IMF is a kernel API fuzzer that leverages an automated API model inferrence techinque proposed in our paper at CCS. I

SoftSec Lab 104 Sep 28, 2022
OpenMMLab Pose Estimation Toolbox and Benchmark.

Introduction English | 简体中文 MMPose is an open-source toolbox for pose estimation based on PyTorch. It is a part of the OpenMMLab project. The master b

OpenMMLab 2.8k Dec 31, 2022
On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation

On Nonlinear Latent Transformations for GAN-based Image Editing - PyTorch implementation On Nonlinear Latent Transformations for GAN-based Image Editi

Valentin Khrulkov 22 Oct 24, 2022
maximal update parametrization (µP)

Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer) Paper link | Blog link In Tensor Programs V: Tuning Large Neural Networks

Microsoft 694 Jan 03, 2023
Cooperative Driving Dataset: a dataset for multi-agent driving scenarios

Cooperative Driving Dataset (CODD) The Cooperative Driving dataset is a synthetic dataset generated using CARLA that contains lidar data from multiple

Eduardo Henrique Arnold 124 Dec 28, 2022
Anime Face Detector using mmdet and mmpose

Anime Face Detector This is an anime face detector using mmdetection and mmpose. (To avoid copyright issues, I use generated images by the TADNE model

198 Jan 07, 2023
Original Implementation of Prompt Tuning from Lester, et al, 2021

Prompt Tuning This is the code to reproduce the experiments from the EMNLP 2021 paper "The Power of Scale for Parameter-Efficient Prompt Tuning" (Lest

Google Research 282 Dec 28, 2022
Tech Resources for Academic Communities

Free tech resources for faculty, students, researchers, life-long learners, and academic community builders for use in tech based courses, workshops, and hackathons.

Microsoft 2.5k Jan 04, 2023
"MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction" (CVPRW 2022) & (Winner of NTIRE 2022 Challenge on Spectral Reconstruction from RGB)

MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction (CVPRW 2022) Yuanhao Cai, Jing Lin, Zudi Lin, Haoqian Wang, Yulun Z

Yuanhao Cai 274 Jan 05, 2023
Graph neural network message passing reframed as a Transformer with local attention

Adjacent Attention Network An implementation of a simple transformer that is equivalent to graph neural network where the message passing is done with

Phil Wang 49 Dec 28, 2022
A collection of Google research projects related to Federated Learning and Federated Analytics.

Federated Research Federated Research is a collection of research projects related to Federated Learning and Federated Analytics. Federated learning i

Google Research 483 Jan 05, 2023
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022
Improving Contrastive Learning by Visualizing Feature Transformation, ICCV 2021 Oral

Improving Contrastive Learning by Visualizing Feature Transformation This project hosts the codes, models and visualization tools for the paper: Impro

Bingchen Zhao 83 Dec 15, 2022
Embeddinghub is a database built for machine learning embeddings.

Embeddinghub is a database built for machine learning embeddings.

Featureform 1.2k Jan 01, 2023
GitHub repository for the ICLR Computational Geometry & Topology Challenge 2021

ICLR Computational Geometry & Topology Challenge 2022 Welcome to the ICLR 2022 Computational Geometry & Topology challenge 2022 --- by the ICLR 2022 W

42 Dec 13, 2022
This repo provides the base code for pytorch-lightning and weight and biases simultaneous integration.

Write your model faster with pytorch-lightning-wadb-code-backbone This repository provides the base code for pytorch-lightning and weight and biases s

9 Mar 29, 2022
Code for "Multi-View Multi-Person 3D Pose Estimation with Plane Sweep Stereo"

Multi-View Multi-Person 3D Pose Estimation with Plane Sweep Stereo This repository includes the source code for our CVPR 2021 paper on multi-view mult

Jiahao Lin 66 Jan 04, 2023
Joint Gaussian Graphical Model Estimation: A Survey

Joint Gaussian Graphical Model Estimation: A Survey Test Models Fused graphical lasso [1] Group graphical lasso [1] Graphical lasso [1] Doubly joint s

Koyejo Lab 1 Aug 10, 2022
Unsupervised Domain Adaptation for Nighttime Aerial Tracking (CVPR2022)

Unsupervised Domain Adaptation for Nighttime Aerial Tracking (CVPR2022) Junjie Ye, Changhong Fu, Guangze Zheng, Danda Pani Paudel, and Guang Chen. Uns

Intelligent Vision for Robotics in Complex Environment 91 Dec 30, 2022