DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Overview

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Cho-Jui Hsieh

This repository contains PyTorch implementation for DynamicViT.

We introduce a dynamic token sparsification framework to prune redundant tokens in vision transformers progressively and dynamically based on the input:

intro

Our code is based on pytorch-image-models, DeiT and LV-ViT

[Project Page] [arXiv]

Model Zoo

We provide our DynamicViT models pretrained on ImageNet:

name arch rho [email protected] [email protected] FLOPs url
DynamicViT-256/0.7 deit_256 0.7 76.532 93.118 1.3G Google Drive / Tsinghua Cloud
DynamicViT-384/0.7 deit_small 0.7 79.316 94.676 2.9G Google Drive / Tsinghua Cloud
DynamicViT-LV-S/0.5 lvvit_s 0.5 81.970 95.756 3.7G Google Drive / Tsinghua Cloud
DynamicViT-LV-S/0.7 lvvit_s 0.7 83.076 96.252 4.6G Google Drive / Tsinghua Cloud
DynamicViT-LV-M/0.7 lvvit_m 0.7 83.816 96.584 8.5G Google Drive / Tsinghua Cloud

Usage

Requirements

  • torch>=1.7.0
  • torchvision>=0.8.1
  • timm==0.4.5

Data preparation: download and extract ImageNet images from http://image-net.org/. The directory structure should be

│ILSVRC2012/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Model preparation: download pre-trained DeiT and LV-ViT models for training DynamicViT:

sh download_pretrain.sh

Demo

We provide a Jupyter notebook where you can run the visualization of DynamicViT.

To run the demo, you need to install matplotlib.

demo

Evaluation

To evaluate a pre-trained DynamicViT model on ImageNet val with a single GPU, run:

python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --model-path /path/to/model --base_rate 0.7 

Training

To train DynamicViT models on ImageNet, run:

DeiT-small

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_deit-small --arch deit_small --input-size 224 --batch-size 96 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

LV-ViT-S

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_lvvit-s --arch lvvit_s --input-size 224 --batch-size 64 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

LV-ViT-M

python -m torch.distributed.launch --nproc_per_node=8 --use_env main_dynamic_vit.py  --output_dir logs/dynamic-vit_lvvit-m --arch lvvit_m --input-size 224 --batch-size 48 --data-path /path/to/ILSVRC2012/ --epochs 30 --dist-eval --distill --base_rate 0.7

You can train models with different keeping ratio by adjusting base_rate. DynamicViT can also achieve comparable performance with only 15 epochs training (around 0.1% lower accuracy).

License

MIT License

Citation

If you find our work useful in your research, please consider citing:

@article{rao2021dynamicvit,
  title={DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification},
  author={Rao, Yongming and Zhao, Wenliang and Liu, Benlin and Lu, Jiwen and Zhou, Jie and Hsieh, Cho-Jui},
  journal={arXiv preprint arXiv:2106.02034},
  year={2021}
}
FAST Aiming at the problems of cumbersome steps and slow download speed of GNSS data

FAST Aiming at the problems of cumbersome steps and slow download speed of GNSS data, a relatively complete set of integrated multi-source data download terminal software fast is developed. The softw

ChangChuntao 23 Dec 31, 2022
PyTorchVideo is a deeplearning library with a focus on video understanding work

PyTorchVideo is a deeplearning library with a focus on video understanding work. PytorchVideo provides resusable, modular and efficient components needed to accelerate the video understanding researc

Facebook Research 2.7k Jan 07, 2023
"Graph Neural Controlled Differential Equations for Traffic Forecasting", AAAI 2022

Graph Neural Controlled Differential Equations for Traffic Forecasting Setup Python environment for STG-NCDE Install python environment $ conda env cr

Jeongwhan Choi 55 Dec 28, 2022
An implementation of the WHATWG URL Standard in JavaScript

whatwg-url whatwg-url is a full implementation of the WHATWG URL Standard. It can be used standalone, but it also exposes a lot of the internal algori

314 Dec 28, 2022
Galactic and gravitational dynamics in Python

Gala is a Python package for Galactic and gravitational dynamics. Documentation The documentation for Gala is hosted on Read the docs. Installation an

Adrian Price-Whelan 101 Dec 22, 2022
On Out-of-distribution Detection with Energy-based Models

On Out-of-distribution Detection with Energy-based Models This repository contains the code for the experiments conducted in the paper On Out-of-distr

Sven 19 Aug 07, 2022
A graph adversarial learning toolbox based on PyTorch and DGL.

GraphWar: Arms Race in Graph Adversarial Learning NOTE: GraphWar is still in the early stages and the API will likely continue to change. 🚀 Installat

Jintang Li 54 Jan 05, 2023
Deep Learning for Natural Language Processing SS 2021 (TU Darmstadt)

Deep Learning for Natural Language Processing SS 2021 (TU Darmstadt) Task Training huge unsupervised deep neural networks yields to strong progress in

2 Aug 05, 2022
Object detection evaluation metrics using Python.

Object detection evaluation metrics using Python.

Louis Facun 2 Sep 06, 2022
[ICCV21] Self-Calibrating Neural Radiance Fields

Self-Calibrating Neural Radiance Fields, ICCV, 2021 Project Page | Paper | Video Author Information Yoonwoo Jeong [Google Scholar] Seokjun Ahn [Google

381 Dec 30, 2022
ZEBRA: Zero Evidence Biometric Recognition Assessment

ZEBRA: Zero Evidence Biometric Recognition Assessment license: LGPLv3 - please reference our paper version: 2020-06-11 author: Andreas Nautsch (EURECO

Voice Privacy Challenge 2 Dec 12, 2021
Code for the CIKM 2019 paper "DSANet: Dual Self-Attention Network for Multivariate Time Series Forecasting".

Dual Self-Attention Network for Multivariate Time Series Forecasting 20.10.26 Update: Due to the difficulty of installation and code maintenance cause

Kyon Huang 223 Dec 16, 2022
Automatic Attendance marker for LMS Practice School Division, BITS Pilani

LMS Attendance Marker Automatic script for lazy people to mark attendance on LMS for Practice School 1. Setup Add your LMS credentials and time slot t

Nihar Bansal 3 Jun 12, 2021
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Dec 28, 2022
My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs (GNN, GAT, GraphSAGE, GCN)

machine-learning-with-graphs My solutions for Stanford University course CS224W: Machine Learning with Graphs Fall 2021 colabs Course materials can be

Marko Njegomir 7 Dec 14, 2022
Massively parallel Monte Carlo diffusion MR simulator written in Python.

Disimpy Disimpy is a Python package for generating simulated diffusion-weighted MR signals that can be useful in the development and validation of dat

Leevi 16 Nov 11, 2022
Codes and pretrained weights for winning submission of 2021 Brain Tumor Segmentation (BraTS) Challenge

Winning submission to the 2021 Brain Tumor Segmentation Challenge This repo contains the codes and pretrained weights for the winning submission to th

94 Dec 28, 2022
Official implementation of AAAI-21 paper "Label Confusion Learning to Enhance Text Classification Models"

Description: This is the official implementation of our AAAI-21 accepted paper Label Confusion Learning to Enhance Text Classification Models. The str

101 Nov 25, 2022
The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track.

ISC21-Descriptor-Track-1st The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track. You can check our solution

lyakaap 73 Dec 24, 2022
Coarse implement of the paper "A Simultaneous Denoising and Dereverberation Framework with Target Decoupling", On DNS-2020 dataset, the DNSMOS of first stage is 3.42 and second stage is 3.47.

SDDNet Coarse implement of the paper "A Simultaneous Denoising and Dereverberation Framework with Target Decoupling", On DNS-2020 dataset, the DNSMOS

Cyril Lv 43 Nov 21, 2022