《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Related tags

Deep Learningpit
Overview

Rethinking Spatial Dimensions of Vision Transformers

Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper

NAVER AI LAB

teaser

Abstract

Vision Transformer (ViT) extends the application range of transformers from language processing to computer vision tasks as being an alternative architecture against the existing convolutional neural networks (CNN). Since the transformer-based architecture has been innovative for computer vision modeling, the design convention towards an effective architecture has been less studied yet. From the successful design principles of CNN, we investigate the role of the spatial dimension conversion and its effectiveness on the transformer-based architecture. We particularly attend the dimension reduction principle of CNNs; as the depth increases, a conventional CNN increases channel dimension and decreases spatial dimensions. We empirically show that such a spatial dimension reduction is beneficial to a transformer architecture as well, and propose a novel Pooling-based Vision Transformer (PiT) upon the original ViT model. We show that PiT achieves the improved model capability and generalization performance against ViT. Throughout the extensive experiments, we further show PiT outperforms the baseline on several tasks such as image classification, object detection and robustness evaluation.

Model performance

We compared performance of PiT with DeiT models in various training settings. Throughput (imgs/sec) values are measured in a machine with single V100 gpu with 128 batche size.

Network FLOPs # params imgs/sec Vanilla +CutMix +DeiT +Distill
DeiT-Ti 1.3 G 5.7 M 2564 68.7 68.5 72.2 74.5
PiT-Ti 0.71 G 4.9 M 3030 71.3 72.6 73.0 74.6
PiT-XS 1.4 G 10.6 M 2128 72.4 76.8 78.1 79.1
DeiT-S 4.6 G 22.1 M 980 68.7 76.5 79.8 81.2
PiT-S 2.9 G 23.5 M 1266 73.3 79.0 80.9 81.9
DeiT-B 17.6 G 86.6 M 303 69.3 75.3 81.8 83.4
PiT-B 12.5 G 73.8 M 348 76.1 79.9 82.0 84.0

Pretrained weights

Model name FLOPs accuracy weights
pit_ti 0.71 G 73.0 link
pit_xs 1.4 G 78.1 link
pit_s 2.9 G 80.9 link
pit_b 12.5 G 82.0 link
pit_ti_distilled 0.71 G 74.6 link
pit_xs_distilled 1.4 G 79.1 link
pit_s_distilled 2.9 G 81.9 link
pit_b_distilled 12.5 G 84.0 link

Dependancies

Our implementations are tested on following libraries with Python 3.6.9 and CUDA 10.1.

torch: 1.7.1
torchvision: 0.8.2
timm: 0.3.4
einops: 0.3.0

Install other dependencies using the following command.

pip install -r requirements.txt

How to use models

You can build PiT models directly

import torch
import pit

model = pit.pit_s(pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)))

Or using timm function

import torch
import timm
import pit

model = timm.create_model('pit_s', pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_809.pth'))
print(model(torch.randn(1, 3, 224, 224)))

To use models trained with distillation, you should use _distilled model and weights.

import torch
import pit

model = pit.pit_s_distilled(pretrained=False)
model.load_state_dict(torch.load('./weights/pit_s_distill_819.pth'))
print(model(torch.randn(1, 3, 224, 224)))

License

Copyright 2021-present NAVER Corp.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Citation

@article{heo2021pit,
    title={Rethinking Spatial Dimensions of Vision Transformers},
    author={Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
    journal={arXiv: 2103.16302},
    year={2021},
}
Owner
NAVER AI
Official account of NAVER AI, Korea No.1 Industrial AI Research Group
NAVER AI
PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Yonglong Tian 2.2k Jan 08, 2023
Minimal implementation of Denoised Smoothing: A Provable Defense for Pretrained Classifiers in TensorFlow.

Denoised-Smoothing-TF Minimal implementation of Denoised Smoothing: A Provable Defense for Pretrained Classifiers in TensorFlow. Denoised Smoothing is

Sayak Paul 19 Dec 11, 2022
【steal piano】GitHub偷情分析工具!

【steal piano】GitHub偷情分析工具! 你是否有这样的困扰,有一天你的仓库被很多人加了star,但是你却不知道这些人都是从哪来的? 别担心,GitHub偷情分析工具帮你轻松解决问题! 原理 GitHub偷情分析工具透过分析star的时间以及他们之间的follow关系,可以推测出每个st

黄巍 442 Dec 21, 2022
A PyTorch implementation of "DGC-Net: Dense Geometric Correspondence Network"

DGC-Net: Dense Geometric Correspondence Network This is a PyTorch implementation of our work "DGC-Net: Dense Geometric Correspondence Network" TL;DR A

191 Dec 16, 2022
Barlow Twins and HSIC

Barlow Twins and HSIC Unofficial Pytorch implementation for Barlow Twins and HSIC_SSL on small datasets (CIFAR10, STL10, and Tiny ImageNet). Correspon

Yao-Hung Hubert Tsai 49 Nov 24, 2022
LQM - Improving Object Detection by Estimating Bounding Box Quality Accurately

Improving Object Detection by Estimating Bounding Box Quality Accurately Abstract Object detection aims to locate and classify object instances in ima

IM Lab., POSTECH 0 Sep 28, 2022
Code to replicate the key results from Exploring the Limits of Out-of-Distribution Detection

Exploring the Limits of Out-of-Distribution Detection In this repository we're collecting replications for the key experiments in the Exploring the Li

Stanislav Fort 35 Jan 03, 2023
a reimplementation of LiteFlowNet in PyTorch that matches the official Caffe version

pytorch-liteflownet This is a personal reimplementation of LiteFlowNet [1] using PyTorch. Should you be making use of this work, please cite the paper

Simon Niklaus 365 Dec 31, 2022
deep-prae

Deep Probabilistic Accelerated Evaluation (Deep-PrAE) Our work presents an efficient rare event simulation methodology for black box autonomy using Im

Safe AI Lab 4 Apr 17, 2021
EGNN - Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

EGNN - Pytorch Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This

Phil Wang 259 Jan 04, 2023
Non-Metric Space Library (NMSLIB): An efficient similarity search library and a toolkit for evaluation of k-NN methods for generic non-metric spaces.

Non-Metric Space Library (NMSLIB) Important Notes NMSLIB is generic but fast, see the results of ANN benchmarks. A standalone implementation of our fa

2.9k Jan 04, 2023
Bayesian regularization for functional graphical models.

BayesFGM Paper: Jiajing Niu, Andrew Brown. Bayesian regularization for functional graphical models. Requirements R version 3.6.3 and up Python 3.6 and

0 Oct 07, 2021
Serving PyTorch 1.0 Models as a Web Server in C++

Serving PyTorch Models in C++ This repository contains various examples to perform inference using PyTorch C++ API. Run git clone https://github.com/W

Onur Kaplan 223 Jan 04, 2023
yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)

代码地址: https://github.com/Sharpiless/yolov5-knowledge-distillation 教师模型: python train.py --weights weights/yolov5m.pt \ --cfg models/yolov5m.ya

52 Dec 04, 2022
A Blender python script for getting asset browser custom preview images for objects and collections.

asset_snapshot A Blender python script for getting asset browser custom preview images for objects and collections. Installation: Click the code butto

Johnny Matthews 44 Nov 29, 2022
Semi-Supervised Signed Clustering Graph Neural Network (and Implementation of Some Spectral Methods)

SSSNET SSSNET: Semi-Supervised Signed Network Clustering For details, please read our paper. Environment Setup Overview The project has been tested on

Yixuan He 9 Nov 24, 2022
Multi-Horizon-Forecasting-for-Limit-Order-Books

Multi-Horizon-Forecasting-for-Limit-Order-Books This jupyter notebook is used to demonstrate our work, Multi-Horizon Forecasting for Limit Order Books

Zihao Zhang 116 Dec 23, 2022
PyTorch implementation of a Real-ESRGAN model trained on custom dataset

Real-ESRGAN PyTorch implementation of a Real-ESRGAN model trained on custom dataset. This model shows better results on faces compared to the original

Sber AI 160 Jan 04, 2023
Curating a dataset for bioimage transfer learning

CytoImageNet A large-scale pretraining dataset for bioimage transfer learning. Motivation In past few decades, the increase in speed of data collectio

Stanley Z. Hua 9 Jun 20, 2022
🥇 LG-AI-Challenge 2022 1위 솔루션 입니다.

LG-AI-Challenge-for-Plant-Classification Dacon에서 진행된 농업 환경 변화에 따른 작물 병해 진단 AI 경진대회 에 대한 코드입니다. (colab directory에 코드가 잘 정리 되어있습니다.) Requirements python

siwooyong 10 Jun 30, 2022