Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

Overview

🦩 Flamingo - Pytorch

Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks

Install

$ pip install flamingo-pytorch

Usage

import torch
from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
    dim = 1024,
    depth = 2,
    dim_head = 64,
    heads = 8,
    num_latents = 64,    # the number of latents to shrink your media sequence to, perceiver style
    num_time_embeds = 4  # say you have 4 images maximum in your dialogue
)

medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)

Then you insert the GatedCrossAttentionBlock at different intervals in your giant language model. Your text would then attend to the perceived media from above

The recommended way to derive the media_locations boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do media_locations = text_id == media_token_id

import torch
from flamingo_pytorch import GatedCrossAttentionBlock

cross_attn = GatedCrossAttentionBlock(
    dim = 1024,
    dim_head = 64,
    heads = 8
)

text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)

media_locations = torch.randint(0, 2, (1, 512)).bool()

text = cross_attn(
    text,
    perceived,
    media_locations = media_locations
)

That's it!

Attention is all you need.

Full working example with Flamingo + PaLM 🌴 🦩 🌴

Integration with PaLM

First install vit-pytorch for the vision encoder

$ pip install vit-pytorch

Then

from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(vit, return_embeddings_only = True)

# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library

import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)

# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))

palm_logits = flamingo_palm(text)

# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

It is quite evident where this is all headed if you think beyond just images.

Inception

For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.

Citations

@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • PerceiverResampler missing some LayerNorms?

    PerceiverResampler missing some LayerNorms?

    Hey, it feels like PerceiverResampler is missing some LayerNorms? it seems to me we should layer-norm x before sending to attentions loop, and may be add layer-norm to ff(latents) + latents?

    opened by inspirit 7
  • Missing flatten op in PerceiverResampler?

    Missing flatten op in PerceiverResampler?

    Hi, It seems that Flamingo did "x_f = flatten(x_f) # [T, S, d] -> [T * S, d]" (batch size == 1) before putting x_f to attention.

    So, it should be like: medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension) perceived = perceive(medias) # (1, 64, 1024) - (batch, num latents, dimension)

    ??

    opened by zengyan-97 6
  • wrong attention masks?

    wrong attention masks?

    https://github.com/lucidrains/flamingo-pytorch/blob/44920f4191ba3c280ff84c6ebc76025656d1dab5/flamingo_pytorch/flamingo_pytorch.py#L159

    In the flamingo paper, the language features in the gated cross-attention layers only attend to the visual features from the immediate preceding image. I believe your attention masks are created in such a way that they attend to the visual features from all preceding images. Can you confirm? If so, a fix would be to simply change the '>=' to '=='.

    opened by dhansmair 4
  • zeroing out attention not working

    zeroing out attention not working

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_pytorch.py#L179

    you are not using the inplace version of the function: https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_

    so I think this line does not have an effect.

    Best, David

    opened by dhansmair 2
  • Applying parallel attn with ff to existing pretrained model?

    Applying parallel attn with ff to existing pretrained model?

    Hi - awesome work! I am trying to understand ? I couldn't find a paper - only a reference to https://github.com/kingoflolz/mesh-transformer-jax. Is this right? Am I understanding that it is bascially applying multiple operations of for qkv and ff at once? Is it possible to use this trick to modify an existing pretrained model?

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_palm.py#L90

    Many thanks in advance!

    Huu

    opened by ontocord 1
  • How to use Flamingo for VQA task?

    How to use Flamingo for VQA task?

    Hi, Thanks for sharing this awesome implementation. I am very interested in using Flamingo model for my usecase. How I can use this implementation to get inference on my dataset for VQA task? I have certain images of products and want extract some information image of product by questioning it. How I can do it ?

    Please help.

    thanks

    opened by karndeepsingh 0
  • Fine-tuning of a model

    Fine-tuning of a model

    Hi, Thank you for this great work. I want to ask how can I fine-tune this model on my dataset for some downstream task like image captioning or image classification? If it is possible for you can you also please share the code?

    opened by ans92 0
  • Need a sample ipython notebook

    Need a sample ipython notebook

    Hello, @lucidrains,

    Thank you for providing this.

    For demo purposes, could you please provide a sample demo in Jupyter notebook?🫠

    Best, LITDataScience

    opened by LITDataScience 0
Releases(0.1.2)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Face detection using deep learning.

Face Detection Docker Solution Using Faster R-CNN Dockerface is a deep learning face detector. It deploys a trained Faster R-CNN network on Caffe thro

Nataniel Ruiz 181 Dec 19, 2022
A PyTorch Implementation of Single Shot MultiBox Detector

SSD: Single Shot MultiBox Object Detector, in PyTorch A PyTorch implementation of Single Shot MultiBox Detector from the 2016 paper by Wei Liu, Dragom

Max deGroot 4.8k Jan 07, 2023
This repo. is an implementation of ACFFNet, which is accepted for in Image and Vision Computing.

Attention-Guided-Contextual-Feature-Fusion-Network-for-Salient-Object-Detection This repo. is an implementation of ACFFNet, which is accepted for in I

5 Nov 21, 2022
PyTorch Code for NeurIPS 2021 paper Anti-Backdoor Learning: Training Clean Models on Poisoned Data.

Anti-Backdoor Learning PyTorch Code for NeurIPS 2021 paper Anti-Backdoor Learning: Training Clean Models on Poisoned Data. Check the unlearning effect

Yige-Li 51 Dec 07, 2022
On Evaluation Metrics for Graph Generative Models

On Evaluation Metrics for Graph Generative Models Authors: Rylee Thompson, Boris Knyazev, Elahe Ghalebi, Jungtaek Kim, Graham Taylor This is the offic

13 Jan 07, 2023
A library that can print Python objects in human readable format

objprint A library that can print Python objects in human readable format Install pip install objprint Usage op Use op() (or objprint()) to print obj

319 Dec 25, 2022
Dilated Convolution for Semantic Image Segmentation

Multi-Scale Context Aggregation by Dilated Convolutions Introduction Properties of dilated convolution are discussed in our ICLR 2016 conference paper

Fisher Yu 764 Dec 26, 2022
1st Place Solution to ECCV-TAO-2020: Detect and Represent Any Object for Tracking

Instead, two models for appearance modeling are included, together with the open-source BAGS model and the full set of code for inference. With this code, you can achieve around 79 Oct 08, 2022

N-Person-Check-Checker-Splitter - A calculator app use to divide checks

N-Person-Check-Checker-Splitter This is my from-scratch programmed calculator ap

2 Feb 15, 2022
4K videos with annotated masks in our ICCV2021 paper 'Internal Video Inpainting by Implicit Long-range Propagation'.

Annotated 4K Videos paper | project website | code | demo video 4K videos with annotated object masks in our ICCV2021 paper: Internal Video Inpainting

Tengfei Wang 21 Nov 05, 2022
AI Face Mesh: This is a simple face mesh detection program based on Artificial intelligence.

AI Face Mesh: This is a simple face mesh detection program based on Artificial Intelligence which made with Python. It's able to detect 468 different

Md. Rakibul Islam 1 Jan 13, 2022
Tutorial in Python targeted at Epidemiologists. Will discuss the basics of analysis in Python 3

Python-for-Epidemiologists This repository is an introduction to epidemiology analyses in Python. Additionally, the tutorials for my library zEpid are

Paul Zivich 120 Nov 17, 2022
Pyramid Pooling Transformer for Scene Understanding

Pyramid Pooling Transformer for Scene Understanding Requirements: torch 1.6+ torchvision 0.7.0 timm==0.3.2 Validated on torch 1.6.0, torchvision 0.7.0

Yu-Huan Wu 119 Dec 29, 2022
AI-Fitness-Tracker - AI Fitness Tracker With Python

AI-Fitness-Tracker We have build a AI based Fitness Tracker using OpenCV and Pyt

Sharvari Mangale 5 Feb 09, 2022
Migration of Edge-based Distributed Federated Learning

FedFly: Towards Migration in Edge-based Distributed Federated Learning About the research Due to mobility, a device participating in Federated Learnin

qub-blesson 11 Nov 13, 2022
Deep Unsupervised 3D SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment.

(ACMMM 2021 Oral) SfM Face Reconstruction Based on Massive Landmark Bundle Adjustment This repository shows two tasks: Face landmark detection and Fac

BoomStar 51 Dec 13, 2022
A 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for no

Kakao Brain 604 Dec 14, 2022
Framework for estimating the structures and parameters of Bayesian networks (DAGs) at per-sample resolution

Sample-specific Bayesian Networks A framework for estimating the structures and parameters of Bayesian networks (DAGs) at per-sample or per-patient re

Caleb Ellington 1 Sep 23, 2022
Reporting and Visualization for Hazardous Events

Reporting and Visualization for Hazardous Events

Jv Kyle Eclarin 2 Oct 03, 2021
PyTorch implementation of UNet++ (Nested U-Net).

PyTorch implementation of UNet++ (Nested U-Net) This repository contains code for a image segmentation model based on UNet++: A Nested U-Net Architect

4ui_iurz1 642 Jan 04, 2023