Ensembling Off-the-shelf Models for GAN Training

Overview

Vision-aided GAN

video (3m) | website | paper







Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?

We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.

Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
arXiv 2112.09130, 2021

Quantitative Comparison


Our method outperforms recent GAN training methods by a large margin, especially in limited sample setting. For LSUN Cat, we achieve similar FID as StyleGAN2 trained on the full dataset using only $0.7%$ of the dataset. On the full dataset, our method improves FID by 1.5x to 2x on cat, church, and horse categories of LSUN.

Example Results

Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the same randomly sample latent code.

Interpolation Videos

Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).


Requirements

  • 64-bit Python 3.8 and PyTorch 1.8.0 (or later). See https://pytorch.org/ for PyTorch install instructions.
  • Cuda toolkit 11.0 or later.
  • python libraries: see requirements.txt
  • StyleGAN2 code relies heavily on custom PyTorch extensions. For detail please refer to the repo stylegan2-ada-pytorch

Setting up Off-the-shelf Computer Vision models

CLIP(ViT): we modify the model.py function to return intermediate features of the transformer model. To set up follow these steps.

git clone https://github.com/openai/CLIP.git
cp vision-aided-gan/training/clip_model.py CLIP/clip/model.py
cd CLIP
python setup.py install

DINO(ViT): model is automatically downloaded from torch hub.

VGG-16: model is automatically downloaded.

Swin-T(MoBY): Create a pretrained-models directory and save the downloaded model there.

Swin-T(Object Detection): follow the below step for setup. Download the model here and save it in the pretrained-models directory.

git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
python setup.py install

for more details on mmcv installation please refer here

Swin-T(Segmentation): follow the below step for setup. Download the model here and save it in the pretrained-models directory.

git clone https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation.git
cd Swin-Transformer-Semantic-Segmentation
python setup.py install

Face Parsing:download the model here and save in the pretrained-models directory.

Face Normals:download the model here and save in the pretrained-models directory.

Pretrained Models

Our final trained models can be downloaded at this link

To generate images:

python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 --network=<network.pkl>

The output is stored in out directory controlled by --outdir. Our generator architecture is same as styleGAN2 and can be similarly used in the Python code as described in stylegan2-ada-pytorch.

model evaluation:

python calc_metrics.py --network <network.pkl> --metrics fid50k_full --data <dataset> --clean 1

We use clean-fid library to calculate FID metric. For LSUN Church and LSUN Horse, we calclate the full real distribution statistics. For details on calculating the real distribution statistics, please refer to clean-fid. For default FID evaluation of StyleGAN2-ADA use clean=0.

Datasets

Dataset preparation is same as given in stylegan2-ada-pytorch. Example setup for LSUN Church

LSUN Church

git clone https://github.com/fyu/lsun.git
cd lsun
python3 download.py -c church_outdoor
unzip church_outdoor_train_lmdb.zip
cd ../vision-aided-gan
python dataset_tool.py --source <path-to>/church_outdoor_train_lmdb/ --dest <path-to-datasets>/church1k.zip --max-images 1000  --transform=center-crop --width=256 --height=256

datasets can be downloaded from their repsective websites:

FFHQ, LSUN Categories, AFHQ, AnimalFace Dog, AnimalFace Cat, 100-shot Bridge-of-Sighs

Training new networks

model selection: returns the computer vision model with highest linear probe accuracy for the best FID model in a folder or the given network file.

python model_selection.py --data mydataset.zip --network  <mynetworkfolder or mynetworkpklfile>

example training command for training with a single pretrained network from scratch

python train.py --outdir=training-models/ --data=mydataset.zip --gpus 2 --metrics fid50k_full --kimg 25000 --cfg paper256 --cv input-dino-output-conv_multi_level --cv-loss multilevel_s --augcv ada --ada-target-cv 0.3 --augpipecv bgc --batch 16 --mirror 1 --aug ada --augpipe bgc --snap 25 --warmup 1  

Training configuration corresponding to training with vision-aided-loss:

  • --cv=input-dino-output-conv_multi_level pretrained network and its configuration.
  • --warmup=0 should be enabled when training from scratch. Introduces our loss after training with 500k images.
  • --cv-loss=multilevel what loss to use on pretrained model based discriminator.
  • --augcv=ada performs ADA augmentation on pretrained model based discriminator.
  • --augcv=diffaugment-<policy> performs DiffAugment on pretrained model based discriminator with given poilcy.
  • --augpipecv=bgc ADA augmentation strategy. Note: cutout is always enabled.
  • --ada-target-cv=0.3 adjusts ADA target value for pretrained model based discriminator.
  • --exact-resume=0 enables exact resume along with optimizer state.

Miscellaneous configurations:

  • --appendname='' additional string to append to training directory name.
  • --wandb-log=0 enables wandb logging.
  • --clean=0 enables FID calculation using clean-fid if the real distribution statistics are pre-calculated.

Run python train.py --help for more details and the full list of args.

References

@article{kumari2021ensembling,
  title={Ensembling Off-the-shelf Models for GAN Training},
  author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
  journal={arXiv preprint arXiv:2112.09130},
  year={2021}
}

Acknowledgments

We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.

Match SafeGraph POIs with Data collected through a cultural resource survey in Washington DC.

Match SafeGraph POI data with Cultural Resource Places in Washington DC Match SafeGraph POIs with Data collected through a cultural resource survey in

Changjie Chen 1 Jan 05, 2022
Bayesian optimisation library developped by Huawei Noah's Ark Library

Bayesian Optimisation Research This directory contains official implementations for Bayesian optimisation works developped by Huawei R&D, Noah's Ark L

HUAWEI Noah's Ark Lab 395 Dec 30, 2022
A graph adversarial learning toolbox based on PyTorch and DGL.

GraphWar: Arms Race in Graph Adversarial Learning NOTE: GraphWar is still in the early stages and the API will likely continue to change. 🚀 Installat

Jintang Li 54 Jan 05, 2023
This GitHub repository contains code used for plots in NeurIPS 2021 paper 'Stochastic Multi-Armed Bandits with Control Variates.'

About Repository This repository contains code used for plots in NeurIPS 2021 paper 'Stochastic Multi-Armed Bandits with Control Variates.' About Code

Arun Verma 1 Nov 09, 2021
TriMap: Large-scale Dimensionality Reduction Using Triplets

TriMap TriMap is a dimensionality reduction method that uses triplet constraints to form a low-dimensional embedding of a set of points. The triplet c

Ehsan Amid 235 Dec 24, 2022
🥈78th place in Riiid Answer Correctness Prediction competition

Riiid Answer Correctness Prediction Introduction This repository is the code that placed 78th in Riiid Answer Correctness Prediction competition. Requ

Jungwoo Park 10 Jul 14, 2022
Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data - Official PyTorch Implementation (CVPR 2022)

Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data (CVPR 2022) Potentials of primitive shapes f

31 Sep 27, 2022
Investigating Attention Mechanism in 3D Point Cloud Object Detection (arXiv 2021)

Investigating Attention Mechanism in 3D Point Cloud Object Detection (arXiv 2021) This repository is for the following paper: "Investigating Attention

52 Nov 19, 2022
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 08, 2023
DiscoBox: Weakly Supervised Instance Segmentation and Semantic Correspondence from Box Supervision

The Official PyTorch Implementation of DiscoBox: Weakly Supervised Instance Segmentation and Semantic Correspondence from Box Supervision

Shiyi Lan 3 Oct 15, 2021
Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch

Lie Transformer - Pytorch (wip) Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in thi

Phil Wang 78 Oct 26, 2022
[ICML 2021, Long Talk] Delving into Deep Imbalanced Regression

Delving into Deep Imbalanced Regression This repository contains the implementation code for paper: Delving into Deep Imbalanced Regression Yuzhe Yang

Yuzhe Yang 568 Dec 30, 2022
S2s2net - Sentinel-2 Super-Resolution Segmentation Network

S2S2Net Sentinel-2 Super-Resolution Segmentation Network Getting started Install

Wei Ji 10 Nov 10, 2022
Rendering color and depth images for ShapeNet models.

Color & Depth Renderer for ShapeNet This library includes the tools for rendering multi-view color and depth images of ShapeNet models. Physically bas

Yinyu Nie 41 Dec 19, 2022
[CoRL 21'] TANDEM: Tracking and Dense Mapping in Real-time using Deep Multi-view Stereo

TANDEM: Tracking and Dense Mapping in Real-time using Deep Multi-view Stereo Lukas Koestler1*    Nan Yang1,2*,†    Niclas Zeller2,3    Daniel Cremers1

TUM Computer Vision Group 744 Jan 04, 2023
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
Ankou: Guiding Grey-box Fuzzing towards Combinatorial Difference

Ankou Ankou is a source-based grey-box fuzzer. It intends to use a more rich fitness function by going beyond simple branch coverage and considering t

SoftSec Lab 54 Dec 24, 2022
Unicorn can be used for performance analyses of highly configurable systems with causal reasoning

Unicorn can be used for performance analyses of highly configurable systems with causal reasoning. Users or developers can query Unicorn for a performance task.

AISys Lab 27 Jan 05, 2023
Official Pytorch implementation for video neural representation (NeRV)

NeRV: Neural Representations for Videos (NeurIPS 2021) Project Page | Paper | UVG Data Hao Chen, Bo He, Hanyu Wang, Yixuan Ren, Ser-Nam Lim, Abhinav S

hao 214 Dec 28, 2022
In this project, we'll be making our own screen recorder in Python using some libraries.

Screen Recorder in Python Project Description: In this project, we'll be making our own screen recorder in Python using some libraries. Requirements:

Hassan Shahzad 4 Jan 24, 2022