Reproducing the Linear Multihead Attention introduced in Linformer paper (Linformer: Self-Attention with Linear Complexity)

Overview

Linear Multihead Attention (Linformer)

PyTorch Implementation of reproducing the Linear Multihead Attention introduced in Linformer paper (Linformer: Self-Attention with Linear Complexity), which demonstrates that the self-attention mechanism can be approximated by a low-rank matrix and reduces the overall self-attention complexity from O(n^2) to O(n) in both time and space.

Implementation

This is an efficient implementation followed with the PyTorch official torch.nn.MultiheadAttention class and F.multi_head_attention_forward function.

Three additional argments defined in LinearMultiheadAttention: sequence length, the projected dimention k and the parameter sharing.

seq_len: the sequence length. Default: 100.
proj_k: the projected dimention `k` in Linformer paper. Default: 128.
param_sharing: parameter sharing mode: layerwise, none. headwise is not implemented. Default: none.

Usage

Examples of using torch.nn.MultiheadAttention:

>>> import torch
>>> multihead_attn = torch.nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

Examples of using LinearMultiheadAttention:

>>> from linear_multihead_attention import LinearMultiheadAttention
>>> multihead_attn = LinearMultiheadAttention(embed_dim, num_heads) 
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

Examples of using LinearMultiheadAttention with the sequence length of 512 and :

>>> from linear_multihead_attention import LinearMultiheadAttention
>>> multihead_attn = LinearMultiheadAttention(embed_dim, num_heads, seq_len=512, proj_k=256, param_sharing='layerwise') 
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

Linear-DETR: Replace torch.nn.MultiheadAttention in DETR with LinearMultiheadAttention in three lines in models/transformer.py, it saved much more memory and space, hope to have a comparable performance:

from linear_multihead_attention import LinearMultiheadAttention

# TransformerEncoderLayer
# self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, seq_len=w*h, proj_k=64) # where w, h are from `bs, c, h, w = src.shape`


# TransformerDecoderLayer
# self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

self.self_attn = LinearMultiheadAttention(d_model, nhead, dropout=dropout, seq_len=num_queries, proj_k=64) # where num_queries = args.num_queries
self.multihead_attn = LinearMultiheadAttention(d_model, nhead, dropout=dropout, seq_len=w*h, proj_k=64) # where w, h are from `bs, c, h, w = src.shape`

Results on DETR

TODO

Citation

@misc{wang2020linformer,
    title={Linformer: Self-Attention with Linear Complexity},
    author={Sinong Wang and Belinda Z. Li and Madian Khabsa and Han Fang and Hao Ma},
    year={2020},
    eprint={2006.04768},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Owner
Kui Xu
Researcher, interested in Computational Biology, and 3D Computer Vision.
Kui Xu
🍊 PAUSE (Positive and Annealed Unlabeled Sentence Embedding), accepted by EMNLP'2021 🌴

PAUSE: Positive and Annealed Unlabeled Sentence Embedding Sentence embedding refers to a set of effective and versatile techniques for converting raw

EQT 21 Dec 15, 2022
This is a MD5 password/passphrase brute force tool

CROWES-PASS-CRACK-TOOl This is a MD5 password/passphrase brute force tool How to install: Do 'git clone https://github.com/CROW31/CROWES-PASS-CRACK-TO

9 Mar 02, 2022
Code for the project carried out fulfilling the course requirements for Fall 2021 NLP at NYU

Introduction Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization,

Sai Himal Allu 1 Apr 25, 2022
Contains links to publicly available datasets for modeling health outcomes using speech and language.

speech-nlp-datasets Contains links to publicly available datasets for modeling various health outcomes using speech and language. Speech-based Corpora

Tuka Alhanai 77 Dec 07, 2022
Autoregressive Entity Retrieval

The GENRE (Generative ENtity REtrieval) system as presented in Autoregressive Entity Retrieval implemented in pytorch. @inproceedings{decao2020autoreg

Meta Research 611 Dec 16, 2022
Wrapper to display a script output or a text file content on the desktop in sway or other wlroots-based compositors

nwg-wrapper This program is a part of the nwg-shell project. This program is a GTK3-based wrapper to display a script output, or a text file content o

Piotr Miller 94 Dec 27, 2022
Guide to using pre-trained large language models of source code

Large Models of Source Code I occasionally train and publicly release large neural language models on programs, including PolyCoder. Here, I describe

Vincent Hellendoorn 947 Dec 28, 2022
GrammarTagger — A Neural Multilingual Grammar Profiler for Language Learning

GrammarTagger — A Neural Multilingual Grammar Profiler for Language Learning GrammarTagger is an open-source toolkit for grammatical profiling for lan

Octanove Labs 27 Jan 05, 2023
A Domain Specific Language (DSL) for building language patterns. These can be later compiled into spaCy patterns, pure regex, or any other format

RITA DSL This is a language, loosely based on language Apache UIMA RUTA, focused on writing manual language rules, which compiles into either spaCy co

Šarūnas Navickas 60 Sep 26, 2022
端到端的长本文摘要模型(法研杯2020司法摘要赛道)

端到端的长文本摘要模型(法研杯2020司法摘要赛道)

苏剑林(Jianlin Su) 334 Jan 08, 2023
Model parallel transformers in JAX and Haiku

Table of contents Mesh Transformer JAX Updates Pretrained Models GPT-J-6B Links Acknowledgments License Model Details Zero-Shot Evaluations Architectu

Ben Wang 4.9k Jan 04, 2023
Weaviate demo with the text2vec-openai module

Weaviate demo with the text2vec-openai module This repository contains an example of how to use the Weaviate text2vec-openai module. When using this d

SeMI Technologies 11 Nov 11, 2022
Longformer: The Long-Document Transformer

Longformer Longformer and LongformerEncoderDecoder (LED) are pretrained transformer models for long documents. ***** New December 1st, 2020: Longforme

AI2 1.6k Dec 29, 2022
基于Transformer的单模型、多尺度的VAE模型

UniVAE 基于Transformer的单模型、多尺度的VAE模型 介绍 https://kexue.fm/archives/8475 依赖 需要大于0.10.6版本的bert4keras(当前还没有推到pypi上,可以直接从GitHub上clone最新版)。 引用 @misc{univae,

苏剑林(Jianlin Su) 49 Aug 24, 2022
Index different CKAN entities in Solr, not just datasets

ckanext-sitesearch Index different CKAN entities in Solr, not just datasets Requirements This extension requires CKAN 2.9 or higher and Python 3 Featu

Open Knowledge Foundation 3 Dec 02, 2022
Stuff related to Ben Eater's 8bit breadboard computer

8bit breadboard computer simulator This is an assembler + simulator/emulator of Ben Eater's 8bit breadboard computer. For a version with its RAM upgra

Marijn van Vliet 29 Dec 29, 2022
Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

COCO LM Pretraining (wip) Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were a

Phil Wang 44 Jul 28, 2022
Repository to hold code for the cap-bot varient that is being presented at the SIIC Defence Hackathon 2021.

capbot-siic Repository to hold code for the cap-bot varient that is being presented at the SIIC Defence Hackathon 2021. Problem Inspiration A plethora

Aryan Kargwal 19 Feb 17, 2022
Japanese synonym library

chikkarpy chikkarpyはchikkarのPython版です。 chikkarpy is a Python version of chikkar. chikkarpy は Sudachi 同義語辞書を利用し、SudachiPyの出力に同義語展開を追加するために開発されたライブラリです。

Works Applications 48 Dec 14, 2022
Python library for interactive topic model visualization. Port of the R LDAvis package.

pyLDAvis Python library for interactive topic model visualization. This is a port of the fabulous R package by Carson Sievert and Kenny Shirley. pyLDA

Ben Mabey 1.7k Dec 20, 2022