Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Overview

Deformable Attention

Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DETR. The relative positional embedding has also been modified for better extrapolation, using the Continuous Positional Embedding proposed in SwinV2.

Install

$ pip install deformable-attention

Usage

import torch
from deformable_attention import DeformableAttention

attn = DeformableAttention(
    dim = 512,                   # feature dimensions
    dim_head = 64,               # dimension per head
    heads = 8,                   # attention heads
    dropout = 0.,                # dropout
    downsample_factor = 4,       # downsample factor (r in paper)
    offset_scale = 4,            # scale of offset, maximum offset
    offset_groups = None,        # number of offset groups, should be multiple of heads
    offset_kernel_size = 6,      # offset kernel size
)

x = torch.randn(1, 512, 64, 64)
attn(x) # (1, 512, 64, 64)

3d deformable attention

import torch
from deformable_attention import DeformableAttention3D

attn = DeformableAttention3D(
    dim = 512,                          # feature dimensions
    dim_head = 64,                      # dimension per head
    heads = 8,                          # attention heads
    dropout = 0.,                       # dropout
    downsample_factor = (2, 8, 8),      # downsample factor (r in paper)
    offset_scale = (2, 8, 8),           # scale of offset, maximum offset
    offset_kernel_size = (4, 10, 10),   # offset kernel size
)

x = torch.randn(1, 512, 10, 32, 32) # (batch, dimension, frames, height, width)
attn(x) # (1, 512, 10, 32, 32)

1d deformable attention for good measure

import torch
from deformable_attention import DeformableAttention1D

attn = DeformableAttention1D(
    dim = 128,
    downsample_factor = 4,
    offset_scale = 2,
    offset_kernel_size = 6
)

x = torch.randn(1, 128, 512)
attn(x) # (1, 128, 512)

Citation

@misc{xia2022vision,
    title   = {Vision Transformer with Deformable Attention}, 
    author  = {Zhuofan Xia and Xuran Pan and Shiji Song and Li Erran Li and Gao Huang},
    year    = {2022},
    eprint  = {2201.00520},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Some code of the implements of Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network

3D-GMPDCNN Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network PyTorch implementation of "Geological Modeling Usin

MoCoPnet - Deformable 3D Convolution for Video Super-Resolution
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

Deformable 3D Convolution for Video Super-Resolution Pytorch implementation of l

3D2Unet: 3D Deformable Unet for Low-Light Video Enhancement (PRCV2021)

3DDUNET This is the code for 3D2Unet: 3D Deformable Unet for Low-Light Video Enhancement (PRCV2021) Conference Paper Link Dataset We use SMOID dataset

Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Comments
  • The relationship between 'dim' and 'inner_dim'

    The relationship between 'dim' and 'inner_dim'

    Hi, I have a question about DeformableAttention module,

    I calculated the output volumes of the forward processes step by step, According to my calculation, the code works only when 'dim' and 'inner_dim' is same.

    Is there any reason why you implement it this way?

    Best regards, Hankyu

    opened by hanq0212 4
  • TypeError: meshgrid() got an unexpected keyword argument 'indexing'

    TypeError: meshgrid() got an unexpected keyword argument 'indexing'

    @lucidrains while trying to perform import torch from deformable_attention import DeformableAttention

    attn = DeformableAttention( dim = 512, # feature dimensions dim_head = 64, # dimension per head heads = 8, # attention heads dropout = 0., # dropout downsample_factor = 4, # downsample factor (r in paper) offset_scale = 4, # scale of offset, maximum offset offset_groups = None, # number of offset groups, should be multiple of heads offset_kernel_size = 6, # offset kernel size )

    x = torch.randn(1, 512, 64, 64) attn(x)

    Got error below from the line.. Kindly help

    https://github.com/lucidrains/deformable-attention/blob/9f3c0ae35652ce54687e0db409921018bfca3310/deformable_attention/deformable_attention_2d.py#L26

    opened by ChidanandKumarKS 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implementation of "Glancing Transformer for Non-Autoregressive Neural Machine Translation"

GLAT Implementation for the ACL2021 paper "Glancing Transformer for Non-Autoregressive Neural Machine Translation" Requirements Python = 3.7 Pytorch

117 Jan 09, 2023
Virtual hand gesture mouse using a webcam

NonMouse 日本語のREADMEはこちら This is an application that allows you to use your hand itself as a mouse. The program uses a web camera to recognize your han

Yuki Takeyama 55 Jan 01, 2023
Sequential Model-based Algorithm Configuration

SMAC v3 Project Copyright (C) 2016-2018 AutoML Group Attention: This package is a reimplementation of the original SMAC tool (see reference below). Ho

AutoML-Freiburg-Hannover 778 Jan 05, 2023
ViDT: An Efficient and Effective Fully Transformer-based Object Detector

ViDT: An Efficient and Effective Fully Transformer-based Object Detector by Hwanjun Song1, Deqing Sun2, Sanghyuk Chun1, Varun Jampani2, Dongyoon Han1,

NAVER AI 262 Dec 27, 2022
This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset

HiRID-ICU-Benchmark This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset for which the manuscript can be found here.

Biomedical Informatics at ETH Zurich 30 Dec 16, 2022
A Diagnostic Dataset for Compositional Language and Elementary Visual Reasoning

CLEVR Dataset Generation This is the code used to generate the CLEVR dataset as described in the paper: CLEVR: A Diagnostic Dataset for Compositional

Facebook Research 503 Jan 04, 2023
Implementation of the paper "Generating Symbolic Reasoning Problems with Transformer GANs"

Generating Symbolic Reasoning Problems with Transformer GANs This is the implementation of the paper Generating Symbolic Reasoning Problems with Trans

Reactive Systems Group 1 Apr 18, 2022
Official implementation for ICDAR 2021 paper "Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer"

Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer Description Convert offline handwritten mathematical expressi

Wenqi Zhao 87 Dec 27, 2022
Implementation of the Swin Transformer in PyTorch.

Swin Transformer - PyTorch Implementation of the Swin Transformer architecture. This paper presents a new vision Transformer, called Swin Transformer,

597 Jan 03, 2023
The official implementation of the IEEE S&P`22 paper "SoK: How Robust is Deep Neural Network Image Classification Watermarking".

Watermark-Robustness-Toolbox - Official PyTorch Implementation This repository contains the official PyTorch implementation of the following paper to

49 Dec 19, 2022
Diffusion Probabilistic Models for 3D Point Cloud Generation (CVPR 2021)

Diffusion Probabilistic Models for 3D Point Cloud Generation [Paper] [Code] The official code repository for our CVPR 2021 paper "Diffusion Probabilis

Shitong Luo 323 Jan 05, 2023
Processed, version controlled history of Minecraft's generated data and assets

mcmeta Processed, version controlled history of Minecraft's generated data and assets Repository structure Each of the following branches has a commit

Misode 75 Dec 28, 2022
Deep universal probabilistic programming with Python and PyTorch

Getting Started | Documentation | Community | Contributing Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notab

7.7k Dec 30, 2022
Training a deep learning model on the noisy CIFAR dataset

Training-a-deep-learning-model-on-the-noisy-CIFAR-dataset This repository contai

1 Jun 14, 2022
A resource for learning about ML, DL, PyTorch and TensorFlow. Feedback always appreciated :)

A resource for learning about ML, DL, PyTorch and TensorFlow. Feedback always appreciated :)

Aladdin Persson 4.7k Jan 08, 2023
Cupytorch - A small framework mimics PyTorch using CuPy or NumPy

CuPyTorch CuPyTorch是一个小型PyTorch,名字来源于: 不同于已有的几个使用NumPy实现PyTorch的开源项目,本项目通过CuPy支持

Xingkai Yu 23 Aug 17, 2022
The devkit of the nuPlan dataset.

The devkit of the nuPlan dataset.

Motional 264 Jan 03, 2023
Mercury: easily convert Python notebook to web app and share with others

Mercury Share your Python notebooks with others Easily convert your Python notebooks into interactive web apps by adding parameters in YAML. Simply ad

MLJAR 2.2k Dec 27, 2022
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022