Ensembling Off-the-shelf Models for GAN Training

Overview

Data-Efficient GANs with DiffAugment

project | paper | datasets | video | slides

Generated using only 100 images of Obama, grumpy cats, pandas, the Bridge of Sighs, the Medici Fountain, the Temple of Heaven, without pre-training.

[NEW!] PyTorch training with DiffAugment-stylegan2-pytorch is now available!

[NEW!] Our Colab tutorial is released!

[NEW!] FFHQ training is supported! See the DiffAugment-stylegan2 README.

[NEW!] Time to generate 100-shot interpolation videos with generate_gif.py!

[NEW!] Our DiffAugment-biggan-imagenet repo (for TPU training) is released!

[NEW!] Our DiffAugment-biggan-cifar PyTorch repo is released!

This repository contains our implementation of Differentiable Augmentation (DiffAugment) in both PyTorch and TensorFlow. It can be used to significantly improve the data efficiency for GAN training. We have provided DiffAugment-stylegan2 (TensorFlow) and DiffAugment-stylegan2-pytorch, DiffAugment-biggan-cifar (PyTorch) for GPU training, and DiffAugment-biggan-imagenet (TensorFlow) for TPU training.

Low-shot generation without pre-training. With DiffAugment, our model can generate high-fidelity images using only 100 Obama portraits, grumpy cats, or pandas from our collected 100-shot datasets, 160 cats or 389 dogs from the AnimalFace dataset at 256×256 resolution.

Unconditional generation results on CIFAR-10. StyleGAN2’s performance drastically degrades given less training data. With DiffAugment, we are able to roughly match its FID and outperform its Inception Score (IS) using only 20% training data.

Differentiable Augmentation for Data-Efficient GAN Training
Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
MIT, Tsinghua University, Adobe Research, CMU
arXiv

Overview

Overview of DiffAugment for updating D (left) and G (right). DiffAugment applies the augmentation T to both the real sample x and the generated output G(z). When we update G, gradients need to be back-propagated through T (iii), which requires T to be differentiable w.r.t. the input.

Training and Generation with 100 Images

To generate an interpolation video using our pre-trained models:

cd DiffAugment-stylegan2
python generate_gif.py -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o obama.gif

or to train a new model:

python run_low_shot.py --dataset=100-shot-obama --num-gpus=4

You may also try out 100-shot-grumpy_cat, 100-shot-panda, 100-shot-bridge_of_sighs, 100-shot-medici_fountain, 100-shot-temple_of_heaven, 100-shot-wuzhen, or the folder containing your own training images. Please refer to the DiffAugment-stylegan2 README for the dependencies and details.

[NEW!] PyTorch training is now available:

cd DiffAugment-stylegan2-pytorch
python train.py --outdir=training-runs --data=https://data-efficient-gans.mit.edu/datasets/100-shot-obama.zip --gpus=1

DiffAugment for StyleGAN2

To run StyleGAN2 + DiffAugment for unconditional generation on the 100-shot datasets, CIFAR, FFHQ, or LSUN, please refer to the DiffAugment-stylegan2 README or DiffAugment-stylegan2-pytorch for the PyTorch version.

DiffAugment for BigGAN

Please refer to the DiffAugment-biggan-cifar README to run BigGAN + DiffAugment for conditional generation on CIFAR (using GPUs), and the DiffAugment-biggan-imagenet README to run on ImageNet (using TPUs).

Using DiffAugment for Your Own Training

To help you use DiffAugment in your own codebase, we provide portable DiffAugment operations of both TensorFlow and PyTorch versions in DiffAugment_tf.py and DiffAugment_pytorch.py. Generally, DiffAugment can be easily adopted in any model by substituting every D(x) with D(T(x)), where x can be real images or fake images, D is the discriminator, and T is the DiffAugment operation. For example,

from DiffAugment_pytorch import DiffAugment
# from DiffAugment_tf import DiffAugment
policy = 'color,translation,cutout' # If your dataset is as small as ours (e.g.,
# hundreds of images), we recommend using the strongest Color + Translation + Cutout.
# For large datasets, try using a subset of transformations in ['color', 'translation', 'cutout'].
# Welcome to discover more DiffAugment transformations!

...
# Training loop: update D
reals = sample_real_images() # a batch of real images
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
real_scores = Discriminator(DiffAugment(reals, policy=policy))
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating D's loss based on real_scores and fake_scores...
...

...
# Training loop: update G
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating G's loss based on fake_scores...
...

We have implemented Color, Translation, and Cutout DiffAugment as visualized below:

Citation

If you find this code helpful, please cite our paper:

@inproceedings{zhao2020diffaugment,
  title={Differentiable Augmentation for Data-Efficient GAN Training},
  author={Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}

Acknowledgements

We thank NSF Career Award #1943349, MIT-IBM Watson AI Lab, Google, Adobe, and Sony for supporting this research. Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC). We thank William S. Peebles and Yijun Li for helpful comments.

Owner
MIT HAN Lab
Accelerating Deep Learning Computing
MIT HAN Lab
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Make differentially private training of transformers easy for everyone

private-transformers This codebase facilitates fast experimentation of differentially private training of Hugging Face transformers. What is this? Why

Xuechen Li 73 Dec 28, 2022
Neural Radiance Fields Using PyTorch

This project is a PyTorch implementation of Neural Radiance Fields (NeRF) for reproduction of results whilst running at a faster speed.

Vedant Ghodke 1 Feb 11, 2022
A PyTorch Implementation of "Watch Your Step: Learning Node Embeddings via Graph Attention" (NeurIPS 2018).

Attention Walk ⠀⠀ A PyTorch Implementation of Watch Your Step: Learning Node Embeddings via Graph Attention (NIPS 2018). Abstract Graph embedding meth

Benedek Rozemberczki 303 Dec 09, 2022
A general framework for inferring CNNs efficiently. Reduce the inference latency of MobileNet-V3 by 1.3x on an iPhone XS Max without sacrificing accuracy.

GFNet-Pytorch (NeurIPS 2020) This repo contains the official code and pre-trained models for the glance and focus network (GFNet). Glance and Focus: a

Rainforest Wang 169 Oct 28, 2022
Code for "MetaMorph: Learning Universal Controllers with Transformers", Gupta et al, ICLR 2022

MetaMorph: Learning Universal Controllers with Transformers This is the code for the paper MetaMorph: Learning Universal Controllers with Transformers

Agrim Gupta 50 Jan 03, 2023
Additional environments compatible with OpenAI gym

Decentralized Control of Quadrotor Swarms with End-to-end Deep Reinforcement Learning A codebase for training reinforcement learning policies for quad

Zhehui Huang 40 Dec 06, 2022
Official code for CVPR2022 paper: Depth-Aware Generative Adversarial Network for Talking Head Video Generation

📖 Depth-Aware Generative Adversarial Network for Talking Head Video Generation (CVPR 2022) 🔥 If DaGAN is helpful in your photos/projects, please hel

Fa-Ting Hong 503 Jan 04, 2023
gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions

gtfs2vec This is a companion repository for a gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions publication. Vis

Politechnika Wrocławska - repozytorium dla informatyków 5 Oct 10, 2022
[ICRA2021] Reconstructing Interactive 3D Scene by Panoptic Mapping and CAD Model Alignment

Interactive Scene Reconstruction Project Page | Paper This repository contains the implementation of our ICRA2021 paper Reconstructing Interactive 3D

97 Dec 28, 2022
Official PyTorch implementation of the paper "Graph-based Generative Face Anonymisation with Pose Preservation" in ICIAP 2021

Contents AnonyGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evaluation Acknowledgments Citat

Nicola Dall'Asen 10 May 24, 2022
A Broader Picture of Random-walk Based Graph Embedding

Random-walk Embedding Framework This repository is a reference implementation of the random-walk embedding framework as described in the paper: A Broa

Zexi Huang 23 Dec 13, 2022
[AAAI22] Reliable Propagation-Correction Modulation for Video Object Segmentation

Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22) Preview version paper of this work is available at: https://arxiv.or

Xiaohao Xu 70 Dec 04, 2022
An unofficial PyTorch implementation of a federated learning algorithm, FedAvg.

Federated Averaging (FedAvg) in PyTorch An unofficial implementation of FederatedAveraging (or FedAvg) algorithm proposed in the paper Communication-E

Seok-Ju Hahn 123 Jan 06, 2023
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data (NeurIPS 2021) This repository will provide the official PyTorch implementa

Liming Jiang 238 Nov 25, 2022
Source code for TACL paper "KEPLER: A Unified Model for Knowledge Embedding and Pre-trained Language Representation".

KEPLER: A Unified Model for Knowledge Embedding and Pre-trained Language Representation Source code for TACL 2021 paper KEPLER: A Unified Model for Kn

THU-KEG 138 Dec 22, 2022
[NeurIPS 2021] A weak-shot object detection approach by transferring semantic similarity and mask prior.

TransMaS This repository is the official pytorch implementation of the following paper: NIPS2021 Mixed Supervised Object Detection by TransferringMask

BCMI 49 Jul 27, 2022
An interactive DNN Model deployed on web that predicts the chance of heart failure for a patient with an accuracy of 98%

Heart Failure Predictor About A Web UI deployed Dense Neural Network Model Made using Tensorflow that predicts whether the patient is healthy or has c

Adit Ahmedabadi 0 Jan 09, 2022
Deep Residual Learning for Image Recognition

Deep Residual Learning for Image Recognition This is a Torch implementation of "Deep Residual Learning for Image Recognition",Kaiming He, Xiangyu Zhan

Kimmy 561 Dec 01, 2022
Tools for the Cleveland State Human Motion and Control Lab

Introduction This is a collection of tools that are helpful for gait analysis. Some are specific to the needs of the Human Motion and Control Lab at C

CSU Human Motion and Control Lab 88 Dec 16, 2022