Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

Overview

HiT-GAN Official TensorFlow Implementation

HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GANs). It achieves state-of-the-art performance for high-resolution image synthesis. Please check our NeurIPS 2021 paper "Improved Transformer for High-Resolution GANs" for more details.

This implementation is based on TensorFlow 2.x. We use tf.keras layers for building the model and use tf.data for our input pipeline. The model is trained using a custom training loop with tf.distribute on multiple TPUs/GPUs.

Environment setup

It is recommended to run distributed training to train our model with TPUs and evaluate it with GPUs. The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.

pip install -r requirements.txt

ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from the official guide.

Train on ImageNet

To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.

Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:

TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>

The following command can be used to train a model on ImageNet (which reflects the default hyperparameters in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

To train the model on ImageNet with multiple GPUs, try the following command:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=False

Please set train_batch_size according to the number of GPUs for training. Note that storing Exponential Moving Average (EMA) models is not supported with GPUs currently (--use_ema_model=False), so training with GPUs will lead to slight performance drop.

Evaluate on ImageNet

Run the following command to evaluate the model on GPUs:

python run.py --mode=eval --dataset=imagenet2012 \
  --eval_batch_size=128 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --latent_dim=256 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

This command runs models with 8 P100 GPUs. Please set eval_batch_size according to the number of GPUs for evaluation. Please also note that train_steps and use_ema_model should be set according to the values used for training.

CelebA-HQ

At the first time, download CelebA-HQ following tensorflow_datasets instruction from the official guide.

Train on CelebA-HQ

The following command can be used to train a model on CelebA-HQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=celeb_a_hq/256 \
  --train_batch_size=256 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on CelebA-HQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=celeb_a_hq/256 \
  --eval_batch_size=128 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

FFHQ

At the first time, download the tfrecords of FFHQ from the official site and put them into $DATA_DIR.

Train on FFHQ

The following command can be used to train a model on FFHQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=ffhq/256 \
  --train_batch_size=256 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on FFHQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=ffhq/256 \
  --eval_batch_size=128 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

Cite

@inproceedings{zhao2021improved,
  title = {Improved Transformer for High-Resolution {GANs}},
  author = {Long Zhao and Zizhao Zhang and Ting Chen and Dimitris Metaxas abd Han Zhang},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}

Disclaimer

This is not an officially supported Google product.

TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022
Image to Image translation, image generataton, few shot learning

Semi-supervised Learning for Few-shot Image-to-Image Translation [paper] Abstract: In the last few years, unpaired image-to-image translation has witn

yaxingwang 49 Nov 18, 2022
Apply a perspective transformation to a raster image inside Inkscape (no need to use an external software such as GIMP or Krita).

Raster Perspective Apply a perspective transformation to bitmap image using the selected path as envelope, without the need to use an external softwar

s.ouchene 19 Dec 22, 2022
Data-driven reduced order modeling for nonlinear dynamical systems

SSMLearn Data-driven Reduced Order Models for Nonlinear Dynamical Systems This package perform data-driven identification of reduced order model based

Haller Group, Nonlinear Dynamics 27 Dec 13, 2022
Weakly- and Semi-Supervised Panoptic Segmentation (ECCV18)

Weakly- and Semi-Supervised Panoptic Segmentation by Qizhu Li*, Anurag Arnab*, Philip H.S. Torr This repository demonstrates the weakly supervised gro

Qizhu Li 159 Dec 20, 2022
Learning Dynamic Network Using a Reuse Gate Function in Semi-supervised Video Object Segmentation.

Training Script for Reuse-VOS This code implementation of CVPR 2021 paper : Learning Dynamic Network Using a Reuse Gate Function in Semi-supervised Vi

HYOJINPARK 22 Jan 01, 2023
OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark

Introduction English | 简体中文 MMAction2 is an open-source toolbox for video understanding based on PyTorch. It is a part of the OpenMMLab project. The m

OpenMMLab 2.7k Jan 07, 2023
yolox_backbone is a deep-learning library and is a collection of YOLOX Backbone models.

YOLOX-Backbone yolox-backbone is a deep-learning library and is a collection of YOLOX backbone models. Install pip install yolox-backbone Load a Pret

Yonghye Kwon 21 Dec 28, 2022
This is a pytorch implementation for the BST model from Alibaba https://arxiv.org/pdf/1905.06874.pdf

Behavior-Sequence-Transformer-Pytorch This is a pytorch implementation for the BST model from Alibaba https://arxiv.org/pdf/1905.06874.pdf This model

Jaime Ferrando Huertas 83 Jan 05, 2023
The repository forked from NVlabs uses our data. (Differentiable rasterization applied to 3D model simplification tasks)

nvdiffmodeling [origin_code] Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Autom

Qiujie (Jay) Dong 2 Oct 31, 2022
Deep Learning (with PyTorch)

Deep Learning (with PyTorch) This notebook repository now has a companion website, where all the course material can be found in video and textual for

Alfredo Canziani 6.2k Jan 07, 2023
Official codebase used to develop Vision Transformer, MLP-Mixer, LiT and more.

Big Vision This codebase is designed for training large-scale vision models on Cloud TPU VMs. It is based on Jax/Flax libraries, and uses tf.data and

Google Research 701 Jan 03, 2023
Facilitates implementing deep neural-network backbones, data augmentations

Introduction Nowadays, the training of Deep Learning models is fragmented and unified. When AI engineers face up with one specific task, the common wa

40 Dec 29, 2022
Guiding evolutionary strategies by (inaccurate) differentiable robot simulators @ NeurIPS, 4th Robot Learning Workshop

Guiding Evolutionary Strategies by Differentiable Robot Simulators In recent years, Evolutionary Strategies were actively explored in robotic tasks fo

Vladislav Kurenkov 4 Dec 14, 2021
my graduation project is about live human face augmentation by projection mapping by using CNN

Live-human-face-expression-augmentation-by-projection my graduation project is about live human face augmentation by projection mapping by using CNN o

1 Mar 08, 2022
Iranian Cars Detection using Yolov5s, PyTorch

Iranian Cars Detection using Yolov5 Train 1- git clone https://github.com/ultralytics/yolov5 cd yolov5 pip install -r requirements.txt 2- Dataset ../

Nahid Ebrahimian 22 Dec 05, 2022
Simple codebase for flexible neural net training

neural-modular Simple codebase for flexible neural net training. Allows for seamless exchange of models, dataset, and optimizers. Uses hydra for confi

Jannik Kossen 7 Apr 05, 2022
NeuroFind - A solution to the to the Task given by the Oberseminar of Messtechnik Institute of TU Dresden in 2021

NeuroFind A solution to the to the Task given by the Oberseminar of Messtechnik

1 Jan 20, 2022
CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper)

CoReD: Generalizing Fake Media Detection with Continual Representation using Distillation (ACMMM'21 Oral Paper) (Accepted for oral presentation at ACM

Minha Kim 1 Nov 12, 2021
Stock-Prediction - prediction of stock market movements using sentiment analysis and deep learning.

Stock-Prediction- In this project, we aim to enhance the prediction of stock market movements using sentiment analysis and deep learning. We divide th

5 Jan 25, 2022