Pytorch implementation of "Training a 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet"

Overview

Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet (arxiv)

This is a Pytorch implementation of our technical report.

Compare

Comparison between the proposed LV-ViT and other recent works based on transformers. Note that we only show models whose model sizes are under 100M.

Training Pipeline

Pipeline

Our codes are based on the pytorch-image-models by Ross Wightman.

LV-ViT Models

Model layer dim Image resolution Param Top 1 Download
LV-ViT-S 16 384 224 26.15M 83.3 link
LV-ViT-S 16 384 384 26.30M 84.4 link
LV-ViT-M 20 512 224 55.83M 84.0 link
LV-ViT-M 20 512 384 56.03M 85.4 link
LV-ViT-L 24 768 448 150.47M 86.2 link

Requirements

torch>=1.4.0 torchvision>=0.5.0 pyyaml timm==0.4.5

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Validation

Replace DATA_DIR with your imagenet validation set path and MODEL_DIR with the checkpoint path

CUDA_VISIBLE_DEVICES=0 bash eval.sh /path/to/imagenet/val /path/to/checkpoint

Label data

We provide NFNet-F6 generated dense label map here. As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.

Training

Coming soon

Reference

If you use this repo or find it useful, please consider citing:

@misc{jiang2021token,
      title={Token Labeling: Training an 85.4% Top-1 Accuracy Vision Transformer with 56M Parameters on ImageNet}, 
      author={Zihang Jiang and Qibin Hou and Li Yuan and Daquan Zhou and Xiaojie Jin and Anran Wang and Jiashi Feng},
      year={2021},
      eprint={2104.10858},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Related projects

T2T-ViT, Re-labeling ImageNet.

Comments
  • error: download the pretrained model but couldn't be unzipped

    error: download the pretrained model but couldn't be unzipped

    tar -xvf lvvit_s-26M-384-84-4.pth.tar tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    opened by Williamlizl 10
  • The accuracy of the validation set is 0,and the loss is always around 13

    The accuracy of the validation set is 0,and the loss is always around 13

    Hello! I use ILSVRC2012_img_train and ILSVRC2012_img_val, and use the provided label_top5_train_nfnet from Google Drive. I train lv-vit-s with batch_size 64 without apex for one epoch. Thanks for your advice.

    opened by yifanQi98 7
  • Pretrained weights for LV-ViT-T

    Pretrained weights for LV-ViT-T

    Hi,

    Thanks for sharing your work. Could you also provide the pre-trained weights for the LV-ViT-T model variant, the one that achieves 79.1% top1-acc. as mentioned in Table 1 of your paper?

    All the best, Marc

    opened by marc345 5
  • train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    train error: AttributeError: 'tuple' object has no attribute 'log_softmax'

    Hi, thanks for you great work. When I train script, some error occurs: AttributeError: 'tuple' object has no attribute 'log_softmax'

    with amp_autocast():   
                output = model(input)  
                loss = loss_fn(output, target)  # error occurs
    
    

    and loss function is train_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.0).cuda()

    by the way: Could you please tell me why we need to specify smoothing=0.0?

    opened by lxy5513 5
  • RuntimeError: CUDA error: device-side assert triggered

    RuntimeError: CUDA error: device-side assert triggered

    I am a green hand of DL. When I run the code of volo with tlt in a single or multi GPU, I get an error as follows: /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [25,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. Traceback (most recent call last): File "main.py", line 949, in main() File "main.py", line 664, in main optimizers=optimizers) File "main.py", line 773, in train_one_epoch label_size=args.token_label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 90, in mixup_target y1 = get_labelmaps_with_coords(target, num_classes, on_value=on_value, off_value=off_value, device=device, label_size=label_size) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 64, in get_labelmaps_with_coords num_classes=num_classes,device=device) File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 16, in get_featuremaps _label_topk[1][:, :, :].long(), RuntimeError: CUDA error: device-side assert triggered.

    I can't fix this problem right now.

    opened by JIAOJIAYUASD 4
  • Generating label for custom dataset

    Generating label for custom dataset

    Hello,

    Thank you for sharing your work. I am currently trying to generate token label to a custom dataset for model lvvit_s, but I keep getting the loss close to 7 and the Accuracy 0 (not pre-trained and using 1 GPU in Google Colab). I also tried using the pre-trained model with --transfer but got 0 in both Loss and Acc . What option should I use for a custom dataset? image

    opened by AleMaiaF 2
  • generate_label.py unable to find model lvvit_s

    generate_label.py unable to find model lvvit_s

    Hi,

    When I tried to run the label generation script for the model lvvit_s it returned an error "RuntimeError: Unknown model".

    Solution: It worked when I added the line "import tlt.models" in the file generate_label.py.

    opened by AleMaiaF 2
  • Can Token labeling reach higher than annotator model?

    Can Token labeling reach higher than annotator model?

    Greetings,

    Thank you for this incredible research.

    I would like to know if it is possible to use Token Labeling to achieve scores higher than that of the annotator model, I believe this was the case with VOLO D5 model where it achieved higher score than NFNet, model used for annotation.

    opened by ErenBalatkan 1
  • label_map does not do the same augmentation (random crop) as the input image

    label_map does not do the same augmentation (random crop) as the input image

    Hi Thanks so much for the nice work! I am curious if you could share the insight on processing of the label_map. If I understand it correctly, after we load image and the corresponding, we shall do the same cropping/ flip/ resize, but in https://github.com/zihangJiang/TokenLabeling/blob/aa438eff9b9fc2daa8c8b4cc6bfaa6e3721f995e/tlt/data/label_transforms_factory.py#L58-L73 Seems only image was cropped, but the label map does not do the same cropping, which make the label map not match with the image?

    Shall we do

            return torchvision_F.resized_crop(
                    img, i, j, h, w, self.size, interpolation
            ), torchvision_F.resized_crop(
                    label_map, i / ratio, j / ratio, h / ratio, w / ratio, self.size, interpolation
            )
    

    Thanks

    opened by haooooooqi 1
  • Python3.6, ok; Python3.8, error

    Python3.6, ok; Python3.8, error

    Test: [ 0/1] Time: 11.293 (11.293) Loss: 0.7043 (0.7043) [email protected]: 42.1875 (42.1875) [email protected]: 100.0000 (100.0000) Test: [ 1/1] Time: 0.108 (5.701) Loss: 0.5847 (0.6689) [email protected]: 89.8148 (56.3187) [email protected]: 100.0000 (100.0000) free(): invalid pointer free(): invalid pointer Traceback (most recent call last): File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 303, in <module> main() File "/opt/conda/lib/python3.8/site-packages/torch/distributed/launch.py", line 294, in main raise subprocess.CalledProcessError(returncode=process.returncode, subprocess.CalledProcessError: Command '['/opt/conda/bin/python3.8', '-u', 'main.py', '--local_rank=1', './dataset/c/c', '--model', 'lvvit_s', '-b', '128', '--apex-amp', '--img-size', '224', '--drop-path', '0.1', '--token-label', '--token-label-size', '14', '--dense-weight', '0.0', '--num-classes', '2', '--finetune', './pretrained/lvvit_s-26M-384-84-4.pth.tar']' died with <Signals.SIGABRT: 6>. [email protected]:/puxin_libochao/TokenLabeling# CUDA_VISIBLE_DEVICES=0,1 bash ./distributed_train.sh 2 ./dataset/c/c --model lvvit_s -b 128 --apex-amp --img-size 224 --drop-path 0.1 --token-label --token-label-size 14 --dense-weight 0.0 --num-classes 2 --finetune ./pretrained/lvvit_s-26M-384-84-4.pth.tar

    opened by Williamlizl 1
  • A Bag of Training Techniques for ViT

    A Bag of Training Techniques for ViT

    Hi, thanks for your wonderful work. I have a question that whether training techniques mentioned in the LV-Vit can be used in other downstream task like object detection? In your paper, I see that many of this techniques are used in ImageNet. Thanks!

    opened by qdd1234 1
  • how to apply token labeling to CNN ?

    how to apply token labeling to CNN ?

    Hello ~ I'm interested in your token labeling technique, So I want to apply this technique in CNN based model because ViT is very heavy to train.

    can I get the your code with CNN token labeling? if you're not give me some detail for implementing

    thank you.

    opened by HoJ00n2 0
  • Model settings for Cifar10

    Model settings for Cifar10

    I am interested if there is any LV-ViT- model setup you have tested for Cifar10. I would like to know the proper setup of all blocks in none pretrained weights settings.

    opened by Aminullah6264 0
Owner
蒋子航
Now a Ph.D. student supervised by Prof. Feng Jiashi in ECE, NUS.
蒋子航
A general 3D Object Detection codebase in PyTorch.

Det3D is the first 3D Object Detection toolbox which provides off the box implementations of many 3D object detection algorithms such as PointPillars, SECOND, PIXOR, etc, as well as state-of-the-art

Benjin Zhu 1.4k Jan 05, 2023
Generative vs Discriminative: Rethinking The Meta-Continual Learning (NeurIPS 2021)

Generative vs Discriminative: Rethinking The Meta-Continual Learning (NeurIPS 2021) In this repository we provide PyTorch implementations for GeMCL; a

4 Apr 15, 2022
Transfer Learning Remote Sensing

Transfer_Learning_Remote_Sensing Simulation R codes for data generation and visualizations are in the folder simulation. Experiment: California Housin

2 Jun 21, 2022
DvD-TD3: Diversity via Determinants for TD3 version

DvD-TD3: Diversity via Determinants for TD3 version The implementation of paper Effective Diversity in Population Based Reinforcement Learning. Instal

3 Feb 11, 2022
tensorflow code for inverse face rendering

InverseFaceRender This is tensorflow code for our project: Learning Inverse Rendering of Faces from Real-world Videos. (https://arxiv.org/abs/2003.120

Yuda Qiu 18 Nov 16, 2022
Repository For Programmers Seeking a platform to show their skills

Programming-Nerds Repository For Programmers Seeking Pull Requests In hacktoberfest ❓ What's Hacktoberfest 2021? Hacktoberfest is the easiest way to g

42 Oct 29, 2022
Fermi Problems: A New Reasoning Challenge for AI

Fermi Problems: A New Reasoning Challenge for AI Fermi Problems are questions whose answer is a number that can only be reasonably estimated as a prec

AI2 15 May 28, 2022
Lenia - Mathematical Life Forms

For full version list, see Timeline in Lenia portal [2020-10-13] Update Python version with multi-kernel and multi-channel extensions (v3.4 LeniaNDK.p

Bert Chan 3.1k Dec 28, 2022
A hyperparameter optimization framework

Optuna: A hyperparameter optimization framework Website | Docs | Install Guide | Tutorial Optuna is an automatic hyperparameter optimization software

7.4k Jan 04, 2023
SelfAugment extends MoCo to include automatic unsupervised augmentation selection.

SelfAugment extends MoCo to include automatic unsupervised augmentation selection. In addition, we've included the ability to pretrain on several new datasets and included a wandb integration.

Colorado Reed 24 Oct 26, 2022
Predicting path with preference based on user demonstration using Maximum Entropy Deep Inverse Reinforcement Learning in a continuous environment

Preference-Planning-Deep-IRL Introduction Check my portfolio post Dependencies Gym stable-baselines3 PyTorch Usage Take Demonstration python3 record.

Tianyu Li 9 Oct 26, 2022
DeepMoCap: Deep Optical Motion Capture using multiple Depth Sensors and Retro-reflectors

DeepMoCap: Deep Optical Motion Capture using multiple Depth Sensors and Retro-reflectors By Anargyros Chatzitofis, Dimitris Zarpalas, Stefanos Kollias

tofis 24 Oct 08, 2022
Bootstrapped Unsupervised Sentence Representation Learning (ACL 2021)

Install first pip3 install -e . Training python3 training/unsupervised_tuning.py python3 training/supervised_tuning.py python3 training/multilingual_

yanzhang_nlp 26 Jul 22, 2022
Bolt Online Learning Toolbox

Bolt Online Learning Toolbox Bolt features discriminative learning of linear predictors (e.g. SVM or Logistic Regression) using fast online learning a

Peter Prettenhofer 87 Dec 12, 2022
FACIAL: Synthesizing Dynamic Talking Face With Implicit Attribute Learning. ICCV, 2021.

FACIAL: Synthesizing Dynamic Talking Face with Implicit Attribute Learning PyTorch implementation for the paper: FACIAL: Synthesizing Dynamic Talking

226 Jan 08, 2023
A PyTorch implementation of the architecture of Mask RCNN

EDIT (AS OF 4th NOVEMBER 2019): This implementation has multiple errors and as of the date 4th, November 2019 is insufficient to be utilized as a reso

Sai Himal Allu 975 Dec 30, 2022
EdiBERT is a generative model based on a bi-directional transformer, suited for image manipulation

EdiBERT, a generative model for image editing EdiBERT is a generative model based on a bi-directional transformer, suited for image manipulation. The

16 Dec 07, 2022
Twin-deep neural network for semi-supervised learning of materials properties

Deep Semi-Supervised Teacher-Student Material Synthesizability Prediction Citation: Semi-supervised teacher-student deep neural network for materials

MLEG 3 Dec 14, 2022
Supplementary materials for ISMIR 2021 LBD paper "Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes"

Evaluation of Latent Space Disentanglement in the Presence of Interdependent Attributes Supplementary materials for ISMIR 2021 LBD submission: K. N. W

Karn Watcharasupat 2 Oct 25, 2021
Improving Deep Network Debuggability via Sparse Decision Layers

Improving Deep Network Debuggability via Sparse Decision Layers This repository contains the code for our paper: Leveraging Sparse Linear Layers for D

Madry Lab 35 Nov 14, 2022