GAN JAX - A toy project to generate images from GANs with JAX

Related tags

Deep LearningGANJax
Overview

GAN JAX - A toy project to generate images from GANs with JAX

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation.

JAX

JAX logo

JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for gradient computation (per sample, jacobian with backward propagation and forward-propagation, hessian...) as well as a better seed system (for reproducibility) and a tool to batch complicated operations automatically and efficiently.

Github link: https://github.com/google/jax

GAN

GAN diagram

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, pitting one against the other (thus the adversarial) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation. GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training the most interesting idea in the last 10 years in ML. (source)

Original paper: https://arxiv.org/abs/1406.2661

Some ideas have improved the training of the GANs by the years. For example:

Deep Convolution GAN (DCGAN) paper: https://arxiv.org/abs/1511.06434

Progressive Growing GAN (ProGAN) paper: https://arxiv.org/abs/1710.10196

The goal of this project is to implement these ideas in JAX framework.

Installation

You can install JAX following the instruction on JAX - Installation

It is strongly recommended to run JAX on Linux with CUDA available (Windows has no stable support yet). In this case you can install JAX using the following command:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then you can install Tensorflow to benefit from tf.data.Dataset to handle the data and the pre-installed dataset. However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU. Visit this site Tensorflow - Installation with pip to install the CPU-only version of Tensorflow 2 depending on your OS and your Python version.

Exemple with Linux and Python 3.9:

pip install tensorflow -f https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.6.0-cp39-cp39-manylinux2010_x86_64.whl

Then you can install the other librairies from requirements.txt. It will install Haiku and Optax, two usefull add-on libraries to implement and optimize machine learning models with JAX.

pip install -r requirements.txt

Install CelebA dataset (optional)

To use the CelebA dataset, you need to download the dataset from Kaggle and install the images in the folder img_align_celeba/ in data/CelebA/images. It is recommended to download the dataset from this source because the faces are already cropped.

Note: the other datasets will be automatically installed with keras or tensorflow-datasets.

Quick Start

You can test a pretrained GAN model by using apps/test.py. It will download the model from pretrained models (in pre_trained/) and generate pictures. You can change the GAN to test by changing the path in the script.

You can also train your own GAN from scratch with apps/train.py. To change the parameters of the training, you can change the configs in the script. You can also change the dataset or the type of GAN by changing the imports (there is only one workd to change for each).

Example to train a GAN in celeba (64x64):

from utils.data import load_images_celeba_64 as load_images

To train a DCGAN:

from gan.dcgan import DCGAN as GAN

Then you can implement your own GAN and train/test them in your own dataset (by overriding the appropriate functions, check the examples in the repository).

Some results of pre-trained models

- Deep Convolution GAN

  • On MNIST:

DCGAN Cifar10

  • On Cifar10:

DCGAN Cifar10

  • On CelebA (64x64):

DCGAN CelebA-64

- Progressive Growing GAN

  • On MNIST:

  • On Cifar10:

  • On CelebA (64x64):

  • On CelebA (128x128):

Owner
Valentin Goldité
Student at CentraleSupelec (top french Engineer School) specialized in machine learning (Computer Vision, NLP, Audio, RL, Time Analysis).
Valentin Goldité
In this project we predict the forest cover type using the cartographic variables in the training/test datasets.

Kaggle Competition: Forest Cover Type Prediction In this project we predict the forest cover type (the predominant kind of tree cover) using the carto

Marianne Joy Leano 1 Mar 15, 2022
Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv) This is a Pytorch implementation of our te

蒋子航 383 Dec 27, 2022
Code & Data for Enhancing Photorealism Enhancement

Code & Data for Enhancing Photorealism Enhancement

Intel ISL (Intel Intelligent Systems Lab) 1.1k Jan 08, 2023
Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation

Leveraging Instance-, Image- and Dataset-Level Information for Weakly Supervised Instance Segmentation This paper has been accepted and early accessed

Yun Liu 39 Sep 20, 2022
Node-level Graph Regression with Deep Gaussian Process Models

Node-level Graph Regression with Deep Gaussian Process Models Prerequests our implementation is mainly based on tensorflow 1.x and gpflow 1.x: python

1 Jan 16, 2022
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
A simple program for training and testing vit

Vit This is a simple program for training and testing vit. Key requirements: torch, torchvision and timm. Dataset I put 5 categories of the cub classi

xiezhenyu 2 Oct 11, 2022
UMT is a unified and flexible framework which can handle different input modality combinations, and output video moment retrieval and/or highlight detection results.

Unified Multi-modal Transformers This repository maintains the official implementation of the paper UMT: Unified Multi-modal Transformers for Joint Vi

Applied Research Center (ARC), Tencent PCG 84 Jan 04, 2023
Official Pytorch Implementation for Splicing ViT Features for Semantic Appearance Transfer presenting Splice

Splicing ViT Features for Semantic Appearance Transfer [Project Page] Splice is a method for semantic appearance transfer, as described in Splicing Vi

Omer Bar Tal 253 Jan 06, 2023
Person Re-identification

Person Re-identification Final project of Computer Vision Table of content Person Re-identification Table of content Students: Proposed method Dataset

Nguyễn Hoàng Quân 4 Jun 17, 2021
LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
Face Alignment using python

Face Alignment Face Alignment using python Input Image Aligned Face Aligned Face Aligned Face Input Image Aligned Face Input Image Aligned Face Instal

Sajjad Aemmi 28 Nov 23, 2022
Time should be taken seer-iously

TimeSeers seers - (Noun) plural form of seer - A person who foretells future events by or as if by supernatural means TimeSeers is an hierarchical Bay

279 Dec 26, 2022
Potato Disease Classification - Training, Rest APIs, and Frontend to test.

Potato Disease Classification Setup for Python: Install Python (Setup instructions) Install Python packages pip3 install -r training/requirements.txt

codebasics 95 Dec 21, 2022
This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

This repository contains an overview of important follow-up works based on the original Vision Transformer (ViT) by Google.

75 Dec 02, 2022
The MATH Dataset

Measuring Mathematical Problem Solving With the MATH Dataset This is the repository for Measuring Mathematical Problem Solving With the MATH Dataset b

Dan Hendrycks 267 Dec 26, 2022
N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting

N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting Recent progress in neural forecasting instigated significant improvements in the

Cristian Challu 82 Jan 04, 2023
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 09, 2023
This code is part of the reproducibility package for the SANER 2022 paper "Generating Clarifying Questions for Query Refinement in Source Code Search".

Clarifying Questions for Query Refinement in Source Code Search This code is part of the reproducibility package for the SANER 2022 paper "Generating

Zachary Eberhart 0 Dec 04, 2021
PyTorch implementation for View-Guided Point Cloud Completion

PyTorch implementation for View-Guided Point Cloud Completion

22 Jan 04, 2023