Refactoring dalle-pytorch and taming-transformers for TPU VM

Overview

Text-to-Image Translation (DALL-E) for TPU in Pytorch

Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning

Requirements

pip install -r requirements.txt

Data Preparation

Place any image dataset with ImageNet-style directory structure (at least 1 subfolder) to fit the dataset into pytorch ImageFolder.

Training VQVAEs

You can easily test main.py with randomly generated fake data.

python train_vae.py --use_tpus --fake_data

For actual training provide specific directory for train_dir, val_dir, log_dir:

python train_vae.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results]

Training DALL-E

python train_dalle.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results] --vae_path [pretrained vae] --bpe_path [pretrained bpe(optional)]

TODO

  • Refactor Encoder and Decoder modules for better readability
  • Refactor VQVAE2
  • Add Net2Net Conditional Transformer for conditional image generation
  • Refactor, optimize, and merge DALL-E with Net2Net Conditional Transformer
  • Add Guided Diffusion + CLIP for image refinement
  • Add VAE converter for JAX to support dalle-mini
  • Add DALL-E colab notebook
  • Add RBGumbelQuantizer
  • Add HiT

ON-GOING

  • Test large dataset loading on TPU Pods
  • Change current DALL-E code to fully support latest updates from DALLE-pytorch

DONE

  • Add VQVAE, VQGAN, and Gumbel VQVAE(Discrete VAE), Gumbel VQGAN
  • Add VQVAE2
  • Add EMA update for Vector Quantization
  • Debug VAEs (Single TPU Node, TPU Pods, GPUs)
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3028
  • Add DALL-E
  • Debug DALL-E (Single TPU Node, TPU Pods, GPUs)
  • Add WebDataset support
  • Add VAE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add DALLE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add automatic checkpoint saver and resume for sudden (which happens a lot) TPU restart
  • Reimplement EMA VectorQuantizer with nn.Embedding
  • Add DALL-E colab notebook by afiaka87
  • Add Normed Vector Quantizer by GallagherCommaJack
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3068
  • Debug WebDataset functionality

BibTeX

@misc{oord2018neural,
      title={Neural Discrete Representation Learning}, 
      author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
      year={2018},
      eprint={1711.00937},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{razavi2019generating,
      title={Generating Diverse High-Fidelity Images with VQ-VAE-2}, 
      author={Ali Razavi and Aaron van den Oord and Oriol Vinyals},
      year={2019},
      eprint={1906.00446},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{esser2020taming,
      title={Taming Transformers for High-Resolution Image Synthesis}, 
      author={Patrick Esser and Robin Rombach and Björn Ommer},
      year={2020},
      eprint={2012.09841},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{ramesh2021zeroshot,
    title   = {Zero-Shot Text-to-Image Generation}, 
    author  = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
    year    = {2021},
    eprint  = {2102.12092},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Owner
Kim, Taehoon
Research Scientist & Machine Learning Engineer.
Kim, Taehoon
CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning

CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning This repository contains the code and relevant instructions

XiaoMing 5 Aug 19, 2022
Awesome Remote Sensing Toolkit based on PaddlePaddle.

基于飞桨框架开发的高性能遥感图像处理开发套件,端到端地完成从训练到部署的全流程遥感深度学习应用。 最新动态 PaddleRS 即将发布alpha版本!欢迎大家试用 简介 PaddleRS是遥感科研院所、相关高校共同基于飞桨开发的遥感处理平台,支持遥感图像分类,目标检测,图像分割,以及变化检测等常用遥

146 Dec 11, 2022
My implementation of transformers related papers for computer vision in pytorch

vision_transformers This is my personnal repo to implement new transofrmers based and other computer vision DL models I am currenlty working without a

samsja 1 Nov 10, 2021
ACL'2021: LM-BFF: Better Few-shot Fine-tuning of Language Models

LM-BFF (Better Few-shot Fine-tuning of Language Models) This is the implementation of the paper Making Pre-trained Language Models Better Few-shot Lea

Princeton Natural Language Processing 607 Jan 07, 2023
PyTorch Implementation of the paper Learning to Reweight Examples for Robust Deep Learning

Learning to Reweight Examples for Robust Deep Learning Unofficial PyTorch implementation of Learning to Reweight Examples for Robust Deep Learning. Th

Daniel Stanley Tan 325 Dec 28, 2022
Project page for End-to-end Recovery of Human Shape and Pose

End-to-end Recovery of Human Shape and Pose Angjoo Kanazawa, Michael J. Black, David W. Jacobs, Jitendra Malik CVPR 2018 Project Page Requirements Pyt

1.4k Dec 29, 2022
Multi-Objective Reinforced Active Learning

Multi-Objective Reinforced Active Learning Dependencies wandb tqdm pytorch = 1.7.0 numpy = 1.20.0 scipy = 1.1.0 pycolab == 1.2 Weights and Biases O

Markus Peschl 6 Nov 19, 2022
Python implementation of Project Fluent

Project Fluent This is a collection of Python packages to use the Fluent localization system. python-fluent consists of these packages: fluent.syntax

Project Fluent 155 Dec 28, 2022
Keras Image Embeddings using Contrastive Loss

Keras-Image-Embeddings-using-Contrastive-Loss Image to Embedding projection in vector space. Implementation in keras and tensorflow for custom data. B

Shravan Anand K 5 Mar 21, 2022
Unsupervised Foreground Extraction via Deep Region Competition

Unsupervised Foreground Extraction via Deep Region Competition [Paper] [Code] The official code repository for NeurIPS 2021 paper "Unsupervised Foregr

28 Nov 06, 2022
Exploring Versatile Prior for Human Motion via Motion Frequency Guidance (3DV2021)

Exploring Versatile Prior for Human Motion via Motion Frequency Guidance [Video Demo] [Paper] Installation Requirements Python 3.6 PyTorch 1.1.0 Pleas

Jiachen Xu 19 Oct 28, 2022
gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions

gtfs2vec This is a companion repository for a gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions publication. Vis

Politechnika Wrocławska - repozytorium dla informatyków 5 Oct 10, 2022
Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters"

Manga Character Screentone Synthesis Official PyTorch implementation of "Synthesis of Screentone Patterns of Manga Characters" presented in IEEE ISM 2

Tsubota 2 Nov 20, 2021
Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)

DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) This repo is the official imp

如今我已剑指天涯 46 Dec 21, 2022
[NeurIPS 2019] Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, Tengyu Ma This is the offi

Kaidi Cao 528 Jan 01, 2023
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
Syllabus del curso IIC2115 - Programación como Herramienta para la Ingeniería 2022/I

IIC2115 - Programación como Herramienta para la Ingeniería Videos y tutoriales Tutorial CMD Tutorial Instalación Python y Jupyter Tutorial de git-GitH

21 Nov 09, 2022
The final project of "Applying AI to 2D Medical Imaging Data" of "AI for Healthcare" nanodegree - Udacity.

Pneumonia Detection from X-Rays Project Overview In this project, you will apply the skills that you have acquired in this 2D medical imaging course t

Omar Laham 1 Jan 14, 2022
Source code for the paper "PLOME: Pre-training with Misspelled Knowledge for Chinese Spelling Correction" in ACL2021

PLOME:Pre-training with Misspelled Knowledge for Chinese Spelling Correction (ACL2021) This repository provides the code and data of the work in ACL20

197 Nov 26, 2022
Diffusion Probabilistic Models for 3D Point Cloud Generation (CVPR 2021)

Diffusion Probabilistic Models for 3D Point Cloud Generation [Paper] [Code] The official code repository for our CVPR 2021 paper "Diffusion Probabilis

Shitong Luo 323 Jan 05, 2023