Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Overview

Memory Efficient Attention Pytorch

Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

import torch
from memory_efficient_attention_pytorch import Attention

attn = Attention(
    dim = 512,
    dim_head = 64,                # dimension per head
    heads = 8,                    # number of attention heads
    causal = True,                # autoregressive or not
    memory_efficient = True,      # whether to use memory efficient attention (can be turned off to test against normal attention)
    q_bucket_size = 1024,         # bucket size along queries dimension
    k_bucket_size = 2048          # bucket size along key / values dimension
).cuda()

x = torch.randn(1, 65536, 512).cuda()
out = attn(x) # (1, 65536, 512)

Cross attention

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    memory_efficient = True,
    q_bucket_size = 1024,
    k_bucket_size = 2048
).cuda()

x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
mask = torch.ones(1, 65536).bool().cuda()

out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512)
  • benchmark and see how much torch jit helps
  • look at Triton and Keops and see if either can be a fit

Citations

@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • [feature request] Combining with flash attention?

    [feature request] Combining with flash attention?

    There is a new algorithm to optimize the qkv attention, https://github.com/HazyResearch/flash-attention https://arxiv.org/abs/2205.14135 It optimises the qkv attention part. Maybe you can look into integrating it with this.

    opened by Vbansal21 15
  • i did this, we could build on top

    i did this, we could build on top

    Hi there!

    It seems I did already some of the code... https://github.com/CHARM-Tx/linear_mem_attention_pytorch could we build on top of this? I talked to https://github.com/Chillee about an experimental functionality from functorch: https://github.com/pytorch/functorch that would allow for increased speed (mainly i want to match jax perofmance but its just difficult w/ pytorch imperative style).

    I would love to collaborate on this if you want!

    opened by hypnopump 5
  • Added dropout support to memory efficient variant

    Added dropout support to memory efficient variant

    Hey Phil,

    I have been using this repository for a project and I wanted to add dropout for completeness. I checked consistency with perceiver-ar impl.. I hope this is helpful.

    -Matt

    opened by usryokousha 2
  • Making this work with relative position bias from XTransformers

    Making this work with relative position bias from XTransformers

    Is there a way to make this work with RelativePositionBias. Currently this produces an attention bias of size $BHN^2$ where B is batch size, H is number of heads and N is input size. Can this be chunked and computed per chunk?

    opened by pfeatherstone 5
  •  save_for_backward can only save variables, but argument 5 is of type bool

    save_for_backward can only save variables, but argument 5 is of type bool

    Hi,

    Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.

    Can you please help me out?

    Code:

    import torch from memory_efficient_attention_pytorch import Attention

    cross_attn = Attention( dim = 512, dim_head = 64, heads = 8, memory_efficient = True, q_bucket_size = 1024, k_bucket_size = 2048 ).cuda() (# out = sm_mod(inp1)) did this to avoid being a header x = torch.randn(1, 65536, 512).cuda() context = torch.randn(1, 65536, 512).cuda() (# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading out = cross_attn(x

    ERROR:

    File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in cli.main() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main run() File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file runpy.run_path(target_as_str, run_name=compat.force_str("main")) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path return _run_module_code(code, init_globals, run_name, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code _run_code(code, mod_globals, init_globals, File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out) File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size) File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint return CheckpointFunction.apply(function, preserve, *args) TypeError: save_for_backward can only save variables, but argument 5 is of type bool

    opened by aliabid2243 1
  • Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    Checkpointing is not compatible with .grad() or when an `inputs` parameter is passed to .backward()

    https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/35559a05572f9d4eb982a8e2e399b40a2d61b85c/memory_efficient_attention_pytorch/memory_efficient_attention.py#L95

    Should this be: summarize_qkv_fn = summarize_qkv_chunk if needs_backwards else checkpointed_summarize_qkv_chunk instead of: summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk

    opened by vrobot 0
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Configure SRX interfaces with Scrapli

Configure SRX interfaces with Scrapli Overview This example will show how to configure interfaces on Juniper's SRX firewalls. In addition to the Pytho

Calvin Remsburg 1 Jan 07, 2022
CVPRW 2021: How to calibrate your event camera

E2Calib: How to Calibrate Your Event Camera This repository contains code that implements video reconstruction from event data for calibration as desc

Robotics and Perception Group 104 Nov 16, 2022
An introduction to bioimage analysis - http://bioimagebook.github.io

Introduction to Bioimage Analysis This book tries explain the main ideas of image analysis in a practical and engaging way. It's written primarily for

Bioimage Book 20 Nov 28, 2022
DI-HPC is an acceleration operator component for general algorithm modules in reinforcement learning algorithms

DI-HPC: Decision Intelligence - High Performance Computation DI-HPC is an acceleration operator component for general algorithm modules in reinforceme

OpenDILab 185 Dec 29, 2022
Group project for MFIN7036. Our goal is to predict firm profitability with text-based competition measures.

NLP_0-project Group project for MFIN7036. Our goal is to predict firm profitability with text-based competition measures1. We are a "democratic" and c

3 Mar 16, 2022
Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation

Info This is the code repository of the work Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation from Elias T

2 Apr 20, 2022
Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

0 Jan 23, 2022
A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset.

A repository for the updated version of CoinRun used to collect MUGEN, a multimodal video-audio-text dataset. This repo contains scripts to train RL agents to navigate the closed world and collect vi

MUGEN 11 Oct 22, 2022
This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

Facebook Research 32 Jan 05, 2023
Robot Servers and Server Manager software for robo-gym

robo-gym-server-modules Robot Servers and Server Manager software for robo-gym. For info on how to use this package please visit the robo-gym website

JR ROBOTICS 4 Aug 16, 2021
Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch

Kim Seonghyeon 2.2k Jan 01, 2023
Scalable training for dense retrieval models.

Scalable implementation of dense retrieval. Training on cluster By default it trains locally: PYTHONPATH=.:$PYTHONPATH python dpr_scale/main.py traine

Facebook Research 90 Dec 28, 2022
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 87 Jan 08, 2023
N-RPG - Novel role playing game da turfu

N-RPG Ce README sera la page de garde du projet. Contenu Il contiendra la présen

4 Mar 15, 2022
This is the pytorch implementation of the paper - Axiomatic Attribution for Deep Networks.

Integrated Gradients This is the pytorch implementation of "Axiomatic Attribution for Deep Networks". The original tensorflow version could be found h

Tianhong Dai 150 Dec 23, 2022
[NeurIPS'20] Self-supervised Co-Training for Video Representation Learning. Tengda Han, Weidi Xie, Andrew Zisserman.

CoCLR: Self-supervised Co-Training for Video Representation Learning This repository contains the implementation of: InfoNCE (MoCo on videos) UberNCE

Tengda Han 271 Jan 02, 2023
A research toolkit for particle swarm optimization in Python

PySwarms is an extensible research toolkit for particle swarm optimization (PSO) in Python. It is intended for swarm intelligence researchers, practit

Lj Miranda 1k Dec 30, 2022
Leveraging Two Types of Global Graph for Sequential Fashion Recommendation, ICMR 2021

This is the repo for the paper: Leveraging Two Types of Global Graph for Sequential Fashion Recommendation Requirements OS: Ubuntu 16.04 or higher ver

Yujuan Ding 10 Oct 10, 2022
An Approach to Explore Logistic Regression Models

User-centered Regression An Approach to Explore Logistic Regression Models This tool applies the potential of Attribute-RadViz in identifying correlat

0 Nov 12, 2021