Code for the paper "Adversarial Generator-Encoder Networks"

Related tags

Deep Learninggan
Overview

This repository contains code for the paper

"Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky.

Pretrained models

This is how you can access the models used to generate figures in the paper.

  1. First install dev version of pytorch 0.2 and make sure you have jupyter notebook ready.

  2. Then download the models with the script:

bash download_pretrained.sh
  1. Run jupyter notebook and go through evaluate.ipynb.

Here is an example of samples and reconstructions for imagenet, celeba and cifar10 datasets generated with evaluate.ipynb.

Celeba

Samples Reconstructions

Cifar10

Samples Reconstructions

Tiny ImageNet

Samples Reconstructions

Training

Use age.py script to train a model. Here are the most important parameters:

  • --dataset: one of [celeba, cifar10, imagenet, svhn, mnist]
  • --dataroot: for datasets included in torchvision it is a directory where everything will be downloaded to; for imagenet, celeba datasets it is a path to a directory with folders train and val inside.
  • --image_size:
  • --save_dir: path to a folder, where checkpoints will be stored
  • --nz: dimensionality of latent space
  • -- batch_size: Batch size. Default 64.
  • --netG: .py file with generator definition. Searched in models directory
  • --netE: .py file with generator definition. Searched in models directory
  • --netG_chp: path to a generator checkpoint to load from
  • --netE_chp: path to an encoder checkpoint to load from
  • --nepoch: number of epoch to run
  • --start_epoch: epoch number to start from. Useful for finetuning.
  • --e_updates: Update plan for encoder. <num steps>;KL_fake:<weight>,KL_real:<weight>,match_z:<weight>,match_x:<weight>.
  • --g_updates: Update plan for generator. <num steps>;KL_fake:<weight>,match_z:<weight>,match_x:<weight>.

And misc arguments:

  • --workers: number of dataloader workers.
  • --ngf: controlles number of channels in generator
  • --ndf: controlles number of channels in encoder
  • --beta1: parameter for ADAM optimizer
  • --cpu: do not use GPU
  • --criterion: Parametric param or non-parametric nonparam way to compute KL. Parametric fits Gaussian into data, non-parametric is based on nearest neighbors. Default: param.
  • --KL: What KL to compute: qp or pq. Default is qp.
  • --noise: sphere for uniform on sphere or gaussian. Default sphere.
  • --match_z: loss to use as reconstruction loss in latent space. L1|L2|cos. Default cos.
  • --match_x: loss to use as reconstruction loss in data space. L1|L2|cos. Default L1.
  • --drop_lr: each drop_lr epochs a learning rate is dropped.
  • --save_every: controls how often intermediate results are stored. Default 50.
  • --manual_seed: random seed. Default 123.

Here is cmd you can start with:

Celeba

Let data_root to be a directory with two folders train, val, each with the images for corresponding split.

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --lr 0.0002 --nz 64 --batch_size 64 --netG dcgan64px --netE dcgan64px --nepoch 5 --drop_lr 5 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '3;KL_fake:1,match_z:1000,match_x:0'

It is beneficial to finetune the model with larger batch_size and stronger matching weight then:

python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --start_epoch 5 --lr 0.0002 --nz 64 --batch_size 256 --netG dcgan64px --netE dcgan64px --nepoch 6 --drop_lr 5   --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' --g_updates '3;KL_fake:1,match_z:1000,match_x:0' --netE_chp  <save_dir>/netE_epoch_5.pth --netG_chp <save_dir>/netG_epoch_5.pth

Imagenet

python age.py --dataset imagenet --dataroot /path/to/imagenet_dir/ --save_dir <save_dir> --image_size 32 --save_dir ${pdir} --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 6 --drop_lr 3  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:2000,match_x:0' --workers 12

It can be beneficial to switch to 256 batch size after several epochs.

Cifar10

python age.py --dataset cifar10 --image_size 32 --save_dir <save_dir> --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 150 --drop_lr 40  --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:1000,match_x:0'

Tested with python 2.7.

Implementation is based on pyTorch DCGAN code.

Citation

If you found this code useful please cite our paper

@inproceedings{DBLP:conf/aaai/UlyanovVL18,
  author    = {Dmitry Ulyanov and
               Andrea Vedaldi and
               Victor S. Lempitsky},
  title     = {It Takes (Only) Two: Adversarial Generator-Encoder Networks},
  booktitle = {{AAAI}},
  publisher = {{AAAI} Press},
  year      = {2018}
}
Owner
Dmitry Ulyanov
Co-Founder at in3D, Phd @ Skoltech
Dmitry Ulyanov
ICSS - Interactive Continual Semantic Segmentation

Presentation This repository contains the code of our paper: Weakly-supervised c

Alteia 9 Jul 23, 2022
Automatic Data-Regularized Actor-Critic (Auto-DrAC)

Auto-DrAC: Automatic Data-Regularized Actor-Critic This is a PyTorch implementation of the methods proposed in Automatic Data Augmentation for General

89 Dec 13, 2022
This Deep Learning Model Predicts that from which disease you are suffering.

Deep-Learning-Project This Deep Learning Model Predicts that from which disease you are suffering. This Project Covers the Topics of Deep Learning Int

Jai Viral Doshi 0 Jan 20, 2022
PyTorch implementation of SimSiam: Exploring Simple Siamese Representation Learning

SimSiam: Exploring Simple Siamese Representation Learning This is a PyTorch implementation of the SimSiam paper: @Article{chen2020simsiam, author =

Facebook Research 834 Dec 30, 2022
Fine-tune pretrained Convolutional Neural Networks with PyTorch

Fine-tune pretrained Convolutional Neural Networks with PyTorch. Features Gives access to the most popular CNN architectures pretrained on ImageNet. A

Alex Parinov 694 Nov 23, 2022
A scanpy extension to analyse single-cell TCR and BCR data.

Scirpy: A Scanpy extension for analyzing single-cell immune-cell receptor sequencing data Scirpy is a scalable python-toolkit to analyse T cell recept

ICBI 145 Jan 03, 2023
StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation

StyleSpace Analysis: Disentangled Controls for StyleGAN Image Generation Demo video: CVPR 2021 Oral: Single Channel Manipulation: Localized or attribu

Zongze Wu 267 Dec 30, 2022
The open-source and free to use Python package miseval was developed to establish a standardized medical image segmentation evaluation procedure

miseval: a metric library for Medical Image Segmentation EVALuation The open-source and free to use Python package miseval was developed to establish

59 Dec 10, 2022
A denoising autoencoder + adversarial losses and attention mechanisms for face swapping.

faceswap-GAN Adding Adversarial loss and perceptual loss (VGGface) to deepfakes'(reddit user) auto-encoder architecture. Updates Date Update 2018-08-2

3.2k Dec 30, 2022
duralava is a neural network which can simulate a lava lamp in an infinite loop.

duralava duralava is a neural network which can simulate a lava lamp in an infinite loop. Example This is not a real lava lamp but a "fake" one genera

Maximilian Bachl 87 Dec 20, 2022
GraPE is a Rust/Python library for high-performance Graph Processing and Embedding.

GraPE GraPE (Graph Processing and Embedding) is a fast graph processing and embedding library, designed to scale with big graphs and to run on both of

AnacletoLab 194 Dec 29, 2022
HybridNets: End-to-End Perception Network

HybridNets: End2End Perception Network HybridNets Network Architecture. HybridNets: End-to-End Perception Network by Dat Vu, Bao Ngo, Hung Phan 📧 FPT

Thanh Dat Vu 370 Dec 29, 2022
Lecture materials for Cornell CS5785 Applied Machine Learning (Fall 2021)

Applied Machine Learning (Cornell CS5785, Fall 2021) This repo contains executable course notes and slides for the Applied ML course at Cornell and Co

Volodymyr Kuleshov 103 Dec 31, 2022
People Interaction Graph

Gihan Jayatilaka*, Jameel Hassan*, Suren Sritharan*, Janith Senananayaka, Harshana Weligampola, et. al., 2021. Holistic Interpretation of Public Scenes Using Computer Vision and Temporal Graphs to Id

University of Peradeniya : COVID Research Group 1 Aug 24, 2022
Learning with Subset Stacking

Learning with Subset Stacking (LESS) LESS is a new supervised learning algorithm that is based on training many local estimators on subsets of a given

S. Ilker Birbil 19 Oct 04, 2022
Dieser Scanner findet Websites, die nicht direkt in Suchmaschinen auftauchen, aber trotzdem erreichbar sind.

Deep Web Scanner Dieses Script findet Websites, die per IPv4-Adresse erreichbar sind und speichert deren Metadaten. Die Ausgabe im Terminal wird nach

Alex K. 30 Nov 18, 2022
The implementation of FOLD-R++ algorithm

FOLD-R-PP The implementation of FOLD-R++ algorithm. The target of FOLD-R++ algorithm is to learn an answer set program for a classification task. Inst

13 Dec 23, 2022
This project is used for the paper Differentiable Programming of Isometric Tensor Network

This project is used for the paper "Differentiable Programming of Isometric Tensor Network". (arXiv:2110.03898)

Chenhua Geng 15 Dec 13, 2022
[CVPR 2021] NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning

NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning Project Page | Paper | Supplemental material #1 | Supplement

KAIST VCLAB 49 Nov 24, 2022
Auto-updating data to assist in investment to NEPSE

Symbol Ratios Summary Sector LTP Undervalued Bonus % MEGA Strong Commercial Banks 368 5 10 JBBL Strong Development Banks 568 5 10 SIFC Strong Finance

Amit Chaudhary 16 Nov 01, 2022