A minimal yet resourceful implementation of diffusion models (along with pretrained models + synthetic images for nine datasets)

Overview

Minimal implementation of diffusion models

A minimal implementation of diffusion models with the goal to democratize the use of synthetic data from these models.

Check out the experimental results section for quantitative numbers on quality of synthetic data and FAQs for a broader discussion. We experiments with nine commonly used datasets, and released all assets, including models and synthetic data for each of them.

Requirements: pip install scipy opencv-python. We assume torch and torchvision are already installed.

Structure

main.py  - Train or sample from a diffusion model.
unets.py - UNet based network architecture for diffusion model.
data.py  - Common datasets and their metadata.
──  scripts
     └── train.sh  - Training scripts for all datasets.
     └── sample.sh - Sampling scripts for all datasets.

Training

Use the following command to train the diffusion model on four gpus.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --epochs 500

We provide the exact script used for training in ./scripts/train.sh.

Sampling

We reuse main.py for sampling but with the --sampling-only only flag. Use the following command to sample 50K images from a pretrained diffusion model.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py \
  --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 \
  --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model

We provide the exact script used for sampling in ./scripts/sample.sh.

How useful is synthetic data from diffusion models? 🤔

Takeaway: Across all datasets, training only on synthetic data suffice to achieve a competitive classification score on real data.

Goal: Our goal is to not only measure photo-realism of synthetic images but also measure how well synthetic images cover the data distribution, i.e., how diverse is synthetic data. Note that a generative model, commonly GANs, can generate high-quality images, but still fail to generate diverse images.

Choice of datasets: We use nine commonly used datasets in image recognition. The goal was to multiple datasets was to capture enough diversity in terms of the number of samples, the number of classes, and coarse vs fine-grained classification. In addition, by using a common setup across datasets, we can test the success of diffusion models without any assumptions about the dataset.

Diffusion model: For each dataset, we train a class-conditional diffusion model. We choose a modest size network and train it for a limited number of hours on a 4xA4000 cluster, as highlighted by the training time in the table below. Next, we sample 50,000 synthetic images from the diffusion model.

Metric to measure synthetic data quality: We train a ResNet50 classifier on only real images and another one on only synthetic images and measure their accuracy on the validation set of real images. This metric is also referred to as classification accuracy score and it provides us a way to measure both quality and diversity of synthetic data in a unified manner across datasets.

Released assets for each dataset: Pre-trained Diffusion models, 50,000 synthetic images for each dataset, and downstream clasifiers trained with real-only or synthetic-only dataset.

Table 1: Training images and classes refer to the number of training images and the number of classes in the dataset. Training time refers to the time taken to train the diffusion model. Real only is the test set accuracy of ResNet-50 model trained on only real training images. Synthetic accuracy is the test accuracy of the ResNet-50 model trained on only 50K synthetic images.

Dataset Training images Classes Training time (hours) Real only Synthetic only
MNIST 60,000 10 2.1 99.6 99.0
MNIST-M 60,000 10 5.3 99.3 97.3
CIFAR-10 50,000 10 10.7 93.8 87.3
Skin Cancer* 33126 2 19.1 69.7 64.1
AFHQ 14630 3 8.6 97.9 98.7
CelebA 109036 4 12.8 90.1 88.9
Standford Cars 8144 196 7.4 33.7 76.6
Oxford Flowers 2040 102 6.0 29.3 76.3
Traffic signs 39252 43 8.3 96.6 96.1

* Due to heavy class imbalance, we use AUROC to measure classification performance.

Note: Except for CIFAR10, MNIST, MNIST-M, and GTSRB, we use 64x64 image resolution for all datasets. The key reason to use a lower resolution was to reduce the computational resources needed to train the diffusion model.

Discussion: Across most datasets training only on synthetic data achieves competitive performance with training on real data. It shows that the synthetic data 1) has high-quality images, otherwise the model wouldn't have learned much from it 2) high coverage of distribution, otherwise, the model trained on synthetic data won't do well on the whole test set. Even more, the synthetic dataset has a unique advantage: we can easily generate a very large amount of it. This difference is clearly visible for the low-data regime (flowers and cars dataset), where training on synthetic data (50K images) achieves much better performance than training on real data, which has less than 10K images. A more principled investigation of sample complexity, i.e., performance vs number-of-synthetic-images is available in one of my previous papers (fig. 9).

FAQs

Q. Why use diffusion models?
A. This question is super broad and has multiple answers. 1) They are super easy to train. Unlike GANs, there are no training instabilities in the optimization process. 2) The mode coverage of the diffusion models is excellent where at the same time the generated images are quite photorealistic. 3) The training pipeline is also consistent across datasets, i.e., no assumption about the data. For all datasets above, the only parameter we changed was the amount of training time.

Q. Is synthetic data from diffusion models much different from other generative models, in particular GANs?
A. As mentioned in the previous answer, synthetic data from diffusion models have much higher coverage than GANs, while having a similar image quality. Check out the this previous paper by Prafulla Dhariwal and Alex Nichol where they provide extensive results supporting this claim. In the regime of robust training, you can find a more quantitive comparison of diffusion models with multiple GANs in one of my previous papers.

Q. Why classification accuracy on some datasets is so low (e.g., flowers), even when training with real data?
A. Due to many reasons, current classification numbers aren't meant to be competitive with state-of-the-art. 1) We don't tune any hyperparameters across datasets. For each dataset, we train a ResNet50 model with 0.1 learning rate, 1e-4 weight decay, 0.9 momentum, and cosine learning rate decay. 2) Instead of full resolution (commonly 224x224), we use low-resolution images (64x64), which makes classification harder.

Q. Using only synthetic data, how to further improve the test accuracy on real data?
A. Diffusion models benefit tremendously from scaling of the training setup. One can do so by increasing the network width (base_width) and training the network for more epochs (2-4x).

References

This implementation was originally motivated by the original implmentation of diffusion models by Jonathan Ho. I followed the recent PyTorch implementation by OpenAI for common design choices in diffusion models.

The experiments to test out the potential of synthetic data from diffusion models are inspired by one of my previous work. We found that using synthetic data from the diffusion model alone surpasses benefits from multiple algorithmic innovations in robust training, which is one of the simple yet extremely hard problems to solve for neural networks. The next step is to repeat the Table-1 experiments, but this time with robust training.

Visualizing real and synthetic images

For each data, we plot real images on the left and synthetic images on the right. Each row corresponds to a unique class while classes for real and synthetic data are identical.

Light     Dark

MNIST

Light     Dark

MNIST-M

Light     Dark

CIFAR-10

Light     Dark

GTSRB

Light     Dark

Celeb-A

Light     Dark

AFHQ

Light     Dark

Cars

Light     Dark

Flowers

Light     Dark

Melanoma (Skin cancer)

Light     Dark

Note: Real images for each dataset follow the same license as their respective dataset.

Owner
Vikash Sehwag
PhD candidate at Princeton University. Interested in problems at the intersection of Security, Privacy, and Machine leanring.
Vikash Sehwag
Unofficial Tensorflow 2 implementation of the paper Implicit Neural Representations with Periodic Activation Functions

Siren: Implicit Neural Representations with Periodic Activation Functions The unofficial Tensorflow 2 implementation of the paper Implicit Neural Repr

Seyma Yucer 2 Jun 27, 2022
Ranger - a synergistic optimizer using RAdam (Rectified Adam), Gradient Centralization and LookAhead in one codebase

Ranger-Deep-Learning-Optimizer Ranger - a synergistic optimizer combining RAdam (Rectified Adam) and LookAhead, and now GC (gradient centralization) i

Less Wright 1.1k Dec 21, 2022
A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

Zain 1 Feb 01, 2022
A template repository for submitting a job to the Slurm Cluster installed at the DISI - University of Bologna

Cluster di HPC con GPU per esperimenti di calcolo (draft version 1.0) Per poter utilizzare il cluster il primo passo è abilitare l'account istituziona

20 Dec 16, 2022
Caffe models in TensorFlow

Caffe to TensorFlow Convert Caffe models to TensorFlow. Usage Run convert.py to convert an existing Caffe model to TensorFlow. Make sure you're using

Saumitro Dasgupta 2.8k Dec 31, 2022
SAS output to EXCEL converter for Cornell/MIT Language and acquisition lab

CORNELLSASLAB SAS output to EXCEL converter for Cornell/MIT Language and acquisition lab Instructions: This python code can be used to convert SAS out

2 Jan 26, 2022
Official implementation for paper Knowledge Bridging for Empathetic Dialogue Generation (AAAI 2021).

Knowledge Bridging for Empathetic Dialogue Generation This is the official implementation for paper Knowledge Bridging for Empathetic Dialogue Generat

Qintong Li 50 Dec 20, 2022
Build Graph Nets in Tensorflow

Graph Nets library Graph Nets is DeepMind's library for building graph networks in Tensorflow and Sonnet. Contact DeepMind 5.2k Jan 05, 2023

ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

Zongdai 107 Dec 20, 2022
Model that predicts the probability of a Twitter user being anti-vaccination.

stylebody {text-align: justify}/style AVAXTAR: Anti-VAXx Tweet AnalyzeR AVAXTAR is a python package to identify anti-vaccine users on twitter. The

10 Sep 27, 2022
Code for the CVPR 2021 paper: Understanding Failures of Deep Networks via Robust Feature Extraction

Welcome to Barlow Barlow is a tool for identifying the failure modes for a given neural network. To achieve this, Barlow first creates a group of imag

Sahil Singla 33 Dec 05, 2022
Semi-SDP Semi-supervised parser for semantic dependency parsing.

Semi-SDP Semi-supervised parser for semantic dependency parsing. This repo contains the code used for the semi-supervised semantic dependency parser i

12 Sep 17, 2021
A flexible submap-based framework towards spatio-temporally consistent volumetric mapping and scene understanding.

Panoptic Mapping This package contains panoptic_mapping, a general framework for semantic volumetric mapping. We provide, among other, a submap-based

ETHZ ASL 194 Dec 20, 2022
CharacterGAN: Few-Shot Keypoint Character Animation and Reposing

CharacterGAN Implementation of the paper "CharacterGAN: Few-Shot Keypoint Character Animation and Reposing" by Tobias Hinz, Matthew Fisher, Oliver Wan

Tobias Hinz 181 Dec 27, 2022
Official implementation for the paper "Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection"

Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection PyTorch code release of the paper "Attentive Prototypes for Sour

Deepti Hegde 23 Oct 17, 2022
Pocsploit is a lightweight, flexible and novel open source poc verification framework

Pocsploit is a lightweight, flexible and novel open source poc verification framework

cckuailong 208 Dec 24, 2022
Minimal PyTorch implementation of YOLOv3

A minimal PyTorch implementation of YOLOv3, with support for training, inference and evaluation.

Erik Linder-Norén 6.9k Dec 29, 2022
PyTorch implementation of DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration (BMVC 2021)

DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration [video] [paper] [supplementary] [data] [thesis] Introduction De

Natalie Lang 10 Dec 14, 2022
Full-featured Decision Trees and Random Forests learner.

CID3 This is a full-featured Decision Trees and Random Forests learner. It can save trees or forests to disk for later use. It is possible to query tr

Alejandro Penate-Diaz 3 Aug 15, 2022
traiNNer is an open source image and video restoration (super-resolution, denoising, deblurring and others) and image to image translation toolbox based on PyTorch.

traiNNer traiNNer is an open source image and video restoration (super-resolution, denoising, deblurring and others) and image to image translation to

202 Jan 04, 2023