(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

Overview

How Do Vision Transformers Work?

This repository provides a PyTorch implementation of "How Do Vision Transformers Work?" In the paper, we show that multi-head self-attentions (MSAs) for computer vision is NOT for capturing long-range dependency. In particular, we address the following three key questions of MSAs and Vision Transformers (ViTs):

  1. What properties of MSAs do we need to better optimize NNs? Do the long-range dependencies of MSAs help NNs learn?
  2. Do MSAs act like Convs? If not, how are they different?
  3. How can we harmonize MSAs with Convs? Can we just leverage their advantages?

We demonstrate that (1) MSAs flatten the loss landscapes, (2) MSA and Convs are complementary because MSAs are low-pass filters and convolutions (Convs) are high-pass filter, and (3) MSAs at the end of a stage significantly improve the accuracy.

Let's find the detailed answers below!

I. What Properties of MSAs Do We Need to Improve Optimization?

MSAs improve not only accuracy but also generalization by flattening the loss landscapes. Such improvement is primarily attributable to their data specificity, NOT long-range dependency 😱 Their weak inductive bias disrupts NN training. On the other hand, ViTs suffers from non-convex losses. MSAs allow negative Hessian eigenvalues in small data regimes. Large datasets and loss landscape smoothing methods alleviate this problem.

II. Do MSAs Act Like Convs?

MSAs and Convs exhibit opposite behaviors. For example, MSAs are low-pass filters, but Convs are high-pass filters. In addition, Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.

III. How Can We Harmonize MSAs With Convs?

Multi-stage neural networks behave like a series connection of small individual models. In addition, MSAs at the end of a stage play a key role in prediction. Based on these insights, we propose design rules to harmonize MSAs with Convs. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block. The design pattern naturally derives the structure of canonical Transformer, which has one MLP block for one MSA block.


In addition, we also introduce AlterNet, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. Surprisingly, AlterNet outperforms CNNs not only in large data regimes but also in small data regimes. This contrasts with canonical ViTs, models that perform poorly on small amounts of data.

This repository is based on the official implementation of "Blurs Make Results Clearer: Spatial Smoothings to Improve Accuracy, Uncertainty, and Robustness". In this paper, we show that a simple (non-trainable) 2 ✕ 2 box blur filter improves accuracy, uncertainty, and robustness simultaneously by ensembling spatially nearby feature maps of CNNs. MSA is not simply generalized Conv, but rather a generalized (trainable) blur filter that complements Conv. Please check it out!

Getting Started

The following packages are required:

  • pytorch
  • matplotlib
  • notebook
  • ipywidgets
  • timm
  • einops
  • tensorboard
  • seaborn (optional)

We mainly use docker images pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime for the code.

See classification.ipynb for image classification. Run all cells to train and test models on CIFAR-10, CIFAR-100, and ImageNet.

Metrics. We provide several metrics for measuring accuracy and uncertainty: Acuracy (Acc, ↑) and Acc for 90% certain results (Acc-90, ↑), negative log-likelihood (NLL, ↓), Expected Calibration Error (ECE, ↓), Intersection-over-Union (IoU, ↑) and IoU for certain results (IoU-90, ↑), Unconfidence (Unc-90, ↑), and Frequency for certain results (Freq-90, ↑). We also define a method to plot a reliability diagram for visualization.

Models. We provide AlexNet, VGG, pre-activation VGG, ResNet, pre-activation ResNet, ResNeXt, WideResNet, ViT, PiT, Swin, MLP-Mixer, and Alter-ResNet by default.

Visualizing the Loss Landscapes

Refer to losslandscape.ipynb for exploring the loss landscapes. It requires a trained model. Run all cells to get predictive performance of the model for weight space grid. We provide a sample loss landscape result.

Evaluating Robustness on Corrupted Datasets

Refer to robustness.ipynb for evaluation corruption robustness on corrupted datasets such as CIFAR-10-C and CIFAR-100-C. It requires a trained model. Run all cells to get predictive performance of the model on datasets which consist of data corrupted by 15 different types with 5 levels of intensity each. We provide a sample robustness result.

How to Apply MSA to Your Own Model

We find that MSA complements Conv (not replaces Conv), and MSA closer to the end of stage improves predictive performance significantly. Based on these insights, we propose the following build-up rules:

  1. Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
  2. If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA
  3. Use more heads and higher hidden dimensions for MSA blocks in late stages.

In the animation above, we replace Convs of ResNet with MSAs one by one according to the build-up rules. Note that several MSAs in c3 harm the accuracy, but the MSA at the end of c2 improves it. As a result, surprisingly, the model with MSAs following the appropriate build-up rule outperforms CNNs even in the small data regime, e.g., CIFAR!

Caution: Investigate Loss Landscapes and Hessians With l2 Regularization on Augmented Datasets

Two common mistakes ⚠️ are investigating loss landscapes and Hessians (1) 'without considering l2 regularization' on (2) 'clean datasets'. However, note that NNs are optimized with l2 regularization on augmented datasets. Therefore, it is appropriate to visualize 'NLL + l2' on 'augmented datasets'. Measuring criteria without l2 on clean dataset would give incorrect (even opposite) results.

Citation

If you find this useful, please consider citing 📑 the paper and starring 🌟 this repository. Please do not hesitate to contact Namuk Park (email: namuk.park at gmail dot com, twitter: xxxnell) with any comments or feedback.

BibTex is TBD.

License

All code is available to you under Apache License 2.0. CNN models build off the torchvision models which are BSD licensed. ViTs build off the PyTorch Image Models and Vision Transformer - Pytorch which are Apache 2.0 and MIT licensed.

Copyright the maintainers.

Owner
xxxnell
Programmer & ML researcher
xxxnell
Dogs classification with Deep Metric Learning using some popular losses

Tsinghua Dogs classification with Deep Metric Learning 1. Introduction Tsinghua Dogs dataset Tsinghua Dogs is a fine-grained classification dataset fo

QuocThangNguyen 45 Nov 09, 2022
Using Python to Play Cyberpunk 2077

CyberPython 2077 Using Python to Play Cyberpunk 2077 This repo will contain code from the Cyberpython 2077 video series on Youtube (youtube.

Harrison 118 Oct 18, 2022
Liecasadi - liecasadi implements Lie groups operation written in CasADi

liecasadi liecasadi implements Lie groups operation written in CasADi, mainly di

Artificial and Mechanical Intelligence 14 Nov 05, 2022
Supercharging Imbalanced Data Learning WithCausal Representation Transfer

ECRT: Energy-based Causal Representation Transfer Code for Supercharging Imbalanced Data Learning With Energy-basedContrastive Representation Transfer

Zidi Xiu 11 May 02, 2022
Projects of Andfun Yangon

AndFunYangon Projects of Andfun Yangon First Commit We can use gsearch.py to sea

Htin Aung Lu 1 Dec 28, 2021
Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch.

Faster R-CNN and Mask R-CNN in PyTorch 1.0 maskrcnn-benchmark has been deprecated. Please see detectron2, which includes implementations for all model

Facebook Research 9k Jan 04, 2023
Code for "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" @ICRA2021

CloudAAE This is an tensorflow implementation of "CloudAAE: Learning 6D Object Pose Regression with On-line Data Synthesis on Point Clouds" Files log:

Gee 35 Nov 14, 2022
Neural networks applied in recognizing guitar chords using python, AutoML.NET with C# and .NET Core

Chord Recognition Demo application The demo application is written in C# with .NETCore. As of July 9, 2020, the only version available is for windows

Andres Mauricio Rondon Patiño 24 Oct 22, 2022
The most simple and minimalistic navigation dashboard.

Navigation This project follows a goal to have simple and lightweight dashboard with different links. I use it to have my own self-hosted service dash

Yaroslav 23 Dec 23, 2022
A semantic segmentation toolbox based on PyTorch

Introduction vedaseg is an open source semantic segmentation toolbox based on PyTorch. Features Modular Design We decompose the semantic segmentation

407 Dec 15, 2022
Fashion Entity Classification

Fashion-Entity-Classification - Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grays

ADITYA SHAH 1 Jan 04, 2022
Scikit-learn compatible estimation of general graphical models

skggm : Gaussian graphical models using the scikit-learn API In the last decade, learning networks that encode conditional independence relationships

213 Jan 02, 2023
Distance-Ratio-Based Formulation for Metric Learning

Distance-Ratio-Based Formulation for Metric Learning Environment Python3 Pytorch (http://pytorch.org/) (version 1.6.0+cu101) json tqdm Preparing datas

Hyeongji Kim 1 Dec 07, 2022
🥇Samsung AI Challenge 2021 1등 솔루션입니다🥇

MoT - Molecular Transformer Large-scale Pretraining for Molecular Property Prediction Samsung AI Challenge for Scientific Discovery This repository is

Jungwoo Park 44 Dec 03, 2022
Contour-guided image completion with perceptual grouping (BMVC 2021 publication)

Contour-guided Image Completion with Perceptual Grouping Authors Morteza Rezanejad*, Sidharth Gupta*, Chandra Gummaluru, Ryan Marten, John Wilder, Mic

Sid Gupta 6 Dec 27, 2022
A PaddlePaddle version image model zoo.

Paddle-Image-Models English | 简体中文 A PaddlePaddle version image model zoo. Install Package Install by pip: $ pip install ppim Install by wheel package

AgentMaker 131 Dec 07, 2022
Jittor implementation of PCT:Point Cloud Transformer

PCT: Point Cloud Transformer This is a Jittor implementation of PCT: Point Cloud Transformer.

MenghaoGuo 547 Jan 03, 2023
Tensorflow Implementation for "Pre-trained Deep Convolution Neural Network Model With Attention for Speech Emotion Recognition"

Tensorflow Implementation for "Pre-trained Deep Convolution Neural Network Model With Attention for Speech Emotion Recognition" Pre-trained Deep Convo

Ankush Malaker 5 Nov 11, 2022
[SIGGRAPH Asia 2019] Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning

AGIS-Net Introduction This is the official PyTorch implementation of the Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning. paper | suppl

Yue Gao 102 Jan 02, 2023
Codes for CVPR2021 paper "PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization"

PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization (CVPR 2021) This is the official implementation of PW

Intelligent Robotics and Machine Vision Lab 42 Dec 18, 2022