Attention Probe: Vision Transformer Distillation in the Wild

Overview

Attention Probe: Vision Transformer Distillation in the Wild

License: MIT

Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang
In ICASSP 2022

This code is the Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Overview

  • We propose the concept of Attention Probe, a special section of the attention map to utilize a large amount of unlabeled data in the wild to complete the vision transformer data-free distillation task. Instead of generating images from the teacher network with a series of priori, images most relevant to the given pre-trained network and tasks will be identified from a large unlabeled dataset (e.g., Flickr) to conduct the knowledge distillation task.
  • We propose a simple yet efficient distillation algorithm, called probe distillation, to distill the student model using intermediate features of the teacher model, which is based on the Attention Probe.

Prerequisite

We use Pytorch 1.7.1, and CUDA 11.0. You can install them with

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

It should also be applicable to other Pytorch and CUDA versions.

Usage

Data Preparation

First, you need to modify the storage format of the cifar-10/100 and tinyimagenet dataset to the style of ImageNet, etc. CIFAR 10 run:

python process_cifar10.py

CIFAR 100 run:

python process_cifar100.py

Tiny-ImageNet run:

python process_tinyimagenet.py
python process_move_file.py

The dataset dir should have the following structure:

dir/
  train/
    ...
  val/
    n01440764/
      ILSVRC2012_val_00000293.JPEG
      ...
    ...

Train a normal teacher network

For this step you need to train normal teacher transformer models for selecting valuable data from the wild. We train the teacher model based on the timm PyTorch library:

timm

Our pretrained teacher models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here:

Pretrained teacher models

Select valuable data from the wild

Then, you can use the Attention Probe method to select valuable data in the wild dataset.

To select valuable data on the CIFAR-10 dataset

bash training.sh
(CIFAR 10 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar10 --data_cifar $root_cifar10 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar10 --selected_file $selected_cifar10 --output_dir $output_student_cifar10 --nb_classes 10 --lr_S 7.5e-4 --attnprobe_sel --attnprobe_dist )

After you will get "class_weights.pth, pred_out.pth, value_blk3.pth, value_blk7.pth, value_out.pth" in '/selected/cifar10/' or '/selected/cifar100/' directory, you have already obtained the selected data.

Probe Knowledge Distillation for Student networks

Then you can distill the student model using intermediate features of the teacher model based on the selected data.

bash training.sh
(CIFAR 10 run: CUDA_VISIBLE_DEVICES=0 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

(CIFAR 100 run: CUDA_VISIBLE_DEVICES=0,1,2,3 python DFND_DeiT-train.py --dataset cifar100 --data_cifar $root_cifar100 --data_imagenet $root_wild --num_select 650000 --teacher_dir $teacher_cifar100 --selected_file $selected_cifar100 --output_dir $output_student_cifar100 --nb_classes 100 --lr_S 8.5e-4 --attnprobe_sel --attnprobe_dist)

you will get the student transformer model in '/output/cifar10/student/' or '/output/cifar100/student/' directory.

Our distilled student models (CIFAR-10, CIFAR-100, ImageNet, Tiny-ImageNet, MNIST) can be downloaded from here: Distilled student models

Results

Citation

@inproceedings{
wang2022attention,
title={Attention Probe: Vision Transformer Distillation in the Wild},
author={Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang},
booktitle={International Conference on Acoustics, Speech and Signal Processing},
year={2022},
url={https://2022.ieeeicassp.org/}
}

Acknowledgement

Owner
Wang jiahao
CVer,AutoML,NAS,Model Compression
Wang jiahao
[ICCV 2021] Code release for "Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks"

Sub-bit Neural Networks: Learning to Compress and Accelerate Binary Neural Networks By Yikai Wang, Yi Yang, Fuchun Sun, Anbang Yao. This is the pytorc

Yikai Wang 26 Nov 20, 2022
Fake-user-agent-traffic-geneator - Python CLI Tool to generate fake traffic against URLs with configurable user-agents

Fake traffic generator for Gartner Demo Generate fake traffic to URLs with custo

New Relic Experimental 3 Oct 31, 2022
The implementation of the lifelong infinite mixture model

Lifelong infinite mixture model 📋 This is the implementation of the Lifelong infinite mixture model 📋 Accepted by ICCV 2021 Title : Lifelong Infinit

Fei Ye 5 Oct 20, 2022
The Power of Scale for Parameter-Efficient Prompt Tuning

The Power of Scale for Parameter-Efficient Prompt Tuning Implementation of soft embeddings from https://arxiv.org/abs/2104.08691v1 using Pytorch and H

Kip Parker 208 Dec 30, 2022
Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation

Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Implementation This project attempted to implement the paper Putting NeRF on a

254 Dec 27, 2022
PyTorch implementation of the wavelet analysis from Torrence & Compo

Continuous Wavelet Transforms in PyTorch This is a PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). The co

Tom Runia 262 Dec 21, 2022
[IEEE TPAMI21] MobileSal: Extremely Efficient RGB-D Salient Object Detection [PyTorch & Jittor]

MobileSal IEEE TPAMI 2021: MobileSal: Extremely Efficient RGB-D Salient Object Detection This repository contains full training & testing code, and pr

Yu-Huan Wu 52 Jan 06, 2023
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

65 Dec 27, 2022
Mini Software that give reminder to drink water as per your weight.

Water Notification Desktop Python The Mini Software built in Python (tkinter) that will remind you to drink water on specific time span based on your

Om Jogani 5 Dec 16, 2022
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021
This repo contains the official code of our work SAM-SLR which won the CVPR 2021 Challenge on Large Scale Signer Independent Isolated Sign Language Recognition.

Skeleton Aware Multi-modal Sign Language Recognition By Songyao Jiang, Bin Sun, Lichen Wang, Yue Bai, Kunpeng Li and Yun Fu. Smile Lab @ Northeastern

Isen (Songyao Jiang) 128 Dec 08, 2022
Source codes for the paper "Local Additivity Based Data Augmentation for Semi-supervised NER"

LADA This repo contains codes for the following paper: Jiaao Chen*, Zhenghui Wang*, Ran Tian, Zichao Yang, Diyi Yang: Local Additivity Based Data Augm

GT-SALT 36 Dec 02, 2022
QAHOI: Query-Based Anchors for Human-Object Interaction Detection (paper)

QAHOI QAHOI: Query-Based Anchors for Human-Object Interaction Detection (paper) Requirements PyTorch = 1.5.1 torchvision = 0.6.1 pip install -r requ

38 Dec 29, 2022
Python and Julia in harmony.

PythonCall & JuliaCall Bringing Python® and Julia together in seamless harmony: Call Python code from Julia and Julia code from Python via a symmetric

Christopher Rowley 414 Jan 07, 2023
Working demo of the Multi-class and Anomaly classification model using the CLIP feature space

👁️ Hindsight AI: Crime Classification With Clip About For Educational Purposes Only This is a recursive neural net trained to classify specific crime

Miles Tweed 2 Jun 05, 2022
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 116 Jan 04, 2023
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

AutoViz and Auto_ViML 397 Dec 30, 2022
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Bayes-Newton Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Wil

AaltoML 165 Nov 27, 2022