Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch

Overview

Omninet - Pytorch

Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be attending to all the tokens of the previous layers, leveraging recent efficient attention advances to achieve this goal.

Install

$ pip install omninet-pytorch

Usage

import torch
from omninet_pytorch import Omninet

omninet = Omninet(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1,              # feedforward dropout
    feature_redraw_interval = 1000 # how often to redraw the projection matrix for omni attention net - Performer
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

Causal case, just use the class OmninetCausal. At the moment, it isn't faithful to the paper (I am using layer axial attention with layer positional embeddings to draw up information), but will fix this once I rework the linear attention CUDA kernel.

import torch
from omninet_pytorch import OmninetCausal

omninet = OmninetCausal(
    dim = 512,                     # model dimension
    depth = 6,                     # depth
    dim_head = 64,                 # dimension per head
    heads = 8,                     # number of heads
    pool_layer_tokens_every = 3,   # key to this paper - every N layers, omni attend to all tokens of all layers
    attn_dropout = 0.1,            # attention dropout
    ff_dropout = 0.1               # feedforward dropout
)

x = torch.randn(1, 1024, 512)
mask = torch.ones(1, 1024).bool()

omninet(x, mask = mask) # (1, 1024, 512)

Citations

@misc{tay2021omninet,
    title   = {OmniNet: Omnidirectional Representations from Transformers}, 
    author  = {Yi Tay and Mostafa Dehghani and Vamsi Aribandi and Jai Gupta and Philip Pham and Zhen Qin and Dara Bahri and Da-Cheng Juan and Donald Metzler},
    year    = {2021},
    eprint  = {2103.01075},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
PyTorch implementation for paper StARformer: Transformer with State-Action-Reward Representations.
PyTorch implementation for paper StARformer: Transformer with State-Action-Reward Representations.

StARformer This repository contains the PyTorch implementation for our paper titled StARformer: Transformer with State-Action-Reward Representations.

Official PyTorch implementation of BlobGAN: Spatially Disentangled Scene Representations

BlobGAN: Spatially Disentangled Scene Representations Official PyTorch Implementation Paper | Project Page | Video | Interactive Demo BlobGAN.mp4 This

ChatBot-Pytorch - A GPT-2 ChatBot implemented using Pytorch and Huggingface-transformers

ChatBot-Pytorch A GPT-2 ChatBot implemented using Pytorch and Huggingface-transf

Here is the implementation of our paper S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations.
Here is the implementation of our paper S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations.

S2VC Here is the implementation of our paper S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations. In thi

[ICCV'21] Official implementation for the paper  Social NCE: Contrastive Learning of Socially-aware Motion Representations
[ICCV'21] Official implementation for the paper Social NCE: Contrastive Learning of Socially-aware Motion Representations

CrowdNav with Social-NCE This is an official implementation for the paper Social NCE: Contrastive Learning of Socially-aware Motion Representations by

Tensorflow 2 implementation of the paper: Learning and Evaluating Representations for Deep One-class Classification published at ICLR 2021

Deep Representation One-class Classification (DROC). This is not an officially supported Google product. Tensorflow 2 implementation of the paper: Lea

Public Implementation of ChIRo from
Public Implementation of ChIRo from "Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations"

Learning 3D Representations of Molecular Chirality with Invariance to Bond Rotations This directory contains the model architectures and experimental

Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP 🚧 WIP 🚧 Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP πŸ“„ πŸ”— Ho-Hsiang Wu, Prem Seetharaman

Implementation of the method proposed in the paper
Implementation of the method proposed in the paper "Neural Descriptor Fields: SE(3)-Equivariant Object Representations for Manipulation"

Neural Descriptor Fields (NDF) PyTorch implementation for training continuous 3D neural fields to represent dense correspondence across objects, and u

Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

153 Dec 14, 2022
ThunderSVM: A Fast SVM Library on GPUs and CPUs

What's new We have recently released ThunderGBM, a fast GBDT and Random Forest library on GPUs. add scikit-learn interface, see here Overview The miss

Xtra Computing Group 1.4k Dec 22, 2022
This repo holds codes of the ICCV21 paper: Visual Alignment Constraint for Continuous Sign Language Recognition.

VAC_CSLR This repo holds codes of the paper: Visual Alignment Constraint for Continuous Sign Language Recognition.(ICCV 2021) [paper] Prerequisites Th

Yuecong Min 64 Dec 19, 2022
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022
The VeriNet toolkit for verification of neural networks

VeriNet The VeriNet toolkit is a state-of-the-art sound and complete symbolic interval propagation based toolkit for verification of neural networks.

9 Dec 21, 2022
Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Pytorch Lightning 1k Jan 02, 2023
Real-time Joint Semantic Reasoning for Autonomous Driving

MultiNet MultiNet is able to jointly perform road segmentation, car detection and street classification. The model achieves real-time speed and state-

Marvin Teichmann 518 Dec 12, 2022
DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS) data.

DeepConsensus DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS)

Google 149 Dec 19, 2022
Qlib is an AI-oriented quantitative investment platform

Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.

Microsoft 10.1k Dec 30, 2022
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
Project code for weakly supervised 3D object detectors using wide-baseline multi-view traffic camera data: WIBAM.

WIBAM (Work in progress) Weakly Supervised Training of Monocular 3D Object Detectors Using Wide Baseline Multi-view Traffic Camera Data 3D object dete

Matthew Howe 10 Aug 24, 2022
Easy to use Python camera interface for NVIDIA Jetson

JetCam JetCam is an easy to use Python camera interface for NVIDIA Jetson. Works with various USB and CSI cameras using Jetson's Accelerated GStreamer

NVIDIA AI IOT 358 Jan 02, 2023
Official Implementation of PCT

Official Implementation of PCT Prerequisites python == 3.8.5 Please make sure you have the following libraries installed: numpy torch=1.4.0 torchvisi

32 Nov 21, 2022
Official implementation of "Watermarking Images in Self-Supervised Latent-Spaces"

πŸ” Watermarking Images in Self-Supervised Latent-Spaces PyTorch implementation and pretrained models for the paper. For details, see Watermarking Imag

Meta Research 32 Dec 13, 2022
Management Dashboard for Torchserve

Torchserve Dashboard Torchserve Dashboard using Streamlit Related blog post Usage Additional Requirement: torchserve (recommended:v0.5.2) Simply run:

Ceyda Cinarel 103 Dec 10, 2022
Generating Videos with Scene Dynamics

Generating Videos with Scene Dynamics This repository contains an implementation of Generating Videos with Scene Dynamics by Carl Vondrick, Hamed Pirs

Carl Vondrick 706 Jan 04, 2023
Jihye Back 520 Jan 04, 2023
Learning-Augmented Dynamic Power Management

Learning-Augmented Dynamic Power Management This repository contains source code accompanying paper Learning-Augmented Dynamic Power Management with M

Adam 0 Feb 22, 2022
Offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation

Shunted Transformer This is the offical implementation of Shunted Self-Attention via Multi-Scale Token Aggregation by Sucheng Ren, Daquan Zhou, Shengf

156 Dec 27, 2022
βšΎπŸ€–βšΎ Automatic baseball pitching overlay in realtime

⚾ Automatically overlaying pitch motion and trajectory with machine learning! This project takes your baseball pitching clips and automatically genera

Tony Chou 240 Dec 05, 2022