CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

Overview

CSWin-Transformer

PWC PWC

This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". The code and models for downstream tasks are coming soon.

Introduction

CSWin Transformer (the name CSWin stands for Cross-Shaped Window) is introduced in arxiv, which is a new general-purpose backbone for computer vision. It is a hierarchical Transformer and replaces the traditional full attention with our newly proposed cross-shaped window self-attention. The cross-shaped window self-attention mechanism computes self-attention in the horizontal and vertical stripes in parallel that from a cross-shaped window, with each stripe obtained by splitting the input feature into stripes of equal width. With CSWin, we could realize global attention with a limited computation cost.

CSWin Transformer achieves strong performance on ImageNet classification (87.5 on val with only 97G flops) and ADE20K semantic segmentation (55.7 mIoU on val), surpassing previous models by a large margin.

teaser

Main Results on ImageNet

model pretrain resolution [email protected] #params FLOPs 22K model 1K model
CSWin-T ImageNet-1K 224x224 82.8 23M 4.3G - model
CSWin-S ImageNet-1k 224x224 83.6 35M 6.9G - model
CSWin-B ImageNet-1k 224x224 84.2 78M 15.0G - model
CSWin-B ImageNet-1k 384x384 85.5 78M 47.0G - model
CSWin-L ImageNet-22k 224x224 86.5 173M 31.5G model model
CSWin-L ImageNet-22k 384x384 87.5 173M 96.8G - model

Main Results on Downstream Tasks

COCO Object Detection

backbone Method pretrain lr Schd box mAP mask mAP #params FLOPS
CSwin-T Mask R-CNN ImageNet-1K 3x 49.0 43.6 42M 279G
CSwin-S Mask R-CNN ImageNet-1K 3x 50.0 44.5 54M 342G
CSwin-B Mask R-CNN ImageNet-1K 3x 50.8 44.9 97M 526G
CSwin-T Cascade Mask R-CNN ImageNet-1K 3x 52.5 45.3 80M 757G
CSwin-S Cascade Mask R-CNN ImageNet-1K 3x 53.7 46.4 92M 820G
CSwin-B Cascade Mask R-CNN ImageNet-1K 3x 53.9 46.4 135M 1004G

ADE20K Semantic Segmentation (val)

Backbone Method pretrain Crop Size Lr Schd mIoU mIoU (ms+flip) #params FLOPs
CSwin-T Semantic FPN ImageNet-1K 512x512 80K 48.2 - 26M 202G
CSwin-S Semantic FPN ImageNet-1K 512x512 80K 49.2 - 39M 271G
CSwin-B Semantic FPN ImageNet-1K 512x512 80K 49.9 - 81M 464G
CSwin-T UPerNet ImageNet-1K 512x512 160K 49.3 50.4 60M 959G
CSwin-S UperNet ImageNet-1K 512x512 160K 50.0 50.8 65M 1027G
CSwin-B UperNet ImageNet-1K 512x512 160K 50.8 51.7 109M 1222G
CSwin-B UPerNet ImageNet-22K 640x640 160K 51.8 52.6 109M 1941G
CSwin-L UperNet ImageNet-22K 640x640 160K 53.4 55.7 208M 2745G

Requirements

timm==0.3.4, pytorch>=1.4, opencv, ... , run:

bash install_req.sh

Apex for mixed precision training is used for finetuning. To install apex, run:

git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Train

Train the three lite variants: CSWin-Tiny, CSWin-Small and CSWin-Base:

bash train.sh 8 --data <data path> --model CSWin_64_12211_tiny_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.2
bash train.sh 8 --data <data path> --model CSWin_64_24322_small_224 -b 256 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99984 --drop-path 0.4
bash train.sh 8 --data <data path> --model CSWin_96_24322_base_224 -b 128 --lr 1e-3 --weight-decay .1 --amp --img-size 224 --warmup-epochs 20 --model-ema-decay 0.99992 --drop-path 0.5

If you want to train our CSWin on images with 384x384 resolution, please use '--img-size 384'.

If the GPU memory is not enough, please use '-b 128 --lr 1e-3 --model-ema-decay 0.99992' or use checkpoint '--use-chk'.

Finetune

Finetune CSWin-Base with 384x384 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_96_24322_base_384 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 384 --warmup-epochs 0 --model-ema-decay 0.9998 --finetune <pretrained 224 model> --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1

Finetune ImageNet-22K pretrained CSWin-Large with 224x224 resolution:

bash finetune.sh 8 --data <data path> --model CSWin_144_24322_large_224 -b 64 --lr 2.5e-4 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9996 --finetune <22k-pretrained model> --epochs 30 --mixup 0.01 --cooldown-epochs 10 --interpolation bicubic  --lr-scale 0.05 --drop-path 0.2 --cutmix 0.3 --use-chk --fine-22k --ema-finetune

If the GPU memory is not enough, please use checkpoint '--use-chk'.

Cite CSWin Transformer

@misc{dong2021cswin,
      title={CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows}, 
        author={Xiaoyi Dong and Jianmin Bao and Dongdong Chen and Weiming Zhang and Nenghai Yu and Lu Yuan and Dong Chen and Baining Guo},
        year={2021},
        eprint={2107.00652},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
}

Acknowledgement

This repository is built using the timm library and the DeiT repository.

License

This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

Microsoft Open Source Code of Conduct

Contact Information

For help or issues using CSWin Transformer, please submit a GitHub issue.

For other communications related to CSWin Transformer, please contact Jianmin Bao ([email protected]), Dong Chen ([email protected]).

Comments
  • About the patches_resolution of the segmentation model

    About the patches_resolution of the segmentation model

    Hello, this work is interesting but I have some questions about the 'patches_resolution' of the segmentation model. I notice that the long side of the cross-shaped windows is the 'patches_resolution' rather than the real feature resoulution. For example, in the stage-3, the long side is 224 / 16 = 14. Do I understand it correctly? Does that make it impossible to exchannge information outside the 'patches_resolution' ?

    opened by danczs 2
  • The results of downstream task by my realization are poor

    The results of downstream task by my realization are poor

    There are some questions:

    1. the split size is still [1 2 7 7]?
    2. last stage branch_num is 2 or 1 ? The downstream task image resolution in last stages cannot equal to 7(split size). If not 1, the pretrained weights size is not matched
    3. pading is right in my realization ? pad_l = pad_t = 0 pad_r = (W_sp - W % W_sp) % W_sp pad_b = (H_sp - H % H_sp) % H_sp q = q.transpose(-2,-1).contiguous().view(B, H, W, C) k = q.transpose(-2,-1).contiguous().view(B, H, W, C) v = q.transpose(-2,-1).contiguous().view(B, H, W, C) if pad_r > 0 or pad_b > 0: q = F.pad(q, (0, 0, pad_l, pad_r, pad_t, pad_b)) k = F.pad(k, (0, 0, pad_l, pad_r, pad_t, pad_b)) v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = q.shape
    opened by Sunting78 2
  • Experiment setting for semantic segmentation

    Experiment setting for semantic segmentation

    Hi, thank you for the code. I implemented CSwin-T with FPN for semantic segmentation in ADE20K but couldn't reach the mIoU value of 48.2 as mentioned by you in the table. The maximum I could get was 39.9 mIoU, it will be great if you could share the exact experiment settings you used? Thanks

    opened by AnukritiSinghh 2
  • CSWin significantly slower than Swin?

    CSWin significantly slower than Swin?

    Greetings,

    From my benchmarks I have noticed that CSwin seems to be significantly slower than Swin when it comes to inference times, is this the expected behavior? While I can get predictions as fast as 20 miliseconds on Swin Large 384 it takes above 900 milisecond on CSWin_144_24322_large_384.

    I performed tests using FP16, torchscript, optimize_for_inference and torch.inference_mode

    opened by ErenBalatkan 2
  • Pretrained settings for object detection

    Pretrained settings for object detection

    Hi, I'm impressed by your excellent work.

    I have a question.

    I wonder which type of the pre-trained weights (224x224 or 384x384 finetuned) is used for object detection.

    I know both 224x224 and 384x384 are pre-trained on ImageNet-1k.

    opened by youngwanLEE 2
  • Error about building 384 models

    Error about building 384 models

    The code: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L407 shoud be:

    model = CSWinTransformer(img_size=384, patch_size=4, embed_dim=96, depth=[2,4,32,2],
    

    And as the same: https://github.com/microsoft/CSWin-Transformer/blob/d8be74a7833898f7bd9c77eb8c051d1b8bd5d753/models/cswin.py#L414

    opened by TingquanGao 2
  • The problem is shown in the figure

    The problem is shown in the figure

    model = CSWinTransformer(patch_size=4, embed_dim=96, depth=[2,4,32,2], split_size=[1,2,12,12], num_heads=[4,8,16,32], mlp_ratio=4.).cuda().eval() inp = torch.rand(1, 3, 224, 224).cuda() outs = model(inp) for out in outs: print(out.shape)

    RuntimeError: shape '[1, 192, 1, 14, 1, 12]' is invalid for input of size 37632

    why? image

    opened by rui-cf 2
  • about the test result on imagenet-2012

    about the test result on imagenet-2012

    Hi! I've test the CSwin-Tiny-224 released pretrained weight, this is my data transforms during testing:

    DEFAULT_CROP_SIZE = 0.9
    scale_size = int(math.floor(image_size / DEFAULT_CROP_SIZE))
    transform = transforms.Compose(
            [
                transforms.Resize(scale_size, interpolation=3)  # 3: bibubic
                if image_size == 224
                else transforms.Resize(image_size, interpolation=3),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
            ]
        )
    

    I can only get 80.5% on imagenet2012 dataset which is inconsistent with the results as you mentioned in this repo, did I miss some details about the data-augmentation during testing?

    opened by rentainhe 1
  • CSwin Code for Segmentation with MMSegmentation

    CSwin Code for Segmentation with MMSegmentation

    Hi, Thank you for your work! I wonder if you plan to release the mmsegmentation code you used for the downstream segmentation task, just like in Swin-Transformer repo.

    Best,

    opened by WalBouss 1
  • How do you produce Table 9  (ablation on different attention mecahnisms) in the  paper?

    How do you produce Table 9 (ablation on different attention mecahnisms) in the paper?

    Hi, thanks for your nice work. I'm doing some comparison on different attention mechanisms, and want to follow your experimental settings. I meet two problems:

    1. Why the reported mIoU is 41.9 for Swin-T in Table 9, while it is 46.1 in Swin Paper?
    2. Can you provide detailed experimental settings for semantic segmentation and object detection in table 9 ?
    opened by rayleizhu 0
  • about the setting of --use_chk

    about the setting of --use_chk

    give the parameter of --use_chk can launch the torch.utils.checkpoint to save the GPU memory, and I wonder if this could hurt the final performance, thanks a lot!

    opened by go-ahead-maker 0
  • Input image normalization parameters for semantic segmentation.

    Input image normalization parameters for semantic segmentation.

    Hey, guys, cool work!

    Unfortunately, for me it is not quite obvious, which normalization parameters (mean, std) did you use, when trained semantic segmentation model on ADE20K. Are they still IMAGENET_DEFAULT_MEAN or you used another values?

    opened by NikAleksFed 0
  • Using transfer to train over food101

    Using transfer to train over food101

    Hi all! I'm trying to train a model for food101 using the using the CSWin_64_12211_tiny_224 model with its pretrained values. The thing is, during execution it looks like its training from 0 rather than reusing the pretrained weights. By this I mean the initial top5 accuracy is around 5% but my initial thoughts is that it should be higher than this.

    For this I loaded the pretrained model and changed it's classification layer in a separate script and saved it for use as follows

    model = create_model( 'CSWin_64_12211_tiny_224', pretrained=True, num_classes=1000, drop_rate=0.0, drop_connect_rate=None, # DEPRECATED, use drop_path drop_path_rate=0.2, drop_block_rate=None, global_pool=None, bn_tf=False, bn_momentum=None, bn_eps=None, checkpoint_path='', img_size=224, use_chk=True) chk_path = './pretrained/cswin_tiny_224.pth' load_checkpoint(model, chk_path) model.reset_classifier(101, 'max')

    These are some of the runs I tried ` bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 5e-6 --min-lr 5e-7 --weight-decay 1e-8 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --mixup 0.1 --cooldown-epochs 10 --drop-path 0.7 --ema-finetune --lr-scale 1 --cutmix 0.1 --use-chk --num-classes 101 --pretrained --finetune ./pretrained/CSWin_64_12211_tiny_224101.pth

    bash finetune.sh 1 --data ../food-101 --model CSWin_64_12211_tiny_224 -b 32 --lr 2e-3 --weight-decay .05 --amp --img-size 224 --warmup-epochs 0 --model-ema-decay 0.9998 --epochs 20 --cooldown-epochs 10 --drop-path 0.2 --ema-finetune --cutmix 0.1 --use-chk --num-classes 101 --initial-checkpoint ./pretrained/CSWin_64_12211_tiny_224101.pth --lr-scale 1.0 --output ./full_base `

    Is there something I'm missing or a proper way I should try this?

    Thanks in advance for any help! :)

    opened by andreynz691 0
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Code repo for realtime multi-person pose estimation in CVPR'17 (Oral)

Realtime Multi-Person Pose Estimation By Zhe Cao, Tomas Simon, Shih-En Wei, Yaser Sheikh. Introduction Code repo for winning 2016 MSCOCO Keypoints Cha

Zhe Cao 4.9k Dec 31, 2022
TResNet: High Performance GPU-Dedicated Architecture

TResNet: High Performance GPU-Dedicated Architecture paperV2 | pretrained models Official PyTorch Implementation Tal Ridnik, Hussam Lawen, Asaf Noy, I

426 Dec 28, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
Python KNN model: Predicting a probability of getting a work visa. Tableau: Non-immigrant visas over the years.

The value of international students to the United States. Probability of getting a non-immigrant visa. Project timeline: Jan 2021 - April 2021 Project

Zinaida Dvoskina 2 Nov 21, 2021
Deep Probabilistic Programming Course @ DIKU

Deep Probabilistic Programming Course @ DIKU

52 May 14, 2022
A setup script to generate ITK Python Wheels

ITK Python Package This project provides a setup.py script to build ITK Python binary packages and infrastructure to build ITK external module Python

Insight Software Consortium 59 Dec 14, 2022
A collection of educational notebooks on multi-view geometry and computer vision.

Multiview notebooks This is a collection of educational notebooks on multi-view geometry and computer vision. Subjects covered in these notebooks incl

Max 65 Dec 09, 2022
deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

63 Oct 17, 2022
Official code for the publication "HyFactor: Hydrogen-count labelled graph-based defactorization Autoencoder".

HyFactor Graph-based architectures are becoming increasingly popular as a tool for structure generation. Here, we introduce a novel open-source archit

Laboratoire-de-Chemoinformatique 11 Oct 10, 2022
Dashboard for the COVID19 spread

COVID-19 Data Explorer App A streamlit Dashboard for the COVID-19 spread. The app is live at: [https://covid19.cwerner.ai]. New data is queried from G

Christian Werner 22 Sep 29, 2022
FAVD: Featherweight Assisted Vulnerability Discovery

FAVD: Featherweight Assisted Vulnerability Discovery This repository contains the replication package for the paper "Featherweight Assisted Vulnerabil

secureIT 4 Sep 16, 2022
Python with OpenCV - MediaPip Framework Hand Detection

Python HandDetection Python with OpenCV - MediaPip Framework Hand Detection Explore the docs » Contact Me About The Project It is a Computer vision pa

2 Jan 07, 2022
Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation

Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation Woncheol Shin1, Gyubok Lee1, Jiyoung Lee1, Joonseok Lee2,3, Edward Ch

Woncheol Shin 7 Sep 26, 2022
Make a surveillance camera from your raspberry pi!

rpi-surveillance Make a surveillance camera from your Raspberry Pi 4! The surveillance is built as following: the camera records 10 seconds video and

Vladyslav 62 Feb 03, 2022
Interactive Visualization to empower domain experts to align ML model behaviors with their knowledge.

An interactive visualization system designed to helps domain experts responsibly edit Generalized Additive Models (GAMs). For more information, check

InterpretML 83 Jan 04, 2023
Official Pytorch implementation of Meta Internal Learning

Official Pytorch implementation of Meta Internal Learning

10 Aug 24, 2022
MixRNet(Using mixup as regularization and tuning hyper-parameters for ResNets)

MixRNet(Using mixup as regularization and tuning hyper-parameters for ResNets) Using mixup data augmentation as reguliraztion and tuning the hyper par

Bhanu 2 Jan 16, 2022
Repository aimed at compiling code, papers, demos etc.. related to my PhD on 3D vision and machine learning for fruit detection and shape estimation at the university of Lincoln

PhD_3DPerception Repository aimed at compiling code, papers, demos etc.. related to my PhD on 3D vision and machine learning for fruit detection and s

lelouedec 2 Oct 06, 2022
Pytorch implementation for "Adversarial Robustness under Long-Tailed Distribution" (CVPR 2021 Oral)

Adversarial Long-Tail This repository contains the PyTorch implementation of the paper: Adversarial Robustness under Long-Tailed Distribution, CVPR 20

Tong WU 89 Dec 15, 2022
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

43 Dec 12, 2022