Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Overview

Perceiver Twitter

PyPI Lint with Black⬛ Upload Python Package DOI Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on top of Transformers such that the data only enters through the cross attention mechanism (see figure) and allow it to scale to hundreds of thousands of inputs, like ConvNets. This, in part also solves the Transformers Quadratic compute and memory bottleneck.

Yannic Kilcher's video was very helpful.

Installation

Run the following to install:

pip install perceiver

Developing perceiver

To install perceiver, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Perceiver.git
# or clone your own fork

cd perceiver
pip install -e .[dev]

A bit about Perceiver

The Perceiver model aims to deal with arbitrary configurations of different modalities using a single transformer-based architecture. Transformers are often flexible and make few assumptions about their inputs, but that also scale quadratically with the number of inputs in terms of both memory and computation. This model proposes a mechanism that makes it possible to deal with high-dimensional inputs, while retaining the expressivity and flexibility to deal with arbitrary input configurations.

The idea here is to introduce a small set of latent units that forms an attention bottleneck through which the inputs must pass. This avoids the quadratic scaling problem of all-to-all attention of a classical transformer. The model can be seen as performing a fully end-to-end clustering of the inputs, with the latent units as the cluster centres, leveraging a highly asymmetric crossattention layer. For spatial information the authors compensate for the lack of explicit grid structures in our model by associating Fourier feature encodings.

Usage

from perceiver import Perceiver
import tensorflow as tf

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
)

img = tf.random.normal([1, 224, 224, 3]) # replicating 1 imagenet image
model(img) # (1, 1000)

About the notebooks

perceiver_example

Open In Colab Binder

This notebook installs the perceiver package and shows an example of running it on a single imagenet image ([1, 224, 224, 3]) with 1000 classes to demonstarte the working of this model.

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • error with tf2.4.1

    error with tf2.4.1

    Hello Rishit,

    thank you for your Perceiver implementation! I have two notes, I am not very familiar with tf2 though. You define and call a tf.keras.Sequential model here https://github.com/Rishit-dagli/Perceiver/blob/4d3b9b0514da4fb623d178e3e70df1836ebad5ba/perceiver/perceiver.py#L106 For my version of tf at least this throws an error, I think it should be defined once in __init__ and then just called in call.

    And just above it, you compute data but then you don't pass it to self.model. Is that correct?

    bug 
    opened by abred 3
  • Training code

    Training code

    Hi there,

    I've tried to set up a standard MNIST training over the last few days using the Perceiver code provided here. So far, I've not been able to come up with any solution where the model actually learns anything. A major problem so far has been the way the model is written with no support for model.fit() and the whole functional API.

    Do you happen to have any training example code for your model which you could provide here in this repo? MNIST as the default starting point would be nice, but anything would do the job as well :)

    question 
    opened by tpetri94 2
  • Create a FeedForward layer

    Create a FeedForward layer

    Create a simple FeedForward layer as a tf.keras.layers.Layer which should essentially contain a Dense layer with the modified GELU activation (#2 ), optionally I could also include a dropout layer and another Dense layer which should have the number of neurons equal to the dimension

    opened by Rishit-dagli 0
  • Implement a PreNorm layer

    Implement a PreNorm layer

    Create a Normalization layer from the tf.keras.layerr.Layers. This should essentially figure out the right axis and implement layer normalization on it.

    opened by Rishit-dagli 0
  • Don't pin TensorFlow version to a specific number

    Don't pin TensorFlow version to a specific number

    Hello,

    In setup.py you should change "tensorflow~=2.4.0" to " "tensorflow>2.4.0" to ensure any version above the minimal one is used.

    bug 
    opened by ebursztein 0
Releases(v0.1.2)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
Consensus Learning from Heterogeneous Objectives for One-Class Collaborative Filtering

Consensus Learning from Heterogeneous Objectives for One-Class Collaborative Filtering This repository provides the source code of "Consensus Learning

SeongKu-Kang 6 Apr 29, 2022
Keras Implementation of Neural Style Transfer from the paper "A Neural Algorithm of Artistic Style"

Neural Style Transfer & Neural Doodles Implementation of Neural Style Transfer from the paper A Neural Algorithm of Artistic Style in Keras 2.0+ INetw

Somshubra Majumdar 2.2k Dec 31, 2022
Official Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021)

TDEER 🦌 🦒 Official Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021) Overview TDEE

33 Dec 23, 2022
FAMIE is a comprehensive and efficient active learning (AL) toolkit for multilingual information extraction (IE)

FAMIE: A Fast Active Learning Framework for Multilingual Information Extraction

18 Sep 01, 2022
Implementation for "Conditional entropy minimization principle for learning domain invariant representation features"

Implementation for "Conditional entropy minimization principle for learning domain invariant representation features". The code is reproduced from thi

1 Nov 02, 2022
Train emoji embeddings based on emoji descriptions.

emoji2vec This is my attempt to train, visualize and evaluate emoji embeddings as presented by Ben Eisner, Tim Rocktäschel, Isabelle Augenstein, Matko

Miruna Pislar 17 Sep 03, 2022
Implementation of the master's thesis "Temporal copying and local hallucination for video inpainting".

Temporal copying and local hallucination for video inpainting This repository contains the implementation of my master's thesis "Temporal copying and

David Álvarez de la Torre 1 Dec 02, 2022
A Collection of Papers and Codes for ICCV2021 Low Level Vision and Image Generation

A Collection of Papers and Codes for ICCV2021 Low Level Vision and Image Generation

196 Jan 05, 2023
discovering subdomains, hidden paths, extracting unique links

python-website-crawler discovering subdomains, hidden paths, extracting unique links pip install -r requirements.txt discover subdomain: You can give

merve 4 Sep 05, 2022
Implementation of [Time in a Box: Advancing Knowledge Graph Completion with Temporal Scopes].

Time2box Implementation of [Time in a Box: Advancing Knowledge Graph Completion with Temporal Scopes].

LingCai 4 Aug 23, 2022
Beancount-mercury - Beancount importer for Mercury Startup Checking

beancount-mercury beancount-mercury provides an Importer for converting CSV expo

Michael Lynch 4 Oct 31, 2022
Deep Learning for Human Part Discovery in Images - Chainer implementation

Deep Learning for Human Part Discovery in Images - Chainer implementation NOTE: This is not official implementation. Original paper is Deep Learning f

Shintaro Shiba 63 Sep 25, 2022
deep-prae

Deep Probabilistic Accelerated Evaluation (Deep-PrAE) Our work presents an efficient rare event simulation methodology for black box autonomy using Im

Safe AI Lab 4 Apr 17, 2021
Integrated physics-based and ligand-based modeling.

ComBind ComBind integrates data-driven modeling and physics-based docking for improved binding pose prediction and binding affinity prediction. Given

Dror Lab 44 Oct 26, 2022
Pytorch implementation for "Adversarial Robustness under Long-Tailed Distribution" (CVPR 2021 Oral)

Adversarial Long-Tail This repository contains the PyTorch implementation of the paper: Adversarial Robustness under Long-Tailed Distribution, CVPR 20

Tong WU 89 Dec 15, 2022
PyTorch implementation of our method for adversarial attacks and defenses in hyperspectral image classification.

Self-Attention Context Network for Hyperspectral Image Classification PyTorch implementation of our method for adversarial attacks and defenses in hyp

22 Dec 02, 2022
DrWhy is the collection of tools for eXplainable AI (XAI). It's based on shared principles and simple grammar for exploration, explanation and visualisation of predictive models.

Responsible Machine Learning With Great Power Comes Great Responsibility. Voltaire (well, maybe) How to develop machine learning models in a responsib

Model Oriented 590 Dec 26, 2022
A simple library that implements CLIP guided loss in PyTorch.

pytorch_clip_guided_loss: Pytorch implementation of the CLIP guided loss for Text-To-Image, Image-To-Image, or Image-To-Text generation. A simple libr

Sergei Belousov 74 Dec 26, 2022
A framework for GPU based high-performance medical image processing and visualization

FAST is an open-source cross-platform framework with the main goal of making it easier to do high-performance processing and visualization of medical images on heterogeneous systems utilizing both mu

Erik Smistad 315 Dec 30, 2022
基于深度强化学习的原神自动钓鱼AI

原神自动钓鱼AI由YOLOX, DQN两部分模型组成。使用迁移学习,半监督学习进行训练。 模型也包含一些使用opencv等传统数字图像处理方法实现的不可学习部分。

4.2k Jan 01, 2023