GAN example for Keras. Cuz MNIST is too small and there should be something more realistic.

Overview

Keras-GAN-Animeface-Character

GAN example for Keras. Cuz MNIST is too small and there should an example on something more realistic.

Some results

Training for 22 epochs

Youtube Video, click on the image

Training for 22 epochs

Loss graph for 5000 mini-batches

Loss graph

1 mini-batch = 64 images. Dataset = 14490, hence 5000 mini-batches is approximately 22 epochs.

Some outputs of 5000th min-batch

Some ouptputs of 5000th mini-batch

Some training images

Some inputs

Useful resources, before you go on

How to run this example

Setup

  • My environment: Python 3.6 + Keras 2.0.4 + Tensorflow 1.x
    • If you are on Keras 2.0.0, you need to update it otherwise BatchNormalization() will cause bug, saying "you need to pass float to input" or something like that from Tensorflow back end.
  • Use virtualenv to initialize a similar environment (python and dependencies):
pip install virtualenv
virtualenv -p <PATH_TO_BIN_DIR>/python3.6 venv
source venv/bin/activate
pip install -r requirements.txt
  • I HATE making a program that has so many command line parameters to pass. Many of the parameters are there in the scripts. Adjust the script as you need. The "main()" function is at the bottom of the script as people do in C/C++
  • Most global parameters are defined in args.py.
    • They are defined as class variables not instance variables so you may have trouble running/training multiple instances of the GAN with different parameters. (which is very unlikely to happen)
  • Download dataset from http://www.nurs.or.jp/~nagadomi/animeface-character-dataset/
    • Extract it to this directory so that the scipt can find ./animeface-character-dataset/thumb/
    • Any dataset should work in principle but GAN is sensitive to hyperparameters and may not work on yours. I tuned the parameters for animeface-character-dataset.

Preprocessing

  • Run the preprocessing script. It saves training time to resize/scale the input than doing those tasks on the fly in the training loop.
    • ./data.py
    • The image, when loaded from PNG files, the RGB values have [0, 255]. (uint8 type). data.py will collect the images, resize the images to 64x64 and scale the RGB values so that they will be in [-1.0, 1.0] range.
    • Data.py will only sample a subset of the dataset if configured to do so. The size of the subset is determined by dataset_sz defined in args.py
    • The images will be written to data.hdf5.
      • Made it small to verify the training is working.
      • You can increase it but you need to adjust the network sizes accordingly.
    • Again, which files to read is defined in the script at the bottom, not by sys.argv.
  • You need a large enough dataset. Otherwise the discriminator will sort of "memorize" the true data and reject all that's generated.

Training

  • Open gan.py then at the bottom, uncomment train_autoenc() if you wish.
    • This is useful for seeing the generator network's capability to reproduce the input.
    • The auto-encoder will be trained on input images.
    • The output will be blurry, as the auto-encoder having mean-squared-error loss. (This is why GAN got invented in the first place!)
  • To run training, modify main() so that train_gan() is uncommented.
  • The script will dump reals.png and fakes.png every 10 epoch so that you can see how the training is going.
  • The training takes a while. For this example on Anime Face dataset, it took about 10000 mini-batches to get good results.
    • If you see only uniform color or "modern art" until 2000 then the training is not working!
  • The script also dumps weights every 10 batches. Utilize them to save training time. Weights before diverging is preferred :) Uncomment load_weights() in train_gan().

Training tips

What I experienced during my training of GAN.

  • As described in GAN Hacks, discriminator should be ahead of the generator so that the generator can be "guided" by the discriminator.
  • If you look at loss graph at https://github.com/osh/KerasGAN, they had gen loss in range of 2 to 4. Their training worked well. The discriminator loss is low, arond 0.1.
  • You'll need trial and error to get the hyper-pameters right so that the training stays in the stable, balanced zone. That includes learning rate of D and G, momentums, etc.
  • The convergence is quite sensitive with LR, beware!
  • If things go well, the discriminator loss for detecting real/fake = dloss0/dloss1 should be less than or around 0.1, which means it is good at telling whether the input is real or fake.
  • If learning rate is too high, the discriminator will diverge and one of the loss will get high and will not fall. Training fails in this case.
  • If you make LR too small, it will only slow the learning and will not prevent other issues such as oscillation. It only needs to be lower than certain threshold that is data dependent.
  • If adjusting LR doesn't work, it could be lack of complexity in the discriminator layer. Add more layers, or some other parameters. It could be anything :( Good luck!
  • On the other hand, generator loss will be relatively higher than discriminator loss. In this script, it oscillates in range 0.1 to 4.
  • If you see any of the D loss staying > 15 (when batch size is 32) the training is screwed.
  • In case of G loss > 15, see if it escapes within 30 batches. If it stays there for too long, it isn't good, I think.
  • In case you're seeing high G loss, it could mean it can't keep up with discriminator. You might need to increase LR. (Must be slower than discriminator though)
  • One final piece of the training I was missing was the parameter in BatchNormalization. I found about it in this link: https://github.com/shekkizh/neuralnetworks.thought-experiments/blob/master/Generative%20Models/GAN/Readme.md
    • Sort of interesting, in PyTorch, momentum parameter for BatchNorm is 0.1, according to the API documents, while in Keras it is 0.99. I'm not sure if 0.1 in PyTorch actually means 1 - 0.1. I didn't look into PyTorch backend implementation.
The codebase for our paper "Generative Occupancy Fields for 3D Surface-Aware Image Synthesis" (NeurIPS 2021)

Generative Occupancy Fields for 3D Surface-Aware Image Synthesis (NeurIPS 2021) Project Page | Paper Xudong Xu, Xingang Pan, Dahua Lin and Bo Dai GOF

xuxudong 97 Nov 10, 2022
iris - Open Source Photos Platform Powered by PyTorch

Open Source Photos Platform Powered by PyTorch. Submission for PyTorch Annual Hackathon 2021.

Omkar Prabhu 137 Sep 10, 2022
Data from "HateCheck: Functional Tests for Hate Speech Detection Models" (Röttger et al., ACL 2021)

In this repo, you can find the data from our ACL 2021 paper "HateCheck: Functional Tests for Hate Speech Detection Models". "test_suite_cases.csv" con

Paul Röttger 43 Nov 11, 2022
Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization

Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization This repository contains the source code for the paper (link wi

Rakuten Group, Inc. 0 Nov 19, 2021
Source code of article "Towards Toxic and Narcotic Medication Detection with Rotated Object Detector"

Towards Toxic and Narcotic Medication Detection with Rotated Object Detector Introduction This is the source code of article: Towards Toxic and Narcot

Woody. Wang 3 Oct 29, 2022
Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification tasks

Uniformer - Pytorch Implementation of Uniformer, a simple attention and 3d convolutional net that achieved SOTA in a number of video classification ta

Phil Wang 90 Nov 24, 2022
Company clustering with K-means/GMM and visualization with PCA, t-SNE, using SSAN relation extraction

RE results graph visualization and company clustering Installation pip install -r requirements.txt python -m nltk.downloader stopwords python3.7 main.

Jieun Han 1 Oct 06, 2022
Omniverse sample scripts - A guide for developing with Python scripts on NVIDIA Ominverse

Omniverse sample scripts ここでは、NVIDIA Omniverse ( https://www.nvidia.com/ja-jp/om

ft-lab (Yutaka Yoshisaka) 37 Nov 17, 2022
Source code for the paper: Variance-Aware Machine Translation Test Sets (NeurIPS 2021 Datasets and Benchmarks Track)

Variance-Aware-MT-Test-Sets Variance-Aware Machine Translation Test Sets License See LICENSE. We follow the data licensing plan as the same as the WMT

NLP2CT Lab, University of Macau 5 Dec 21, 2021
なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモ

FaceDetection-Anti-Spoof-Demo なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモです。 モデルはPINTO_model_zoo/191_anti-spoof-mn3からONNX形式のモデルを使用しています。 Requirement mediapipe

KazuhitoTakahashi 8 Nov 18, 2022
Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21)

NeuralGIF Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21) We present Neural Generalized Implicit F

Garvita Tiwari 104 Nov 18, 2022
Gym-TORCS is the reinforcement learning (RL) environment in TORCS domain with OpenAI-gym-like interface.

Gym-TORCS Gym-TORCS is the reinforcement learning (RL) environment in TORCS domain with OpenAI-gym-like interface. TORCS is the open-rource realistic

naoto yoshida 400 Dec 27, 2022
PyTorch implementation of Neural View Synthesis and Matching for Semi-Supervised Few-Shot Learning of 3D Pose

Neural View Synthesis and Matching for Semi-Supervised Few-Shot Learning of 3D Pose Release Notes The official PyTorch implementation of Neural View S

Angtian Wang 20 Oct 09, 2022
Libraries, tools and tasks created and used at DeepMind Robotics.

dm_robotics: Libraries, tools, and tasks created and used for Robotics research at DeepMind. Package overview Package Summary Transformations Rigid bo

DeepMind 273 Jan 06, 2023
Simultaneous NMT/MMT framework in PyTorch

This repository includes the codes, the experiment configurations and the scripts to prepare/download data for the Simultaneous Machine Translation wi

<a href=[email protected]"> 37 Sep 29, 2022
This folder contains the python code of UR5E's advanced forward kinematics model.

This folder contains the python code of UR5E's advanced forward kinematics model. By entering the angle of the joint of UR5e, the detailed coordinates of up to 48 points around the robot arm can be c

Qiang Wang 4 Sep 17, 2022
EvDistill: Asynchronous Events to End-task Learning via Bidirectional Reconstruction-guided Cross-modal Knowledge Distillation (CVPR'21)

EvDistill: Asynchronous Events to End-task Learning via Bidirectional Reconstruction-guided Cross-modal Knowledge Distillation (CVPR'21) Citation If y

addisonwang 18 Nov 11, 2022
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
SalFBNet: Learning Pseudo-Saliency Distribution via Feedback Convolutional Networks

SalFBNet This repository includes Pytorch implementation for the following paper: SalFBNet: Learning Pseudo-Saliency Distribution via Feedback Convolu

12 Aug 12, 2022
This code is part of the reproducibility package for the SANER 2022 paper "Generating Clarifying Questions for Query Refinement in Source Code Search".

Clarifying Questions for Query Refinement in Source Code Search This code is part of the reproducibility package for the SANER 2022 paper "Generating

Zachary Eberhart 0 Dec 04, 2021