A 1.3B text-to-image generation model trained on 14 million image-text pairs

Overview

minDALL-E on Conceptual Captions

minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for non-commercial purposes.

a painting of a bird in the style of asian painting a photo of san francisco's golden gate bridge in black and white tone

Environment Setup

  • Basic setup
PyTorch == 1.8.0
CUDA >= 10.1
  • Other packages
pip install -r requirements.txt

Model Checkpoint

  • Model structure (two-stage autoregressive model)
    • Stage1: Unlike the original DALL-E [1], we replace Discrete VAE with VQGAN [2] to generate high-quality samples effectively. We slightly fine-tune vqgan_imagenet_f16_16384, provided by the official VQGAN repository, on FFHQ [3] as well as ImageNet.
    • Stage2: We train our 1.3B transformer from scratch on 14 million image-text pairs from CC3M [4] and CC12M [5]. For the more detailed model spec, please see configs/dalle-1.3B.yaml.
  • You can download the pretrained models including the tokenizer from this link. This will require about 5GB space.

Sampling

  • Given a text prompt, the code snippet below generates candidate images and re-ranks them using OpenAI's CLIP [6].
  • This has been tested under a single V100 of 32GB memory. In the case of using GPUs with limited memory, please lower down num_candidates to avoid OOM.
from matplotlib import pyplot as plt
import clip
from dalle.models import Dalle
from dalle.utils.utils import set_seed, clip_score

device = 'cuda:0'
set_seed(0)

prompt = "A painting of a monkey with sunglasses in the frame"
model = Dalle.from_pretrained('minDALL-E/1.3B')  # This will automatically download the pretrained model.
model.to(device=device)

# Sampling
images = model.sampling(prompt=prompt,
                        top_k=256, # It is recommended that top_k is set lower than 256.
                        top_p=None,
                        softmax_temperature=1.0,
                        num_candidates=96,
                        device=device).cpu().numpy()
images = np.transpose(images, (0, 2, 3, 1))

# CLIP Re-ranking
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
model_clip.to(device=device)
rank = clip_score(prompt=prompt,
                  images=images,
                  model_clip=model_clip,
                  preprocess_clip=preprocess_clip,
                  device=device)

# Plot images
images = images[rank]
plt.imshow(images[0])
plt.show()

Samples (Top-K=256, Temperature=1.0)

  • "a painting of a {cat, dog} with sunglasses in the frame"

  • "a large {pink, black} elephant walking on the beach"

  • "Eiffel tower on a {desert, mountain}"

Quantitative Results

  • We have validated minDALL-E on the CC3M validation set (in-distribution evaluation) and MS-COCO (zero-shot evaluation).
  • For CC3M, we measure the cosine similarity between image and text representations from the pretrained CLIP model (ViT-B/32), referred to as CLIP-score.
  • For MS-COCO, we compute FID between 30K generated and real samples from MS-COCO 2017, where we randomly choose 30K captions from COCO as in DALL-E. We select the best out of 32 candidates by CLIP re-ranking.
Model CC3M:CLIP-score (higher is better) MS-COCO:FID-30K (lower is better)
VQGAN [2] 0.20 -
ImageBART [7] 0.23 -
DALL-E [1] - 27.5
minDALL-E 0.26 14.7

Transfer Learning Examples

  • minDALL-E, which is pre-trained on noisy text supervisions, could be transferable to class-conditional and unconditional generation tasks. To validate this, we simply fine-tune it on ImageNet over 8 epochs in the case of class-conditional generation and unconditional generation.
  • The commands below fine-tune the pretrained DALL-E. It takes about 36 hours on 8 V100 GPUs.
# unconditinoal image generation for imagenet (256x256)
python examples/transfer_learning_ex.py -d=configs/transfer-imagenet-uncond-gen.yaml
                                        -u=[MODEL_CKPT]
                                        -r=[RESULT_PATH]
                                        --n-gpus=[NUM_GPUS]

# class-conditinoal image generation for imagenet (256x256)
python examples/transfer_learning_ex.py -d=configs/transfer-imagenet-clscond-gen.yaml
                                        -u=[MODEL_CKPT]
                                        -r=[RESULT_PATH]
                                        --n-gpus=[NUM_GPUS]
  • We compute FID-50K between 50K generated samples and all ImageNet training samples, where we use top-k=256 and softmax temperature=1.0 for generation. All results are obtained without the rejection sampling. Interestingly, our model achieves very competitive performance with baselines, even though minDALL-E is fine-tuned in a few epochs.
Model Params FID-50K(class-cond.) FID-50K(uncond.)
VQ-GAN 1.4B 15.78 -
ImageBART 3.5B 21.19 -
minDALL-E 1.3B 15.55 37.58

BibTex

If you find this repository useful in your research, please cite:

@misc{kakaobrain2021minDALL-E,
  title         = {minDALL-E on Conceptual Captions},
  author        = {Saehoon Kim, Sanghun Cho, Chiheon Kim, Doyup Lee, and Woonhyuk Baek},
  year          = {2021},
  howpublished  = {\url{https://github.com/kakaobrain/minDALL-E}},
}

References

  • [1] Ramesh et al. Zero-Shot Text-to-Image Generation, ICML 2021.
  • [2] Esser et al. Taming Transformers for High-Resolution Image Synthesis, CVPR 2021.
  • [3] Karras et al. A Style-Based Generator Architecture for Generative Adversarial Networks, CVPR 2019.
  • [4] Sharma et al. Conceptual Captions: A Cleaned, Hypernymed, Image Alt-text Dataset For Automatic Image Captioning, ACL 2018.
  • [5] Changpinyo et al. Conceptual 12M: Pushing Web-Scale Image-Text Pre-Training To Recognize Long-Tail Visual Concepts, CVPR 2021.
  • [6] Radford et al. Learning Transferable Visual Models From Natural Language Supervision, ICML 2021.
  • [7] Esser et al. ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive Image Synthesis, NeurIPS 2021.
  • [8] https://github.com/karpathy/minGPT

Licenses

  • The source codes are licensed under Apache 2.0 License.
  • The stage2 pretrained weights are licensed under CC-BY-NC-SA 4.0 License.

Contact

We hope that minDALL-E helps various projects in research-oriented institutes and startups. If you would like to collaborate with us or share a feedback, please e-mail to us, [email protected]

Limitations

Although minDALL-E is trained on a small set (14M image-text pairs), this might be vulnerable to malicious attacks from the prompt engineering to generate socially unacceptable images. If you obersve these images, please report the "prompt" and "generated images" to us.

Owner
Kakao Brain
Kakao Brain Corp.
Kakao Brain
An Intelligent Self-driving Truck System For Highway Transportation

Inceptio Intelligent Truck System An Intelligent Self-driving Truck System For Highway Transportation Note The code is still in development. OS requir

InceptioResearch 11 Jul 13, 2022
Implementation of C-RNN-GAN.

Implementation of C-RNN-GAN. Publication: Title: C-RNN-GAN: Continuous recurrent neural networks with adversarial training Information: http://mogren.

Olof Mogren 427 Dec 25, 2022
Code To Tune or Not To Tune? Zero-shot Models for Legal Case Entailment.

COLIEE 2021 - task 2: Legal Case Entailment This repository contains the code to reproduce NeuralMind's submissions to COLIEE 2021 presented in the pa

NeuralMind 13 Dec 16, 2022
A Streamlit demo demonstrating the Deep Dream technique. Adapted from the TensorFlow Deep Dream tutorial.

Streamlit Demo: Deep Dream A Streamlit demo demonstrating the Deep Dream technique. Adapted from the TensorFlow Deep Dream tutorial How to run this de

Streamlit 11 Dec 12, 2022
Open-source python package for the extraction of Radiomics features from 2D and 3D images and binary masks.

pyradiomics v3.0.1 Build Status Linux macOS Windows Radiomics feature extraction in Python This is an open-source python package for the extraction of

Artificial Intelligence in Medicine (AIM) Program 842 Dec 28, 2022
Source code for ZePHyR: Zero-shot Pose Hypothesis Rating @ ICRA 2021

ZePHyR: Zero-shot Pose Hypothesis Rating ZePHyR is a zero-shot 6D object pose estimation pipeline. The core is a learned scoring function that compare

R-Pad - Robots Perceiving and Doing 18 Aug 22, 2022
The implementation our EMNLP 2021 paper "Enhanced Language Representation with Label Knowledge for Span Extraction".

LEAR The implementation our EMNLP 2021 paper "Enhanced Language Representation with Label Knowledge for Span Extraction". See below for an overview of

杨攀 93 Jan 07, 2023
Voila - Voilà turns Jupyter notebooks into standalone web applications

Rendering of live Jupyter notebooks with interactive widgets. Introduction Voilà turns Jupyter notebooks into standalone web applications. Unlike the

Voilà Dashboards 4.5k Jan 03, 2023
An implementation of the Contrast Predictive Coding (CPC) method to train audio features in an unsupervised fashion.

CPC_audio This code implements the Contrast Predictive Coding algorithm on audio data, as described in the paper Unsupervised Pretraining Transfers we

8 Nov 14, 2022
A comprehensive list of published machine learning applications to cosmology

ml-in-cosmology This github attempts to maintain a comprehensive list of published machine learning applications to cosmology, organized by subject ma

George Stein 290 Dec 29, 2022
Source code and data from the RecSys 2020 article "Carousel Personalization in Music Streaming Apps with Contextual Bandits" by W. Bendada, G. Salha and T. Bontempelli

Carousel Personalization in Music Streaming Apps with Contextual Bandits - RecSys 2020 This repository provides Python code and data to reproduce expe

Deezer 48 Jan 02, 2023
Repo for the Tutorials of Day1-Day3 of the Nordic Probabilistic AI School 2021 (https://probabilistic.ai/)

ProbAI 2021 - Probabilistic Programming and Variational Inference Tutorial with Pryo Day 1 (June 14) Slides Notebook: students_PPLs_Intro Notebook: so

PGM-Lab 46 Nov 01, 2022
A library of extension and helper modules for Python's data analysis and machine learning libraries.

Mlxtend (machine learning extensions) is a Python library of useful tools for the day-to-day data science tasks. Sebastian Raschka 2014-2020 Links Doc

Sebastian Raschka 4.2k Jan 02, 2023
Audio2Face - Audio To Face With Python

Audio2Face Discription We create a project that transforms audio to blendshape w

FACEGOOD 724 Dec 26, 2022
pytorch implementation for PointNet

PointNet.pytorch This repo is implementation for PointNet in pytorch. The model is in pointnet/model.py. It is teste

Fei Xia 1.7k Dec 30, 2022
Spatial Transformer Nets in TensorFlow/ TensorLayer

MOVED TO HERE Spatial Transformer Networks Spatial Transformer Networks (STN) is a dynamic mechanism that produces transformations of input images (or

Hao 36 Nov 23, 2022
CountDown to New Year and shoot fireworks

CountDown and Shoot Fireworks About App This is an small application make you re

5 Dec 31, 2022
TeST: Temporal-Stable Thresholding for Semi-supervised Learning

TeST: Temporal-Stable Thresholding for Semi-supervised Learning TeST Illustration Semi-supervised learning (SSL) offers an effective method for large-

Xiong Weiyu 1 Jul 14, 2022
Repository for GNSS-based position estimation using a Deep Neural Network

Code repository accompanying our work on 'Improving GNSS Positioning using Neural Network-based Corrections'. In this paper, we present a Deep Neural

32 Dec 13, 2022
[CVPR2021] Look before you leap: learning landmark features for one-stage visual grounding.

LBYL-Net This repo implements paper Look Before You Leap: Learning Landmark Features For One-Stage Visual Grounding CVPR 2021. Getting Started Prerequ

SVIP Lab 45 Dec 12, 2022