Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

Overview

HiT-GAN Official TensorFlow Implementation

HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GANs). It achieves state-of-the-art performance for high-resolution image synthesis. Please check our NeurIPS 2021 paper "Improved Transformer for High-Resolution GANs" for more details.

This implementation is based on TensorFlow 2.x. We use tf.keras layers for building the model and use tf.data for our input pipeline. The model is trained using a custom training loop with tf.distribute on multiple TPUs/GPUs.

Environment setup

It is recommended to run distributed training to train our model with TPUs and evaluate it with GPUs. The code is compatible with TensorFlow 2.x. See requirements.txt for all prerequisites, and you can also install them using the following command.

pip install -r requirements.txt

ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from the official guide.

Train on ImageNet

To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.

Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:

TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>

The following command can be used to train a model on ImageNet (which reflects the default hyperparameters in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

To train the model on ImageNet with multiple GPUs, try the following command:

python run.py --mode=train --dataset=imagenet2012 \
  --train_batch_size=256 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --save_every_n_steps=2000 \
  --latent_dim=256 --generator_lr=0.0001 \
  --discriminator_lr=0.0001 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=False

Please set train_batch_size according to the number of GPUs for training. Note that storing Exponential Moving Average (EMA) models is not supported with GPUs currently (--use_ema_model=False), so training with GPUs will lead to slight performance drop.

Evaluate on ImageNet

Run the following command to evaluate the model on GPUs:

python run.py --mode=eval --dataset=imagenet2012 \
  --eval_batch_size=128 --train_steps=1000000 \
  --image_crop_size=128 --image_crop_proportion=0.875 \
  --latent_dim=256 --channel_multiplier=1 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

This command runs models with 8 P100 GPUs. Please set eval_batch_size according to the number of GPUs for evaluation. Please also note that train_steps and use_ema_model should be set according to the values used for training.

CelebA-HQ

At the first time, download CelebA-HQ following tensorflow_datasets instruction from the official guide.

Train on CelebA-HQ

The following command can be used to train a model on CelebA-HQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=celeb_a_hq/256 \
  --train_batch_size=256 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on CelebA-HQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=celeb_a_hq/256 \
  --eval_batch_size=128 --train_steps=250000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

FFHQ

At the first time, download the tfrecords of FFHQ from the official site and put them into $DATA_DIR.

Train on FFHQ

The following command can be used to train a model on FFHQ (which reflects the default hyperparameters used for the resolution of 256 in our paper) on TPUv2 4x4:

python run.py --mode=train --dataset=ffhq/256 \
  --train_batch_size=256 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --save_every_n_steps=1000 \
  --latent_dim=512 --generator_lr=0.00005 \
  --discriminator_lr=0.00005 --channel_multiplier=2 \
  --use_consistency_regularization=True \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=True --master=$TPU_NAME

Evaluate on FFHQ

Run the following command to evaluate the model on 8 P100 GPUs:

python run.py --mode=eval --dataset=ffhq/256 \
  --eval_batch_size=128 --train_steps=500000 \
  --image_crop_size=256 --image_crop_proportion=1.0 \
  --latent_dim=512 --channel_multiplier=2 \
  --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
  --use_tpu=False --use_ema_model=True

Cite

@inproceedings{zhao2021improved,
  title = {Improved Transformer for High-Resolution {GANs}},
  author = {Long Zhao and Zizhao Zhang and Ting Chen and Dimitris Metaxas abd Han Zhang},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2021}
}

Disclaimer

This is not an officially supported Google product.

Yolo object detection - Yolo object detection with python

How to run download required files make build_image make download Docker versio

3 Jan 26, 2022
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
Misc YOLOL scripts for use in the Starbase space sandbox videogame

starbase-misc Misc YOLOL scripts for use in the Starbase space sandbox videogame. Each directory contains standalone YOLOL scripts. They don't really

4 Oct 17, 2021
Finite difference solution of 2D Poisson equation. Can handle Dirichlet, Neumann and mixed boundary conditions.

Poisson-solver-2D Finite difference solution of 2D Poisson equation Current version can handle Dirichlet, Neumann, and mixed (combination of Dirichlet

Mohammad Asif Zaman 34 Dec 23, 2022
Intelligent Video Analytics toolkit based on different inference backends.

English | 中文 OpenIVA OpenIVA is an end-to-end intelligent video analytics development toolkit based on different inference backends, designed to help

Quantum Liu 15 Oct 27, 2022
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
Pynomial - a lightweight python library for implementing the many confidence intervals for the risk parameter of a binomial model

Pynomial - a lightweight python library for implementing the many confidence intervals for the risk parameter of a binomial model

Demetri Pananos 9 Oct 04, 2022
iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis

iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis Andreas Bl

CompVis Heidelberg 36 Dec 25, 2022
PyTorch Implementation of the paper Learning to Reweight Examples for Robust Deep Learning

Learning to Reweight Examples for Robust Deep Learning Unofficial PyTorch implementation of Learning to Reweight Examples for Robust Deep Learning. Th

Daniel Stanley Tan 325 Dec 28, 2022
A Pytorch implement of paper "Anomaly detection in dynamic graphs via transformer" (TADDY).

TADDY: Anomaly detection in dynamic graphs via transformer This repo covers an reference implementation for the paper "Anomaly detection in dynamic gr

Yue Tan 21 Nov 24, 2022
gACSON software for visualization, processing and analysis of three-dimensional electron microscopy images

gACSON gACSON software is to visualize, segment, and analyze the morphology of neurons in three-dimensional electron microscopy images. If you use any

Andrea Behanova 2 May 31, 2022
K-Means Clustering and Hierarchical Clustering Unsupervised Learning Solution in Python3.

Unsupervised Learning - K-Means Clustering and Hierarchical Clustering - The Heritage Foundation's Economic Freedom Index Analysis 2019 - By David Sal

David Salako 1 Jan 12, 2022
This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

This tool converts a Nondeterministic Finite Automata (NFA) into a Deterministic Finite Automata (DFA)

Quinn Herden 1 Feb 04, 2022
Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations. [2021]

Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations This repo contains the Pytorch implementation of our paper: Revisit

Wouter Van Gansbeke 80 Nov 20, 2022
Place holder for HOPE: a human-centric and task-oriented MT evaluation framework using professional post-editing

HOPE: A Task-Oriented and Human-Centric Evaluation Framework Using Professional Post-Editing Towards More Effective MT Evaluation Place holder for dat

Lifeng Han 1 Apr 25, 2022
Real time sign language recognition

The proposed work aims at converting american sign language gestures into English that can be understood by everyone in real time.

Mohit Kaushik 6 Jun 13, 2022
A project that uses optical flow and machine learning to detect aimhacking in video clips.

waldo-anticheat A project that aims to use optical flow and machine learning to visually detect cheating or hacking in video clips from fps games. Che

waldo.vision 542 Dec 03, 2022
This repository contains all code and data for the Inside Out Visual Place Recognition task

Inside Out Visual Place Recognition This repository contains code and instructions to reproduce the results for the Inside Out Visual Place Recognitio

15 May 21, 2022
Code for our WACV 2022 paper "Hyper-Convolution Networks for Biomedical Image Segmentation"

Hyper-Convolution Networks for Biomedical Image Segmentation Code for our WACV 2022 paper "Hyper-Convolution Networks for Biomedical Image Segmentatio

Tianyu Ma 17 Nov 02, 2022
Wenet STT Python

Wenet STT Python Beta Software Simple Python library, distributed via binary wheels with few direct dependencies, for easily using WeNet models for sp

David Zurow 33 Feb 21, 2022