GAN JAX - A toy project to generate images from GANs with JAX

Related tags

Deep LearningGANJax
Overview

GAN JAX - A toy project to generate images from GANs with JAX

This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation.

JAX

JAX logo

JAX is a framework developed by Deep-Mind (Google) that allows to build machine learning models in a more powerful (XLA compilation) and flexible way than its counterpart Tensorflow, using a framework almost entirely based on the nd.array of numpy (but stored on the GPU, or TPU if available). It also provides new utilities for gradient computation (per sample, jacobian with backward propagation and forward-propagation, hessian...) as well as a better seed system (for reproducibility) and a tool to batch complicated operations automatically and efficiently.

Github link: https://github.com/google/jax

GAN

GAN diagram

Generative adversarial networks (GANs) are algorithmic architectures that use two neural networks, pitting one against the other (thus the adversarial) in order to generate new, synthetic instances of data that can pass for real data. They are used widely in image generation, video generation and voice generation. GANs were introduced in a paper by Ian Goodfellow and other researchers at the University of Montreal, including Yoshua Bengio, in 2014. Referring to GANs, Facebook’s AI research director Yann LeCun called adversarial training the most interesting idea in the last 10 years in ML. (source)

Original paper: https://arxiv.org/abs/1406.2661

Some ideas have improved the training of the GANs by the years. For example:

Deep Convolution GAN (DCGAN) paper: https://arxiv.org/abs/1511.06434

Progressive Growing GAN (ProGAN) paper: https://arxiv.org/abs/1710.10196

The goal of this project is to implement these ideas in JAX framework.

Installation

You can install JAX following the instruction on JAX - Installation

It is strongly recommended to run JAX on Linux with CUDA available (Windows has no stable support yet). In this case you can install JAX using the following command:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Then you can install Tensorflow to benefit from tf.data.Dataset to handle the data and the pre-installed dataset. However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU. Visit this site Tensorflow - Installation with pip to install the CPU-only version of Tensorflow 2 depending on your OS and your Python version.

Exemple with Linux and Python 3.9:

pip install tensorflow -f https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.6.0-cp39-cp39-manylinux2010_x86_64.whl

Then you can install the other librairies from requirements.txt. It will install Haiku and Optax, two usefull add-on libraries to implement and optimize machine learning models with JAX.

pip install -r requirements.txt

Install CelebA dataset (optional)

To use the CelebA dataset, you need to download the dataset from Kaggle and install the images in the folder img_align_celeba/ in data/CelebA/images. It is recommended to download the dataset from this source because the faces are already cropped.

Note: the other datasets will be automatically installed with keras or tensorflow-datasets.

Quick Start

You can test a pretrained GAN model by using apps/test.py. It will download the model from pretrained models (in pre_trained/) and generate pictures. You can change the GAN to test by changing the path in the script.

You can also train your own GAN from scratch with apps/train.py. To change the parameters of the training, you can change the configs in the script. You can also change the dataset or the type of GAN by changing the imports (there is only one workd to change for each).

Example to train a GAN in celeba (64x64):

from utils.data import load_images_celeba_64 as load_images

To train a DCGAN:

from gan.dcgan import DCGAN as GAN

Then you can implement your own GAN and train/test them in your own dataset (by overriding the appropriate functions, check the examples in the repository).

Some results of pre-trained models

- Deep Convolution GAN

  • On MNIST:

DCGAN Cifar10

  • On Cifar10:

DCGAN Cifar10

  • On CelebA (64x64):

DCGAN CelebA-64

- Progressive Growing GAN

  • On MNIST:

  • On Cifar10:

  • On CelebA (64x64):

  • On CelebA (128x128):

Owner
Valentin Goldité
Student at CentraleSupelec (top french Engineer School) specialized in machine learning (Computer Vision, NLP, Audio, RL, Time Analysis).
Valentin Goldité
Deeper insights into graph convolutional networks for semi-supervised learning

deeper_insights_into_GCNs Deeper insights into graph convolutional networks for semi-supervised learning References data and utils.py come from Implem

Davidham3 17 Dec 16, 2022
(CVPR 2021) Lifting 2D StyleGAN for 3D-Aware Face Generation

Lifting 2D StyleGAN for 3D-Aware Face Generation Official implementation of paper "Lifting 2D StyleGAN for 3D-Aware Face Generation". Requirements You

Yichun Shi 66 Nov 29, 2022
Catbird is an open source paraphrase generation toolkit based on PyTorch.

Catbird is an open source paraphrase generation toolkit based on PyTorch. Quick Start Requirements and Installation The project is based on PyTorch 1.

Afonso Salgado de Sousa 5 Dec 15, 2022
Very large and sparse networks appear often in the wild and present unique algorithmic opportunities and challenges for the practitioner

Sparse network learning with snlpy Very large and sparse networks appear often in the wild and present unique algorithmic opportunities and challenges

Andrew Stolman 1 Apr 30, 2021
tf2-keras implement yolov5

YOLOv5 in tesnorflow2.x-keras yolov5数据增强jupyter示例 Bilibili视频讲解地址: 《yolov5 解读,训练,复现》 Bilibili视频讲解PPT文件: yolov5_bilibili_talk_ppt.pdf Bilibili视频讲解PPT文件:

yangcheng 254 Jan 08, 2023
Creating predictive checklists from data using integer programming.

Learning Optimal Predictive Checklists A Python package to learn simple predictive checklists from data subject to customizable constraints. For more

Healthy ML 5 Apr 19, 2022
MTA:SA Server Configer.

MTAConfiger MTA:SA Server Configer. Hi 👋 , I'm Alireza A Python Developer Boy 🔭 I’m currently working on my C# projects 🌱 I’m currently Learning CS

3 Jun 07, 2022
Graph Representation Learning via Graphical Mutual Information Maximization

GMI (Graphical Mutual Information) Graph Representation Learning via Graphical Mutual Information Maximization (Peng Z, Huang W, Luo M, et al., WWW 20

93 Dec 29, 2022
REGTR: End-to-end Point Cloud Correspondences with Transformers

REGTR: End-to-end Point Cloud Correspondences with Transformers This repository contains the source code for REGTR. REGTR utilizes multiple transforme

Zi Jian Yew 108 Dec 17, 2022
Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021]

Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021] Paper: https://arxiv.org/abs/2104.11208 Introduction Despite the significa

76 Dec 07, 2022
Breast Cancer Detection 🔬 ITI "AI_Pro" Graduation Project

BreastCancerDetection - This program is designed to predict two severity of abnormalities associated with breast cancer cells: benign and malignant. Mammograms from MIAS is preprocessed and features

6 Nov 29, 2022
Pytorch tutorials for Neural Style transfert

PyTorch Tutorials This tutorial is no longer maintained. Please use the official version: https://pytorch.org/tutorials/advanced/neural_style_tutorial

Alexis David Jacq 135 Jun 26, 2022
Baseline for the Spoofing-aware Speaker Verification Challenge 2022

Introduction This repository contains several materials that supplements the Spoofing-Aware Speaker Verification (SASV) Challenge 2022 including: calc

40 Dec 28, 2022
Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"

T-Few This repository contains the official code for the paper: "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learni

220 Dec 31, 2022
FID calculation with proper image resizing and quantization steps

clean-fid: Fixing Inconsistencies in FID Project | Paper The FID calculation involves many steps that can produce inconsistencies in the final metric.

Gaurav Parmar 606 Jan 06, 2023
Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021)

T2Net Task Transformer Network for Joint MRI Reconstruction and Super-Resolution (MICCAI 2021) [Paper][Code] Dependencies numpy==1.18.5 scikit_image==

64 Nov 23, 2022
GLODISMO: Gradient-Based Learning of Discrete Structured Measurement Operators for Signal Recovery

GLODISMO: Gradient-Based Learning of Discrete Structured Measurement Operators for Signal Recovery This is the code to the paper: Gradient-Based Learn

3 Feb 15, 2022
Machine learning, in numpy

numpy-ml Ever wish you had an inefficient but somewhat legible collection of machine learning algorithms implemented exclusively in NumPy? No? Install

David Bourgin 11.6k Dec 30, 2022
A torch implementation of "Pixel-Level Domain Transfer"

Pixel Level Domain Transfer A torch implementation of "Pixel-Level Domain Transfer". based on dcgan.torch. Dataset The dataset used is "LookBook", fro

Fei Xia 260 Sep 02, 2022
KwaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%)

KuaiRec: A Fully-observed Dataset for Recommender Systems (Density: Almost 100%) KuaiRec is a real-world dataset collected from the recommendation log

Chongming GAO (高崇铭) 70 Dec 28, 2022