Official code for Score-Based Generative Modeling through Stochastic Differential Equations

Overview

Score-Based Generative Modeling through Stochastic Differential Equations

This repo contains the official implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations

by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole


We propose a unified framework that generalizes and improves previous work on score-based generative models through the lens of stochastic differential equations (SDEs). In particular, we can transform data to a simple noise distribution with a continuous-time stochastic process described by an SDE. This SDE can be reversed for sample generation if we know the score of the marginal distributions at each intermediate time step, which can be estimated with score matching. The basic idea is captured in the figure below:

schematic

Our work enables a better understanding of existing approaches, new sampling algorithms, exact likelihood computation, uniquely identifiable encoding, latent code manipulation, and brings new conditional generation abilities to the family of score-based generative models.

All combined, we achieved an FID of 2.20 and an Inception score of 9.89 for unconditional generation on CIFAR-10, as well as high-fidelity generation of 1024px Celeba-HQ images. In addition, we obtained a likelihood value of 2.99 bits/dim on uniformly dequantized CIFAR-10 images.

What does this code do?

Aside from the NCSN++ and DDPM++ models in our paper, this codebase also re-implements many previous score-based models all in one place, including NCSN from Generative Modeling by Estimating Gradients of the Data Distribution, NCSNv2 from Improved Techniques for Training Score-Based Generative Models, and DDPM from Denoising Diffusion Probabilistic Models.

It supports training new models, evaluating the sample quality and likelihoods of existing models. We carefully designed the code to be modular and easily extensible to new SDEs, predictors, or correctors.

How to run the code

Dependencies

Run the following to install a subset of necessary python packages for our code

pip install -r requirements.txt

Usage

Train and evaluate our models through main.py.

main.py:
  --config: Training configuration.
    (default: 'None')
  --eval_folder: The folder name for storing evaluation results
    (default: 'eval')
  --mode: <train|eval>: Running mode: train or eval
  --workdir: Working directory
  • config is the path to the config file. Our prescribed config files are provided in configs/. They are formatted according to ml_collections and should be quite self-explanatory.

  • workdir is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

  • eval_folder is the name of a subfolder in workdir that stores all artifacts of the evaluation process, like meta checkpoints for pre-emption prevention, image samples, and numpy dumps of quantitative results.

  • mode is either "train" or "eval". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in workdir . When set to "eval", it can do an arbitrary combination of the following

    • Evaluate the loss function on the test / validation dataset.

    • Generate a fixed number of samples and compute its Inception score, FID, or KID.

    • Compute the log-likelihood on the training or test dataset.

    These functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections package. For example, to generate samples and evaluate sample quality, supply the --config.eval.enable_sampling flag; to compute log-likelihoods, supply the --config.eval.enable_bpd flag, and specify --config.eval.dataset=train/test to indicate whether to compute the likelihoods on the training or test dataset.

How to extend the code

  • New SDEs: inherent the sde_lib.SDE abstract class and implement all abstract methods. The discretize() method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
  • New predictors: inherent the sampling.Predictor abstract class, implement the update_fn abstract method, and register its name with @register_predictor. The new predictor can be directly used in sampling.get_pc_sampler for Predictor-Corrector sampling, and all other controllable generation methods in controllable_generation.py.
  • New correctors: inherent the sampling.Corrector abstract class, implement the update_fn abstract method, and register its name with @register_corrector. The new corrector can be directly used in sampling.get_pc_sampler, and all other controllable generation methods in controllable_generation.py.

Pretrained checkpoints

Link: https://drive.google.com/drive/folders/10pQygNzF7hOOLwP3q8GiNxSnFRpArUxQ?usp=sharing

You may find two checkpoints for some models. The first checkpoint (with a smaller number) is the one that we reported FID scores in Table 3. The second checkpoint (with a larger number) is the one that we reported likelihood values and FIDs of black-box ODE samplers in Table 2. The former corresponds to the smallest FID during the course of training (every 50k iterations). The later is the last checkpoint during training.

Demonstrations and tutorials

  • Load our pretrained checkpoints and play with sampling, likelihood computation, and controllable synthesis

Open In Colab

  • Tutorial of score-based generative models in JAX + FLAX

Open In Colab

  • Tutorial of score-based generative models in PyTorch

Open In Colab

References

If you find the code useful for your research, please consider citing

@inproceedings{
  song2021scorebased,
  title={Score-Based Generative Modeling through Stochastic Differential Equations},
  author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=PxTIG12RRHS}
}

This work is built upon some previous papers which might also interest you:

  • Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
  • Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
  • Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
Owner
Yang Song
PhD Candidate in Stanford AI Lab
Yang Song
Awesome AI Learning with +100 AI Cheat-Sheets, Free online Books, Top Courses, Best Videos and Lectures, Papers, Tutorials, +99 Researchers, Premium Websites, +121 Datasets, Conferences, Frameworks, Tools

All about AI with Cheat-Sheets(+100 Cheat-sheets), Free Online Books, Courses, Videos and Lectures, Papers, Tutorials, Researchers, Websites, Datasets

Niraj Lunavat 1.2k Jan 01, 2023
Tensorflow-seq2seq-tutorials - Dynamic seq2seq in TensorFlow, step by step

seq2seq with TensorFlow Collection of unfinished tutorials. May be good for educational purposes. 1 - simple sequence-to-sequence model with dynamic u

Matvey Ezhov 1k Dec 17, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
MINIROCKET: A Very Fast (Almost) Deterministic Transform for Time Series Classification

MINIROCKET: A Very Fast (Almost) Deterministic Transform for Time Series Classification

187 Dec 26, 2022
PyTorch implementation of Federated Learning with Non-IID Data, and federated learning algorithms, including FedAvg, FedProx.

Federated Learning with Non-IID Data This is an implementation of the following paper: Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, Vik

Youngjoon Lee 48 Dec 29, 2022
Hybrid Neural Fusion for Full-frame Video Stabilization

FuSta: Hybrid Neural Fusion for Full-frame Video Stabilization Project Page | Video | Paper | Google Colab Setup Setup environment for [Yu and Ramamoo

Yu-Lun Liu 430 Jan 04, 2023
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN)

Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN) This is the implementation of the paper Multi-Age

Future Power Networks 83 Jan 06, 2023
A resource for learning about deep learning techniques from regression to LSTM and Reinforcement Learning using financial data and the fitness functions of algorithmic trading

A tour through tensorflow with financial data I present several models ranging in complexity from simple regression to LSTM and policy networks. The s

195 Dec 07, 2022
DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism (SVS & TTS); AAAI 2022; Official code

DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism This repository is the official PyTorch implementation of our AAAI-2022 paper, in

Jinglin Liu 803 Dec 28, 2022
AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning

AutoPentest-DRL: Automated Penetration Testing Using Deep Reinforcement Learning AutoPentest-DRL is an automated penetration testing framework based o

Cyber Range Organization and Design Chair 217 Jan 01, 2023
A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

ssnt-loss ℹ️ This is a WIP project. the implementation is still being tested. A pure PyTorch implementation of the loss described in "Online Segment t

張致強 1 Feb 09, 2022
Official code of "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network."

R2RNet Official code of "R2RNet: Low-light Image Enhancement via Real-low to Real-normal Network." Jiang Hai, Zhu Xuan, Ren Yang, Yutong Hao, Fengzhu

77 Dec 24, 2022
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Ilyes Khemakhem 65 Dec 22, 2022
A lossless neural compression framework built on top of JAX.

Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible

Rosalind Franklin Institute 2 Mar 14, 2022
A powerful framework for decentralized federated learning with user-defined communication topology

Scatterbrained Decentralized Federated Learning Scatterbrained makes it easy to build federated learning systems. In addition to traditional federated

Johns Hopkins Applied Physics Laboratory 7 Sep 26, 2022
Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order of magnitude using coresets and data selection.

COResets and Data Subset selection Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order

decile-team 244 Jan 09, 2023
RNN Predict Street Commercial Vitality

RNN-for-Predicting-Street-Vitality Code and dataset for Predicting the Vitality of Stores along the Street based on Business Type Sequence via Recurre

Zidong LIU 1 Dec 15, 2021
Feature board for ERPNext

ERPNext Feature Board Feature board for ERPNext Development Prerequisites k3d kubectl helm bench Install K3d Cluster # export K3D_FIX_CGROUPV2=1 # use

Revant Nandgaonkar 16 Nov 09, 2022
Machine learning and Deep learning models, deploy on telegram (the best social media)

Semi Intelligent BOT The project involves : Classifying fake news Classifying objects such as aeroplane, automobile, bird, cat, deer, dog, frog, horse

MohammadReza Norouzi 5 Mar 06, 2022