Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Overview

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

License: MIT

Code for this paper Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly. [Preprint]

Tianlong Chen, Yu Cheng, Zhe Gan, Jingjing Liu, Zhangyang Wang.

Overview

Training generative adversarial networks (GANs) with limited data generally results in deteriorated performance and collapsed models. To conquerthis challenge, we are inspired by the latest observation of Kalibhat et al. (2020); Chen et al.(2021d), that one can discover independently trainable and highly sparse subnetworks (a.k.a.,lottery tickets) from GANs. Treating this as aninductive prior, we decompose the data-hungry GAN training into two sequential sub-problems:

  • (i) identifying the lottery ticket from the original GAN;
  • (ii) then training the found sparse subnetwork with aggressive data and feature augmentations.

Both sub-problems re-use the same small training set of real images. Such a coordinated framework enables us to focus on lower-complexity and more data-efficient sub-problems, effectively stabilizing trainingand improving convergence.

Methodology

Experiment Results

More experiments can be found in our paper.

Implementation

For the first step, finding the lottery tickets in GAN is referred to this repo.

For the second step, training GAN ticket toughly are provides as follow:

Environment for SNGAN

conda install python3.6
conda install pytorch1.4.0 -c pytorch
pip install tensorflow-gpu==1.13
pip install imageio
pip install tensorboardx

R.K. Donwload fid statistics from Fid_Stat.

Commands for SNGAN

R.K. Limited data training for SNGAN

  • Dataset: CIFAR-10

Example for full model training on 20% limited data (--ratio 0.2):

python train_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --ratio 0.2

Example for full model training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2

Example for sparse model (i.e., GAN tickets) training on 20% limited data (--ratio 0.2) with AdvAug on G and D:

python train_with_masks_adv_gd_less.py -gen_bs 128 -dis_bs 64 --dataset cifar10 --img_size 32 --max_iter 50000 --model sngan_cifar10 --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --exp_name sngan_cifar10_adv_gd_less_0.2 --init-path initial_weights --gamma 0.01 --step 1 --ratio 0.2 --rewind-path <>
  • --rewind-path: the stored path of identified sparse masks

Environment for BigGAN

conda env create -f environment.yml studiogan

Commands for BigGAN

R.K. Limited data training for BigGAN

  • Dataset: TINY ILSVRC

Example:

python main_ompg.py -t -e -c ./configs/TINY_ILSVRC2012/BigGAN_adv.json --eval_type valid --seed 42 --mask_path checkpoints/BigGAN-train-0.1 --mask_round 2 --reduce_train_dataset 0.1 --gamma 0.01 
  • --mask_path: the stored path of identified sparse masks
  • --mask_round: the sparsity level = 0.8 ^ mask_round
  • --reduce_train_dataset: the size of used limited training data
  • --gamma: hyperparameter for AdvAug. You can set it to 0 to git rid of AdvAug

  • Dataset: CIFAR100

Example:

python main_ompg.py -t -e -c ./configs/CIFAR100_less/DiffAugGAN_adv.json --ratio 0.2 --mask_path checkpoints/diffauggan_cifar100_0.2 --mask_round 9 --seed 42 --gamma 0.01
  • DiffAugGAN_adv.json: it indicate this confirguration use DiffAug.

Pre-trained Models

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights

https://www.dropbox.com/sh/7v8hn2859cvm7jj/AACyN8FOkMjgMwy5ibVj61IPa?dl=0

  • SNGAN / CIFAR-10 / 10% Training Data / 10.74% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/gsklrdcjzogqzcd/AAALlIYcWOZuERLcocKIqlEya?dl=0

  • BigGAN / CIFAR-10 / 10% Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/epuajb1iqn5xma6/AAAD0zwehky1wvV3M3-uesHsa?dl=0

  • BigGAN / CIFAR-100 10% / Training Data / 13.42% Remaining Weights + DiffAug + AdvAug on G and D

https://www.dropbox.com/sh/y3pqdqee39jpct4/AAAsSebqHwkWmjO_O8Hp0hcEa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model

https://www.dropbox.com/sh/2rmvqwgcjir1p2l/AABNEo0B-0V9ZSnLnKF_OUA3a?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / Full model + AdvAug on G and D

https://www.dropbox.com/sh/pbwjphualzdy2oe/AACZ7VYJctNBKz3E9b8fgj_Ia?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights

https://www.dropbox.com/sh/82i9z44uuczj3u3/AAARsfNzOgd1R9sKuh1OqUdoa?dl=0

  • BigGAN / Tiny-ImageNet / 10% Training Data / 64% Remaining Weights + AdvAug on G and D

https://www.dropbox.com/sh/yknk1joigx0ufbo/AAChMvzCsedejFjY1XxGcaUta?dl=0

Citation

@misc{chen2021ultradataefficient,
      title={Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly}, 
      author={Tianlong Chen and Yu Cheng and Zhe Gan and Jingjing Liu and Zhangyang Wang},
      year={2021},
      eprint={2103.00397},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Acknowledgement

https://github.com/VITA-Group/GAN-LTH

https://github.com/GongXinyuu/sngan.pytorch

https://github.com/VITA-Group/AutoGAN

https://github.com/POSTECH-CVLab/PyTorch-StudioGAN

https://github.com/mit-han-lab/data-efficient-gans

https://github.com/lucidrains/stylegan2-pytorch

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Breaking the Dilemma of Medical Image-to-image Translation

Breaking the Dilemma of Medical Image-to-image Translation Supervised Pix2Pix and unsupervised Cycle-consistency are two modes that dominate the field

Kid Liet 86 Dec 21, 2022
Offical implementation for "Trash or Treasure? An Interactive Dual-Stream Strategy for Single Image Reflection Separation".

Trash or Treasure? An Interactive Dual-Stream Strategy for Single Image Reflection Separation (NeurIPS 2021) by Qiming Hu, Xiaojie Guo. Dependencies P

Qiming Hu 31 Dec 20, 2022
Implementation of "Deep Implicit Templates for 3D Shape Representation"

Deep Implicit Templates for 3D Shape Representation Zerong Zheng, Tao Yu, Qionghai Dai, Yebin Liu. arXiv 2020. This repository is an implementation fo

Zerong Zheng 144 Dec 07, 2022
FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI

FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI 声明: 本项目仅限于学习交流,不可用于非法用途,包括但不限于:用于游戏外挂等,使用本项目产生的任何后果与本人无关! 简介 本项目基于yolov5,实现了一款FPS类游戏(CF、CSGO等)的自瞄AI,本项目旨在使用现

Fabian 246 Dec 28, 2022
Pytorch implementation of "Geometrically Adaptive Dictionary Attack on Face Recognition" (WACV 2022)

Geometrically Adaptive Dictionary Attack on Face Recognition This is the Pytorch code of our paper "Geometrically Adaptive Dictionary Attack on Face R

6 Nov 21, 2022
An official PyTorch Implementation of Boundary-aware Self-supervised Learning for Video Scene Segmentation (BaSSL)

An official PyTorch Implementation of Boundary-aware Self-supervised Learning for Video Scene Segmentation (BaSSL)

Kakao Brain 72 Dec 28, 2022
NeRF visualization library under construction

NeRF visualization library using PlenOctrees, under construction pip install nerfvis Docs will be at: https://nerfvis.readthedocs.org import nerfvis s

Alex Yu 196 Jan 04, 2023
PyTorch implementation of paper: AdaAttN: Revisit Attention Mechanism in Arbitrary Neural Style Transfer, ICCV 2021.

AdaAttN: Revisit Attention Mechanism in Arbitrary Neural Style Transfer [Paper] [PyTorch Implementation] [Paddle Implementation] Overview This reposit

148 Dec 30, 2022
A fast python implementation of Ray Tracing in One Weekend using python and Taichi

ray-tracing-one-weekend-taichi A fast python implementation of Ray Tracing in One Weekend using python and Taichi. Taichi is a simple "Domain specific

157 Dec 26, 2022
Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

567 Dec 26, 2022
novel deep learning research works with PaddlePaddle

Research 发布基于飞桨的前沿研究工作,包括CV、NLP、KG、STDM等领域的顶会论文和比赛冠军模型。 目录 计算机视觉(Computer Vision) 自然语言处理(Natrual Language Processing) 知识图谱(Knowledge Graph) 时空数据挖掘(Spa

1.5k Dec 29, 2022
VisionKG: Vision Knowledge Graph

VisionKG: Vision Knowledge Graph Official Repository of VisionKG by Anh Le-Tuan, Trung-Kien Tran, Manh Nguyen-Duc, Jicheng Yuan, Manfred Hauswirth and

Continuous Query Evaluation over Linked Stream (CQELS) 9 Jun 23, 2022
Official PyTorch implementation of the paper "Self-Supervised Relational Reasoning for Representation Learning", NeurIPS 2020 Spotlight.

Official PyTorch implementation of the paper: "Self-Supervised Relational Reasoning for Representation Learning" (2020), Patacchiola, M., and Storkey,

Massimiliano Patacchiola 135 Jan 03, 2023
SSL_SLAM2: Lightweight 3-D Localization and Mapping for Solid-State LiDAR (mapping and localization separated) ICRA 2021

SSL_SLAM2 Lightweight 3-D Localization and Mapping for Solid-State LiDAR (Intel Realsense L515 as an example) This repo is an extension work of SSL_SL

Wang Han 王晗 1.3k Jan 08, 2023
HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method)

Methods HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method) Dynamically selecting the best propagation method for each node

Yong 7 Dec 18, 2022
LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation (NeurIPS2021 Benchmark and Dataset Track)

LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation by Junjue Wang, Zhuo Zheng, Ailong Ma, Xiaoyan Lu, and Yanfei Zh

Kingdrone 174 Dec 22, 2022
Randomizes the warps in a stock pokeemerald repo.

pokeemerald warp randomizer Randomizes the warps in a stock pokeemerald repo. Usage Instructions Install networkx and matplotlib via pip3 or similar.

Max Thomas 6 Mar 17, 2022
CLIPImageClassifier wraps clip image model from transformers

CLIPImageClassifier CLIPImageClassifier wraps clip image model from transformers. CLIPImageClassifier is initialized with the argument classes, these

Jina AI 6 Sep 12, 2022
The offcial repository for 'CharacterBERT and Self-Teaching for Improving the Robustness of Dense Retrievers on Queries with Typos', SIGIR2022

CharacterBERT-DR The offcial repository for CharacterBERT and Self-Teaching for Improving the Robustness of Dense Retrievers on Queries with Typos, Sh

ielab 11 Nov 15, 2022
Learning Neural Painters Fast! using PyTorch and Fast.ai

The Joy of Neural Painting Learning Neural Painters Fast! using PyTorch and Fast.ai Blogpost with more details: The Joy of Neural Painting The impleme

Libre AI 72 Nov 10, 2022