(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

Overview

How Do Vision Transformers Work?

This repository provides a PyTorch implementation of "How Do Vision Transformers Work?" In the paper, we show that multi-head self-attentions (MSAs) for computer vision is NOT for capturing long-range dependency. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

  1. What properties of MSAs do we need to better optimize NNs? Do the long-range dependencies of MSAs help NNs learn?
  2. Do MSAs act like Convs? If not, how are they different?
  3. How can we harmonize MSAs with Convs? Can we just leverage their advantages?

We demonstrate that (1) MSAs flatten the loss landscapes, (2) MSA and Convs are complementary because MSAs are low-pass filters and convolutions (Convs) are high-pass filter, and (3) MSAs at the end of a stage significantly improve the accuracy.

Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes. Such improvement is primarily attributable to their data specificity, NOT long-range dependency 😱 Their weak inductive bias disrupts NN training. On the other hand, ViTs suffers from non-convex losses. MSAs allow negative Hessian eigenvalues in small data regimes. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. For example, MSAs are low-pass filters, but Convs are high-pass filters. In addition, Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of canonical Transformer, which has one MLP block for one MSA block.


In addition, we also introduce AlterNet, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes. This contrasts with canonical ViTs, models that perform poorly on small amounts of data.

This repository is based on the official implementation of "Blurs Make Results Clearer: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness". In this paper, we show that a simple (non-trainable) 2 ✕ 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, ↑) and Acc for 90% certain results (Acc-90, ↑), negative log-likelihood (NLL, ↓), Expected Calibration Error (ECE, ↓), Intersection-over-Union (IoU, ↑) and IoU for certain results (IoU-90, ↑), Unconfidence (Unc-90, ↑), and Frequency for certain results (Freq-90, ↑). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb for exploring the loss landscapes. It requires a trained model. Run all cells to get predictive performance of the model for weight space grid. We provide a sample loss landscape result.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. It requires a trained model. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. We provide a sample robustness result.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regime, e.g., CIFAR!

Caution: Investigate Loss Landscapes and Hessians With l2 Regularization on Augmented Datasets

Two common mistakes ⚠️ are investigating loss landscapes and Hessians (1) 'without considering l2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with l2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + l2' on 'augmented datasets'. Measuring criteria without l2 on clean dataset would give incorrect (even opposite) results.

Citation

If you find this useful, please consider citing 📑 the paper and starring 🌟 this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

BibTex is TBD.

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

Owner
xxxnell
Programmer & ML researcher
xxxnell
Open source person re-identification library in python

Open-ReID Open-ReID is a lightweight library of person re-identification for research purpose. It aims to provide a uniform interface for different da

Tong Xiao 1.3k Jan 01, 2023
Checking fibonacci - Generating the Fibonacci sequence is a classic recursive problem

Fibonaaci Series Generating the Fibonacci sequence is a classic recursive proble

Moureen Caroline O 1 Feb 15, 2022
4D Human Body Capture from Egocentric Video via 3D Scene Grounding

4D Human Body Capture from Egocentric Video via 3D Scene Grounding [Project] [Paper] Installation: Our method requires the same dependencies as SMPLif

Miao Liu 37 Nov 08, 2022
A mini library for Policy Gradients with Parameter-based Exploration, with reference implementation of the ClipUp optimizer from NNAISENSE.

PGPElib A mini library for Policy Gradients with Parameter-based Exploration [1] and friends. This library serves as a clean re-implementation of the

NNAISENSE 56 Jan 01, 2023
Pytorch implementation of CVPR2021 paper "MUST-GAN: Multi-level Statistics Transfer for Self-driven Person Image Generation"

MUST-GAN Code | paper The Pytorch implementation of our CVPR2021 paper "MUST-GAN: Multi-level Statistics Transfer for Self-driven Person Image Generat

TianxiangMa 46 Dec 26, 2022
Unofficial implementation (replicates paper results!) of MINER: Multiscale Implicit Neural Representations in pytorch-lightning

MINER_pl Unofficial implementation of MINER: Multiscale Implicit Neural Representations in pytorch-lightning. 📖 Ref readings Laplacian pyramid explan

AI葵 51 Nov 28, 2022
MPLP: Metapath-Based Label Propagation for Heterogenous Graphs

MPLP: Metapath-Based Label Propagation for Heterogenous Graphs Results on MAG240M Here, we demonstrate the following performance on the MAG240M datase

Qiuying Peng 10 Jun 28, 2022
CS550 Machine Learning course project on CNN Detection.

CNN Detection (CS550 Machine Learning Project) Team Members (Tensor) : Yadava Kishore Chodipilli (11940310) Thashmitha BS (11941250) This is a work do

yaadava_kishore 2 Jan 30, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
FaceAPI: AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using TensorFlow/JS

FaceAPI AI-powered Face Detection & Rotation Tracking, Face Description & Recognition, Age & Gender & Emotion Prediction for Browser and NodeJS using

Vladimir Mandic 395 Dec 29, 2022
This is the code used in the paper "Entity Embeddings of Categorical Variables".

This is the code used in the paper "Entity Embeddings of Categorical Variables". If you want to get the original version of the code used for the Kagg

Cheng Guo 845 Nov 29, 2022
This repository compare a selfie with images from identity documents and response if the selfie match.

aws-rekognition-facecompare This repository compare a selfie with images from identity documents and response if the selfie match. This code was made

1 Jan 27, 2022
PyTorch implementation of SmoothGrad: removing noise by adding noise.

SmoothGrad implementation in PyTorch PyTorch implementation of SmoothGrad: removing noise by adding noise. Vanilla Gradients SmoothGrad Guided backpro

SSKH 143 Jan 05, 2023
This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing described in paper Discontinuous Grammar as a Foreign Language.

Discontinuous Grammar as a Foreign Language This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing

Daniel Fernández-González 2 Apr 07, 2022
A Comparative Framework for Multimodal Recommender Systems

Cornac Cornac is a comparative framework for multimodal recommender systems. It focuses on making it convenient to work with models leveraging auxilia

Preferred.AI 671 Jan 03, 2023
CLIPImageClassifier wraps clip image model from transformers

CLIPImageClassifier CLIPImageClassifier wraps clip image model from transformers. CLIPImageClassifier is initialized with the argument classes, these

Jina AI 6 Sep 12, 2022
Learn other languages ​​using artificial intelligence with python.

The main idea of ​​the project is to facilitate the learning of other languages. We created a simple AI that will interact with you. Just ask questions that if she knows, she will answer.

Pedro Rodrigues 2 Jun 07, 2022
MGFN: Multi-Graph Fusion Networks for Urban Region Embedding was accepted by IJCAI-2022.

Multi-Graph Fusion Networks for Urban Region Embedding (IJCAI-22) This is the implementation of Multi-Graph Fusion Networks for Urban Region Embedding

202 Nov 18, 2022
AI Summer's complete catalog of articles

Learn Deep Learning with AI Summer A collection of all articles (almost 100) written for the AI Summer blog organized by topic. Deep Learning Theory M

AI Summer 95 Dec 29, 2022
HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty

HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty Giorgio Cantarini, Francesca Odone, Nicoletta Noceti, Federi

18 Aug 02, 2022