My implementation of DeepMind's Perceiver

Overview

DeepMind Perceiver (in PyTorch)

Disclaimer: This is not official and I'm not affiliated with DeepMind.

My implementation of the Perceiver: General Perception with Iterative Attention. You can read more about the model on DeepMind's website.

I trained an MNIST model which you can find in models/mnist.pkl or by using perceiver.load_mnist_model(). It gets 96.02% on the test-data.

Getting started

To run this you need PyTorch installed:

pip3 install torch

From perceiver you can import Perceiver or PerceiverLogits.

Then you can use it as such (or look in examples.ipynb):

from perceiver import Perceiver

model = Perceiver(
    input_channels, # <- How many channels in the input? E.g. 3 for RGB.
    input_shape, # <- How big is the input in the different dimensions? E.g. (28, 28) for MNIST
    fourier_bands=4, # <- How many bands should the positional encoding have?
    latents=64, # <- How many latent vectors?
    d_model=32, # <- Model dimensionality. Every pixel/token/latent vector will have this size.
    heads=8, # <- How many heads in self-attention? Cross-attention always has 1 head.
    latent_blocks=6, # <- How much latent self-attention for each cross attention with the input?
    dropout=0.1, # <- Dropout
    layers=8, # <- This will become two unique layer-blocks: layer 1 and layer 2-8 (using weight sharing).
)

The above model outputs the latents after the final layer. If you want logits instead, use the following model:

from perceiver import PerceiverLogits

model = PerceiverLogits(
    input_channels, # <- How many channels in the input? E.g. 3 for RGB.
    input_shape, # <- How big is the input in the different dimensions? E.g. (28, 28) for MNIST
    output_features, # <- How many different classes? E.g. 10 for MNIST.
    fourier_bands=4, # <- How many bands should the positional encoding have?
    latents=64, # <- How many latent vectors?
    d_model=32, # <- Model dimensionality. Every pixel/token/latent vector will have this size.
    heads=8, # <- How many heads in self-attention? Cross-attention always has 1 head.
    latent_blocks=6, # <- How much latent self-attention for each cross attention with the input?
    dropout=0.1, # <- Dropout
    layers=8, # <- This will become two unique layer-blocks: layer 1 and layer 2-8 (using weight sharing).
)

To use my pre-trained MNIST model (not very good):

from perceiver import load_mnist_model

model = load_mnist_model()

TODO:

  • Positional embedding generalized to n dimensions (with fourier features)
  • Train other models (like CIFAR-100 or something not in the image domain)
  • Type indication
  • Unit tests for components of model
  • Package
Owner
Louis Arge
Experienced full-stack developer. Self-studying machine learning.
Louis Arge
App customer segmentation cohort rfm clustering

CUSTOMER SEGMENTATION COHORT RFM CLUSTERING TỔNG QUAN VỀ HỆ THỐNG DỮ LIỆU Nên chuyển qua theme màu dark thì sẽ nhìn đẹp hơn https://customer-segmentat

hieulmsc 3 Dec 18, 2021
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

61.4k Jan 04, 2023
Capstone-Project-2 - A game program written in the Python language

Capstone-Project-2 My Pygame Game Information: Description This Pygame project i

Nhlakanipho Khulekani Hlophe 1 Jan 04, 2022
source code for 'Finding Valid Adjustments under Non-ignorability with Minimal DAG Knowledge' by A. Shah, K. Shanmugam, K. Ahuja

Source code for "Finding Valid Adjustments under Non-ignorability with Minimal DAG Knowledge" Reference: Abhin Shah, Karthikeyan Shanmugam, Kartik Ahu

Abhin Shah 1 Jun 03, 2022
Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Christopher T. Chubb 35 Dec 21, 2022
Code for "Unsupervised Source Separation via Bayesian inference in the latent domain"

LQVAE-separation Code for "Unsupervised Source Separation via Bayesian inference in the latent domain" Paper Samples GT Compressed Separated Drums GT

Michele Mancusi 30 Oct 25, 2022
[SIGGRAPH 2020] Attribute2Font: Creating Fonts You Want From Attributes

Attr2Font Introduction This is the official PyTorch implementation of the Attribute2Font: Creating Fonts You Want From Attributes. Paper: arXiv | Rese

Yue Gao 200 Dec 15, 2022
Change Detection in SAR Images Based on Multiscale Capsule Network

SAR_CD_MS_CapsNet Code for the paper "Change Detection in SAR Images Based on Multiscale Capsule Network" , IEEE Geoscience and Remote Sensing Letters

Feng Gao 21 Nov 29, 2022
TDmatch is a Python library developed to perform matching tasks in three categories:

TDmatch TDmatch is a Python library developed to perform matching tasks in three categories: Text to Data which matches tuples of a table to text docu

Naser Ahmadi 5 Aug 11, 2022
🎁 3,000,000+ Unsplash images made available for research and machine learning

The Unsplash Dataset The Unsplash Dataset is made up of over 250,000+ contributing global photographers and data sourced from hundreds of millions of

Unsplash 2k Jan 03, 2023
Data & Code for ACCENTOR Adding Chit-Chat to Enhance Task-Oriented Dialogues

ACCENTOR: Adding Chit-Chat to Enhance Task-Oriented Dialogues Overview ACCENTOR consists of the human-annotated chit-chat additions to the 23.8K dialo

Facebook Research 69 Dec 29, 2022
Applying curriculum to meta-learning for few shot classification

Curriculum Meta-Learning for Few-shot Classification We propose an adaptation of the curriculum training framework, applicable to state-of-the-art met

Stergiadis Manos 3 Oct 25, 2022
🛰️ List of earth observation companies and job sites

Earth Observation Companies & Jobs source Portals & Jobs Geospatial Geospatial jobs newsletter: ~biweekly newsletter with geospatial jobs by Ali Ahmad

Dahn 64 Dec 27, 2022
This repository contains the data and code for the paper "Diverse Text Generation via Variational Encoder-Decoder Models with Gaussian Process Priors" ([email protected])

GP-VAE This repository provides datasets and code for preprocessing, training and testing models for the paper: Diverse Text Generation via Variationa

Wanyu Du 18 Dec 29, 2022
[ICCV21] Self-Calibrating Neural Radiance Fields

Self-Calibrating Neural Radiance Fields, ICCV, 2021 Project Page | Paper | Video Author Information Yoonwoo Jeong [Google Scholar] Seokjun Ahn [Google

381 Dec 30, 2022
OCR Post Correction for Endangered Language Texts

📌 Coming soon: an update to the software including features from our paper on semi-supervised OCR post-correction, to be published in the Transaction

Shruti Rijhwani 96 Dec 31, 2022
SSL_SLAM2: Lightweight 3-D Localization and Mapping for Solid-State LiDAR (mapping and localization separated) ICRA 2021

SSL_SLAM2 Lightweight 3-D Localization and Mapping for Solid-State LiDAR (Intel Realsense L515 as an example) This repo is an extension work of SSL_SL

Wang Han 王晗 1.3k Jan 08, 2023
Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

TianYuan 27 Nov 07, 2022
The original weights of some Caffe models, ported to PyTorch.

pytorch-caffe-models This repo contains the original weights of some Caffe models, ported to PyTorch. Currently there are: GoogLeNet (Going Deeper wit

Katherine Crowson 9 Nov 04, 2022
PyContinual (An Easy and Extendible Framework for Continual Learning)

PyContinual (An Easy and Extendible Framework for Continual Learning) Easy to Use You can sumply change the baseline, backbone and task, and then read

176 Jan 05, 2023