Codebase for Diffusion Models Beat GANS on Image Synthesis.

Overview

guided-diffusion

This is the codebase for Diffusion Models Beat GANS on Image Synthesis.

This repository is based on openai/improved-diffusion, with modifications for classifier conditioning and architecture improvements.

Download pre-trained models

We have released checkpoints for the main models in the paper. Before using these models, please review the corresponding model card to understand the intended use and limitations of these models.

Here are the download links for each model checkpoint:

Sampling from pre-trained models

To sample from these models, you can use the classifier_sample.py, image_sample.py, and super_res_sample.py scripts. Here, we provide flags for sampling from all of these models. We assume that you have downloaded the relevant model checkpoints into a folder called models/.

For these examples, we will generate 100 samples with batch size 4. Feel free to change these values.

SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 250"

Classifier guidance

Note for these sampling runs that you can set --classifier_scale 0 to sample from the base diffusion model. You may also use the image_sample.py script instead of classifier_sample.py in that case.

  • 64x64 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 1.0 --classifier_path models/64x64_classifier.pt --model_path models/64x64_diffusion.pt $SAMPLE_FLAGS
  • 128x128 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 0.5 --classifier_path models/128x128_classifier.pt --model_path models/128x128_diffusion.pt $SAMPLE_FLAGS
  • 256x256 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 1.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS
  • 256x256 model (unconditional):
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion.pt $SAMPLE_FLAGS
  • 512x512 model:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 512 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True"
python classifier_sample.py $MODEL_FLAGS --classifier_scale 4.0 --classifier_path models/512x512_classifier.pt --model_path models/512x512_diffusion.pt $SAMPLE_FLAGS

Upsampling

For these runs, we assume you have some base samples in a file 64_samples.npz or 128_samples.npz for the two respective models.

  • 64 -> 256:
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --large_size 256  --small_size 64 --learn_sigma True --noise_schedule linear --num_channels 192 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python super_res_sample.py $MODEL_FLAGS --model_path models/64_256_upsampler.pt --base_samples 64_samples.npz $SAMPLE_FLAGS
  • 128 -> 512:
MODEL_FLAGS="--attention_resolutions 32,16 --class_cond True --diffusion_steps 1000 --large_size 512 --small_size 128 --learn_sigma True --noise_schedule linear --num_channels 192 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python super_res_sample.py $MODEL_FLAGS --model_path models/128_512_upsampler.pt $SAMPLE_FLAGS --base_samples 128_samples.npz

LSUN models

These models are class-unconditional and correspond to a single LSUN class. Here, we show how to sample from lsun_bedroom.pt, but the other two LSUN checkpoints should work as well:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python image_sample.py $MODEL_FLAGS --model_path models/lsun_bedroom.pt $SAMPLE_FLAGS

You can sample from lsun_horse_nodropout.pt by changing the dropout flag:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
python image_sample.py $MODEL_FLAGS --model_path models/lsun_horse_nodropout.pt $SAMPLE_FLAGS

Note that for these models, the best samples result from using 1000 timesteps:

SAMPLE_FLAGS="--batch_size 4 --num_samples 100 --timestep_respacing 1000"

Results

This table summarizes our ImageNet results for pure guided diffusion models:

Dataset FID Precision Recall
ImageNet 64x64 2.07 0.74 0.63
ImageNet 128x128 2.97 0.78 0.59
ImageNet 256x256 4.59 0.82 0.52
ImageNet 512x512 7.72 0.87 0.42

This table shows the best results for high resolutions when using upsampling and guidance together:

Dataset FID Precision Recall
ImageNet 256x256 3.94 0.83 0.53
ImageNet 512x512 3.85 0.84 0.53

Finally, here are the unguided results on individual LSUN classes:

Dataset FID Precision Recall
LSUN Bedroom 1.90 0.66 0.51
LSUN Cat 5.57 0.63 0.52
LSUN Horse 2.57 0.71 0.55

Training models

Training diffusion models is described in the parent repository. Training a classifier is similar. We assume you have put training hyperparameters into a TRAIN_FLAGS variable, and classifier hyperparameters into a CLASSIFIER_FLAGS variable. Then you can run:

mpiexec -n N python scripts/classifier_train.py --data_dir path/to/imagenet $TRAIN_FLAGS $CLASSIFIER_FLAGS

Make sure to divide the batch size in TRAIN_FLAGS by the number of MPI processes you are using.

Here are flags for training the 128x128 classifier. You can modify these for training classifiers at other resolutions:

TRAIN_FLAGS="--iterations 300000 --anneal_lr True --batch_size 256 --lr 3e-4 --save_interval 10000 --weight_decay 0.05"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True"

For sampling from a 128x128 classifier-guided model, 25 step DDIM:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --image_size 128 --learn_sigma True --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True --classifier_scale 1.0 --classifier_use_fp16 True"
SAMPLE_FLAGS="--batch_size 4 --num_samples 50000 --timestep_respacing ddim25 --use_ddim True"
mpiexec -n N python scripts/classifier_sample.py \
    --model_path /path/to/model.pt \
    --classifier_path path/to/classifier.pt \
    $MODEL_FLAGS $CLASSIFIER_FLAGS $SAMPLE_FLAGS

To sample for 250 timesteps without DDIM, replace --timestep_respacing ddim25 to --timestep_respacing 250, and replace --use_ddim True with --use_ddim False.

Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch

Perceiver - Pytorch Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch Install $ pip install perceiver-pytorch Usage

Phil Wang 876 Dec 29, 2022
Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction

GraviCap Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction. Gravity-Aware Monocular 3D Human-Object

Rishabh Dabral 15 Dec 09, 2022
Face uncertainty quantification or estimation using PyTorch.

Face-uncertainty-pytorch This is a demo code of face uncertainty quantification or estimation using PyTorch. The uncertainty of face recognition is af

Kaen 3 Sep 16, 2022
Official implementation of "A Shared Representation for Photorealistic Driving Simulators" in PyTorch.

A Shared Representation for Photorealistic Driving Simulators The official code for the paper: "A Shared Representation for Photorealistic Driving Sim

VITA lab at EPFL 7 Oct 13, 2022
A fast Evolution Strategy implementation in Python

Evostra: Evolution Strategy for Python Evolution Strategy (ES) is an optimization technique based on ideas of adaptation and evolution. You can learn

Mika 251 Dec 08, 2022
Python with OpenCV - MediaPip Framework Hand Detection

Python HandDetection Python with OpenCV - MediaPip Framework Hand Detection Explore the docs ยป Contact Me About The Project It is a Computer vision pa

2 Jan 07, 2022
STARCH compuets regional extreme storm physical characteristics and moisture balance based on spatiotemporal precipitation data from reanalysis or climate model data.

STARCH (Storm Tracking And Regional CHaracterization) STARCH computes regional extreme storm physical and moisture balance characteristics based on sp

Onosama 7 Oct 20, 2022
Official project repository for 'Normality-Calibrated Autoencoder for Unsupervised Anomaly Detection on Data Contamination'

NCAE_UAD Official project repository of 'Normality-Calibrated Autoencoder for Unsupervised Anomaly Detection on Data Contamination' Abstract In this p

Jongmin Andrew Yu 2 Feb 10, 2022
Hough Transform and Hough Line Transform Using OpenCV

Hough transform is a feature extraction method for detecting simple shapes such as circles, lines, etc in an image. Hough Transform and Hough Line Transform is implemented in OpenCV with two methods;

Happy N. Monday 3 Feb 15, 2022
Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution.

convolver Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution. Created by Sean Higley

Sean Higley 1 Feb 23, 2022
KoCLIP: Korean port of OpenAI CLIP, in Flax

KoCLIP This repository contains code for KoCLIP, a Korean port of OpenAI's CLIP. This project was conducted as part of Hugging Face's Flax/JAX communi

Jake Tae 100 Jan 02, 2023
Text and code for the forthcoming second edition of Think Bayes, by Allen Downey.

Think Bayes 2 by Allen B. Downey The HTML version of this book is here. Think Bayes is an introduction to Bayesian statistics using computational meth

Allen Downey 1.5k Jan 08, 2023
A unified framework to jointly model images, text, and human attention traces.

connect-caption-and-trace This repository contains the reference code for our paper Connecting What to Say With Where to Look by Modeling Human Attent

Meta Research 73 Oct 24, 2022
๐Ÿ’Š A 3D Generative Model for Structure-Based Drug Design (NeurIPS 2021)

A 3D Generative Model for Structure-Based Drug Design Coming soon... Citation @inproceedings{luo2021sbdd, title={A 3D Generative Model for Structu

Shitong Luo 118 Jan 05, 2023
Demos of essentia classifiers hosted on replicate.ai

essentia-replicate-demos Demos of Essentia models hosted on replicate.ai's MTG site. The models Check our site for a complete list of the models avail

Music Technology Group - Universitat Pompeu Fabra 12 Nov 14, 2022
Paper: Cross-View Kernel Similarity Metric Learning Using Pairwise Constraints for Person Re-identification

Cross-View Kernel Similarity Metric Learning Using Pairwise Constraints for Person Re-identification T M Feroz Ali, Subhasis Chaudhuri, ICVGIP-20-21

T M Feroz Ali 3 Jun 17, 2022
A Pytorch implementation of MoveNet from Google. Include training code and pre-train model.

Movenet.Pytorch Intro MoveNet is an ultra fast and accurate model that detects 17 keypoints of a body. This is A Pytorch implementation of MoveNet fro

Mr.Fire 241 Dec 26, 2022
CCPD: a diverse and well-annotated dataset for license plate detection and recognition

CCPD (Chinese City Parking Dataset, ECCV) UPdate on 10/03/2019. CCPD Dataset is now updated. We are confident that images in subsets of CCPD is much m

detectRecog 1.8k Dec 30, 2022
Official implementation for the paper: Multi-label Classification with Partial Annotations using Class-aware Selective Loss

Multi-label Classification with Partial Annotations using Class-aware Selective Loss Paper | Pretrained models Official PyTorch Implementation Emanuel

99 Dec 27, 2022
mmfewshot is an open source few shot learning toolbox based on PyTorch

OpenMMLab FewShot Learning Toolbox and Benchmark

OpenMMLab 514 Dec 28, 2022