Ensembling Off-the-shelf Models for GAN Training

Overview

Vision-aided GAN

video (3m) | website | paper







Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?

We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.

Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
arXiv 2112.09130, 2021

Quantitative Comparison


Our method outperforms recent GAN training methods by a large margin, especially in limited sample setting. For LSUN Cat, we achieve similar FID as StyleGAN2 trained on the full dataset using only $0.7%$ of the dataset. On the full dataset, our method improves FID by 1.5x to 2x on cat, church, and horse categories of LSUN.

Example Results

Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the same randomly sample latent code.

Interpolation Videos

Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).


Requirements

  • 64-bit Python 3.8 and PyTorch 1.8.0 (or later). See https://pytorch.org/ for PyTorch install instructions.
  • Cuda toolkit 11.0 or later.
  • python libraries: see requirements.txt
  • StyleGAN2 code relies heavily on custom PyTorch extensions. For detail please refer to the repo stylegan2-ada-pytorch

Setting up Off-the-shelf Computer Vision models

CLIP(ViT): we modify the model.py function to return intermediate features of the transformer model. To set up follow these steps.

git clone https://github.com/openai/CLIP.git
cp vision-aided-gan/training/clip_model.py CLIP/clip/model.py
cd CLIP
python setup.py install

DINO(ViT): model is automatically downloaded from torch hub.

VGG-16: model is automatically downloaded.

Swin-T(MoBY): Create a pretrained-models directory and save the downloaded model there.

Swin-T(Object Detection): follow the below step for setup. Download the model here and save it in the pretrained-models directory.

git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html
python setup.py install

for more details on mmcv installation please refer here

Swin-T(Segmentation): follow the below step for setup. Download the model here and save it in the pretrained-models directory.

git clone https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation.git
cd Swin-Transformer-Semantic-Segmentation
python setup.py install

Face Parsing:download the model here and save in the pretrained-models directory.

Face Normals:download the model here and save in the pretrained-models directory.

Pretrained Models

Our final trained models can be downloaded at this link

To generate images:

python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 --network=<network.pkl>

The output is stored in out directory controlled by --outdir. Our generator architecture is same as styleGAN2 and can be similarly used in the Python code as described in stylegan2-ada-pytorch.

model evaluation:

python calc_metrics.py --network <network.pkl> --metrics fid50k_full --data <dataset> --clean 1

We use clean-fid library to calculate FID metric. For LSUN Church and LSUN Horse, we calclate the full real distribution statistics. For details on calculating the real distribution statistics, please refer to clean-fid. For default FID evaluation of StyleGAN2-ADA use clean=0.

Datasets

Dataset preparation is same as given in stylegan2-ada-pytorch. Example setup for LSUN Church

LSUN Church

git clone https://github.com/fyu/lsun.git
cd lsun
python3 download.py -c church_outdoor
unzip church_outdoor_train_lmdb.zip
cd ../vision-aided-gan
python dataset_tool.py --source <path-to>/church_outdoor_train_lmdb/ --dest <path-to-datasets>/church1k.zip --max-images 1000  --transform=center-crop --width=256 --height=256

datasets can be downloaded from their repsective websites:

FFHQ, LSUN Categories, AFHQ, AnimalFace Dog, AnimalFace Cat, 100-shot Bridge-of-Sighs

Training new networks

model selection: returns the computer vision model with highest linear probe accuracy for the best FID model in a folder or the given network file.

python model_selection.py --data mydataset.zip --network  <mynetworkfolder or mynetworkpklfile>

example training command for training with a single pretrained network from scratch

python train.py --outdir=training-models/ --data=mydataset.zip --gpus 2 --metrics fid50k_full --kimg 25000 --cfg paper256 --cv input-dino-output-conv_multi_level --cv-loss multilevel_s --augcv ada --ada-target-cv 0.3 --augpipecv bgc --batch 16 --mirror 1 --aug ada --augpipe bgc --snap 25 --warmup 1  

Training configuration corresponding to training with vision-aided-loss:

  • --cv=input-dino-output-conv_multi_level pretrained network and its configuration.
  • --warmup=0 should be enabled when training from scratch. Introduces our loss after training with 500k images.
  • --cv-loss=multilevel what loss to use on pretrained model based discriminator.
  • --augcv=ada performs ADA augmentation on pretrained model based discriminator.
  • --augcv=diffaugment-<policy> performs DiffAugment on pretrained model based discriminator with given poilcy.
  • --augpipecv=bgc ADA augmentation strategy. Note: cutout is always enabled.
  • --ada-target-cv=0.3 adjusts ADA target value for pretrained model based discriminator.
  • --exact-resume=0 enables exact resume along with optimizer state.

Miscellaneous configurations:

  • --appendname='' additional string to append to training directory name.
  • --wandb-log=0 enables wandb logging.
  • --clean=0 enables FID calculation using clean-fid if the real distribution statistics are pre-calculated.

Run python train.py --help for more details and the full list of args.

References

@article{kumari2021ensembling,
  title={Ensembling Off-the-shelf Models for GAN Training},
  author={Kumari, Nupur and Zhang, Richard and Shechtman, Eli and Zhu, Jun-Yan},
  journal={arXiv preprint arXiv:2112.09130},
  year={2021}
}

Acknowledgments

We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.

Hyperbolic Hierarchical Clustering.

Hyperbolic Hierarchical Clustering (HypHC) This code is the official PyTorch implementation of the NeurIPS 2020 paper: From Trees to Continuous Embedd

HazyResearch 154 Dec 15, 2022
Compact Bidirectional Transformer for Image Captioning

Compact Bidirectional Transformer for Image Captioning Requirements Python 3.8 Pytorch 1.6 lmdb h5py tensorboardX Prepare Data Please use git clone --

YE Zhou 19 Dec 12, 2022
Official Implementation for HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing

HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing Yuval Alaluf*, Omer Tov*, Ron Mokady, Rinon Gal, Amit H. Bermano *Denotes equ

885 Jan 06, 2023
DeepRec is a recommendation engine based on TensorFlow.

DeepRec Introduction DeepRec is a recommendation engine based on TensorFlow 1.15, Intel-TensorFlow and NVIDIA-TensorFlow. Background Sparse model is a

Alibaba 676 Jan 03, 2023
Weakly Supervised Learning of Rigid 3D Scene Flow

Weakly Supervised Learning of Rigid 3D Scene Flow This repository provides code and data to train and evaluate a weakly supervised method for rigid 3D

Zan Gojcic 124 Dec 27, 2022
PyTorch implementation of the Transformer in Post-LN (Post-LayerNorm) and Pre-LN (Pre-LayerNorm).

Transformer-PyTorch A PyTorch implementation of the Transformer from the paper Attention is All You Need in both Post-LN (Post-LayerNorm) and Pre-LN (

Jared Wang 22 Feb 27, 2022
BabelCalib: A Universal Approach to Calibrating Central Cameras. In ICCV (2021)

BabelCalib: A Universal Approach to Calibrating Central Cameras This repository contains the MATLAB implementation of the BabelCalib calibration frame

Yaroslava Lochman 55 Dec 30, 2022
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Thalles Silva 1.7k Dec 28, 2022
Diffusion Probabilistic Models for 3D Point Cloud Generation (CVPR 2021)

Diffusion Probabilistic Models for 3D Point Cloud Generation [Paper] [Code] The official code repository for our CVPR 2021 paper "Diffusion Probabilis

Shitong Luo 323 Jan 05, 2023
Python implementation of "Single Image Haze Removal Using Dark Channel Prior"

##Dependencies pillow(~2.6.0) Numpy(~1.9.0) If the scripts throw AttributeError: __float__, make sure your pillow has jpeg support e.g. try: $ sudo ap

Joyee Cheung 73 Dec 20, 2022
audioLIME: Listenable Explanations Using Source Separation

audioLIME This repository contains the Python package audioLIME, a tool for creating listenable explanations for machine learning models in music info

Institute of Computational Perception 27 Dec 01, 2022
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

21 Oct 06, 2022
Code for the RA-L (ICRA) 2021 paper "SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition"

SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition [ArXiv+Supplementary] [IEEE Xplore RA-L 2021] [ICRA 2021 YouTube Video]

Sourav Garg 63 Dec 12, 2022
This is the official implementation of Elaborative Rehearsal for Zero-shot Action Recognition (ICCV2021)

Elaborative Rehearsal for Zero-shot Action Recognition This is an official implementation of: Shizhe Chen and Dong Huang, Elaborative Rehearsal for Ze

DeLightCMU 26 Sep 24, 2022
Online Pseudo Label Generation by Hierarchical Cluster Dynamics for Adaptive Person Re-identification

Online Pseudo Label Generation by Hierarchical Cluster Dynamics for Adaptive Person Re-identification

TANG, shixiang 6 Nov 25, 2022
[WACV21] Code for our paper: Samuel, Atzmon and Chechik, "From Generalized zero-shot learning to long-tail with class descriptors"

DRAGON: From Generalized zero-shot learning to long-tail with class descriptors Paper Project Website Video Overview DRAGON learns to correct the bias

Dvir Samuel 25 Dec 06, 2022
Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads-Tutorial-3 Testing the Facial Emotion Recognition (FER) algorithm on animations

PegHeads Inc 2 Jan 03, 2022
Code for "Retrieving Black-box Optimal Images from External Databases" (WSDM 2022)

Retrieving Black-box Optimal Images from External Databases (WSDM 2022) We propose how a user retreives an optimal image from external databases of we

joisino 5 Apr 13, 2022
Tensorboard for pytorch (and chainer, mxnet, numpy, ...)

tensorboardX Write TensorBoard events with simple function call. The current release (v2.3) is tested on anaconda3, with PyTorch 1.8.1 / torchvision 0

Tzu-Wei Huang 7.5k Dec 28, 2022
Awesome-google-colab - Google Colaboratory Notebooks and Repositories

Unofficial Google Colaboratory Notebook and Repository Gallery Please contact me to take over and revamp this repo (it gets around 30k views and 200k

Derek Snow 1.2k Jan 03, 2023