Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)

Overview

Vision Transformer

Pytorch reimplementation of Google's repository for the ViT model that was released with the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.

This paper show that Transformers applied directly to image patches and pre-trained on large datasets work really well on image recognition task.

fig1

Vision Transformer achieve State-of-the-Art in image recognition task with standard Transformer encoder and fixed-size patches. In order to perform classification, author use the standard approach of adding an extra learnable "classification token" to the sequence.

fig2

Usage

1. Download Pre-trained model (Google's Official Checkpoint)

  • Available models: ViT-B_16(85.8M), R50+ViT-B_16(97.96M), ViT-B_32(87.5M), ViT-L_16(303.4M), ViT-L_32(305.5M), ViT-H_14(630.8M)
    • imagenet21k pre-train models
      • ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14
    • imagenet21k pre-train + imagenet2012 fine-tuned models
      • ViT-B_16-224, ViT-B_16, ViT-B_32, ViT-L_16-224, ViT-L_16, ViT-L_32
    • Hybrid Model(Resnet50 + Transformer)
      • R50-ViT-B_16
# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz

# imagenet21k pre-train + imagenet2012 fine-tuning
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{MODEL_NAME}.npz

2. Train Model

python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

CIFAR-10 and CIFAR-100 are automatically download and train. In order to use a different dataset you need to customize data_utils.py.

The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of --gradient_accumulation_steps.

Also can use Automatic Mixed Precision(Amp) to reduce memory usage and train faster

python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz --fp16 --fp16_opt_level O2

Results

To verify that the converted model weight is correct, we simply compare it with the author's experimental results. We trained using mixed precision, and --fp16_opt_level was set to O2.

imagenet-21k

model dataset resolution acc(official) acc(this repo) time
ViT-B_16 CIFAR-10 224x224 - 0.9908 3h 13m
ViT-B_16 CIFAR-10 384x384 0.9903 0.9906 12h 25m
ViT_B_16 CIFAR-100 224x224 - 0.923 3h 9m
ViT_B_16 CIFAR-100 384x384 0.9264 0.9228 12h 31m
R50-ViT-B_16 CIFAR-10 224x224 - 0.9892 4h 23m
R50-ViT-B_16 CIFAR-10 384x384 0.99 0.9904 15h 40m
R50-ViT-B_16 CIFAR-100 224x224 - 0.9231 4h 18m
R50-ViT-B_16 CIFAR-100 384x384 0.9231 0.9197 15h 53m
ViT_L_32 CIFAR-10 224x224 - 0.9903 2h 11m
ViT_L_32 CIFAR-100 224x224 - 0.9276 2h 9m
ViT_H_14 CIFAR-100 224x224 - WIP

imagenet-21k + imagenet2012

model dataset resolution acc
ViT-B_16-224 CIFAR-10 224x224 0.99
ViT_B_16-224 CIFAR-100 224x224 0.9245
ViT-L_32 CIFAR-10 224x224 0.9903
ViT-L_32 CIFAR-100 224x224 0.9285

shorter train

  • In the experiment below, we used a resolution size (224x224).
  • tensorboard
upstream model dataset total_steps /warmup_steps acc(official) acc(this repo)
imagenet21k ViT-B_16 CIFAR-10 500/100 0.9859 0.9859
imagenet21k ViT-B_16 CIFAR-10 1000/100 0.9886 0.9878
imagenet21k ViT-B_16 CIFAR-100 500/100 0.8917 0.9072
imagenet21k ViT-B_16 CIFAR-100 1000/100 0.9115 0.9216

Visualization

The ViT consists of a Standard Transformer Encoder, and the encoder consists of Self-Attention and MLP module. The attention map for the input image can be visualized through the attention score of self-attention.

Visualization code can be found at visualize_attention_map.

fig3

Reference

Citations

@article{dosovitskiy2020,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
Owner
Eunkwang Jeon
Eunkwang Jeon
An experiment to bait a generalized frontrunning MEV bot

Honeypot 🍯 A simple experiment that: Creates a honeypot contract Baits a generalized fronturnning bot with a unique transaction Analyze bot behaviour

0x1355 14 Nov 24, 2022
The FIRST GANs-based omics-to-omics translation framework

OmiTrans Please also have a look at our multi-omics multi-task DL freamwork 👀 : OmiEmbed The FIRST GANs-based omics-to-omics translation framework Xi

Xiaoyu Zhang 6 Dec 14, 2022
Oriented Object Detection: Oriented RepPoints + Swin Transformer/ReResNet

Oriented RepPoints for Aerial Object Detection The code for the implementation of “Oriented RepPoints + Swin Transformer/ReResNet”. Introduction Based

96 Dec 13, 2022
Image process framework based on plugin like imagej, it is esay to glue with scipy.ndimage, scikit-image, opencv, simpleitk, mayavi...and any libraries based on numpy

Introduction ImagePy is an open source image processing framework written in Python. Its UI interface, image data structure and table data structure a

ImagePy 1.2k Dec 29, 2022
LexGLUE: A Benchmark Dataset for Legal Language Understanding in English

LexGLUE: A Benchmark Dataset for Legal Language Understanding in English ⚖️ 🏆 🧑‍🎓 👩‍⚖️ Dataset Summary Inspired by the recent widespread use of th

95 Dec 08, 2022
This is the code repository for the paper "Identification of the Generalized Condorcet Winner in Multi-dueling Bandits" (NeurIPS 2021).

Code Repository for the Paper "Identification of the Generalized Condorcet Winner in Multi-dueling Bandits" (To appear in: Proceedings of NeurIPS20

1 Oct 03, 2022
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

61.4k Jan 04, 2023
Hard cater examples from Hopper ICLR paper

CATER-h Honglu Zhou*, Asim Kadav, Farley Lai, Alexandru Niculescu-Mizil, Martin Renqiang Min, Mubbasir Kapadia, Hans Peter Graf (*Contact: honglu.zhou

NECLA ML Group 6 May 11, 2021
A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generative Modeling" (ICCV 2021)

Manifold Matching via Deep Metric Learning for Generative Modeling A Pytorch implementation of "Manifold Matching via Deep Metric Learning for Generat

69 Dec 10, 2022
Training vision models with full-batch gradient descent and regularization

Stochastic Training is Not Necessary for Generalization -- Training competitive vision models without stochasticity This repository implements trainin

Jonas Geiping 32 Jan 06, 2023
AI drive app that can help user become beautiful.

爱美丽 Beauty 简体中文 Features Beauty is an AI drive app that can help user become beautiful. it contain those functions: face score cheek face beauty repor

Starved Midnight 1 Jan 30, 2022
One Million Scenes for Autonomous Driving

ONCE Benchmark This is a reproduced benchmark for 3D object detection on the ONCE (One Million Scenes) dataset. The code is mainly based on OpenPCDet.

148 Dec 28, 2022
Whisper is a file-based time-series database format for Graphite.

Whisper Overview Whisper is one of three components within the Graphite project: Graphite-Web, a Django-based web application that renders graphs and

Graphite Project 1.2k Dec 25, 2022
Classification of ecg datas for disease detection

ecg_classification Classification of ecg datas for disease detection

Atacan ÖZKAN 5 Sep 09, 2022
GeDML is an easy-to-use generalized deep metric learning library

GeDML is an easy-to-use generalized deep metric learning library

Borui Zhang 32 Dec 05, 2022
DEMix Layers for Modular Language Modeling

DEMix This repository contains modeling utilities for "DEMix Layers: Disentangling Domains for Modular Language Modeling" (Gururangan et. al, 2021). T

Suchin 43 Nov 11, 2022
KGDet: Keypoint-Guided Fashion Detection (AAAI 2021)

KGDet: Keypoint-Guided Fashion Detection (AAAI 2021) This is an official implementation of the AAAI-2021 paper "KGDet: Keypoint-Guided Fashion Detecti

Qian Shenhan 35 Dec 29, 2022
A PyTorch Library for Accelerating 3D Deep Learning Research

Kaolin: A Pytorch Library for Accelerating 3D Deep Learning Research Overview NVIDIA Kaolin library provides a PyTorch API for working with a variety

NVIDIA GameWorks 3.5k Jan 07, 2023
Code for PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning

PackNet: https://arxiv.org/abs/1711.05769 Pretrained models are available here: https://uofi.box.com/s/zap2p03tnst9dfisad4u0sfupc0y1fxt Datasets in Py

Arun Mallya 216 Jan 05, 2023
Neural Style and MSG-Net

PyTorch-Style-Transfer This repo provides PyTorch Implementation of MSG-Net (ours) and Neural Style (Gatys et al. CVPR 2016), which has been included

Hang Zhang 904 Dec 21, 2022