Pytorch implementation of Compressive Transformers, from Deepmind

Overview

Compressive Transformer in Pytorch

Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.

PyPI version

Install

$ pip install compressive_transformer_pytorch

Usage

import torch
from compressive_transformer_pytorch import CompressiveTransformer

model = CompressiveTransformer(
    num_tokens = 20000,
    emb_dim = 128,                 # embedding dimensions, embedding factorization from Albert paper
    dim = 512,
    depth = 12,
    seq_len = 1024,
    mem_len = 1024,                # memory length
    cmem_len = 1024 // 4,          # compressed memory buffer length
    cmem_ratio = 4,                # compressed memory ratio, 4 was recommended in paper
    reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss
    attn_dropout = 0.1,            # dropout post-attention
    ff_dropout = 0.1,              # dropout in feedforward
    attn_layer_dropout = 0.1,      # dropout for attention layer output
    gru_gated_residual = True,     # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper
    mogrify_gru = False,           # experimental feature that adds a mogrifier for the update and residual before gating by the GRU
    memory_layers = range(6, 13),  # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper
    ff_glu = True                  # use GLU variant for feedforward
)

inputs = torch.randint(0, 256, (1, 2048))
masks = torch.ones_like(inputs).bool()

segments = inputs.reshape(1, -1, 1024).transpose(0, 1)
masks = masks.reshape(1, -1, 1024).transpose(0, 1)

logits, memories, aux_loss = model(segments[0], mask = masks[0])
logits,        _, aux_loss = model(segments[1], mask = masks[1], memories = memories)

# memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)

When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.

import torch
from compressive_transformer_pytorch import CompressiveTransformer
from compressive_transformer_pytorch import AutoregressiveWrapper

model = CompressiveTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    mem_len = 1024,
    cmem_len = 256,
    cmem_ratio = 4,
    memory_layers = [5,6]
).cuda()

model = AutoregressiveWrapper(model)

inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()

for loss, aux_loss, _ in model(inputs, return_loss = True):
    (loss + aux_loss).backward()
    # optimizer step and zero grad

# ... after much training ...

# generation is also greatly simplified and automated away
# just pass in the prime, which can be 1 start token or any length
# all is taken care of for you

prime = torch.ones(1, 1).cuda()  # assume 1 is start token
sample = model.generate(prime, 4096)

Citations

@misc{rae2019compressive,
    title   = {Compressive Transformers for Long-Range Sequence Modelling},
    author  = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap},
    year    = {2019},
    eprint  = {1911.05507},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{parisotto2019stabilizing,
    title   = {Stabilizing Transformers for Reinforcement Learning},
    author  = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell},
    year    = {2019},
    eprint  = {1910.06764},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{rae-razavi-2020-transformers,
    title   = "Do Transformers Need Deep Long-Range Memory?",
    author  = "Rae, Jack  and
      Razavi, Ali",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month   = jul,
    year    = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url     = "https://www.aclweb.org/anthology/2020.acl-main.672"
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • aux_loss does not update any weigth

    aux_loss does not update any weigth

    Hi lucidrains, thanks for your implementation, it is very elegant and helped me a lot with my disertation. Anyway I can't understand a particular: it seems like aux_loss is not related to any weight because of the detaching in the last part of the SelfAttention layer. With the following code, for example, I get that there is no layer optimized by aux_loss:

    import torch
    from compressive_transformer_pytorch import CompressiveTransformer
    from compressive_transformer_pytorch import AutoregressiveWrapper
    
    model = CompressiveTransformer(
        num_tokens = 20000,
        dim = 512,
        depth = 6,
        seq_len = 1024,
        mem_len = 1024,
        cmem_len = 256,
        cmem_ratio = 4,
        memory_layers = [5,6]
    ).cuda()
    
    model = AutoregressiveWrapper(model)
    
    inputs = torch.randint(0, 20000, (1, 1024)).cuda()
    
    optimizer = torch.optim.Adam(model.parameters())
    
    for loss, aux_loss, _ in model(inputs, return_loss = True):
        optimizer.zero_grad(set_to_none=True)
        loss.backward(retain_graph=True)
        print("OPTIMIZED BY LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
        optimizer.zero_grad(set_to_none=True)
        aux_loss.backward(retain_graph=True)
        print("OPTIMIZED BY AUX_LOSS ************************************************************")
        for module_name, parameter in model.named_parameters():
            if parameter.grad is not None:
                print(module_name)
    

    I am not expert about the PyTorch mechanisms, so maybe I am getting something wrong. Again thank you

    opened by StefanoBerti 3
  • How to use this for speech/audio generation?

    How to use this for speech/audio generation?

    Great work Phil! In their paper, the authors applied this model to speech modeling, how would you advise on what should I change to use for speech. Because in speech, the data are signals, we do not have num_tokens, nor do we have emb_dim. Our data input is simply, [batch, channel, time]. Any advice?

    opened by jinglescode 3
  • [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    [Error] NameError: name 'math' is not defined in compressive_transformer_pytorch.py

    hello, I run code "examples/enwik8_simple" now, and I got error as follows:

    train.py:65: DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) training: 0%| | 0/100000 [00:00<?, ?it/s] Traceback (most recent call last): File "train.py", line 101, in <module> for mlm_loss, aux_loss, is_last in model(next(train_loader), max_batch_size = MAX_BATCH_SIZE, return_loss = True): File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/autoregressive_wrapper.py", line 151, in forward logits, new_mem, aux_loss = self.net(xi_seg_b, mask = mask_seg_b, memories = mem, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 338, in f orward x, = ff(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 84, in fo rward out = self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 106, in f orward return self.fn(x, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 140, in f orward x = self.act(x) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 547, in __call__ result = self.forward(*input, **kwargs) File "/home/donghyun/donghyun/anaconda3/envs/pytorch/lib/python3.7/site-packages/compressive_transformer_pytorch/compressive_transformer_pytorch.py", line 122, in f orward return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) NameError: name 'math' is not defined

    so I inserted "import math" into compressive_transformer_pytorch.py file and it work well. I hope you modify compressive_transformer_pytorch.py code.

    opened by dinoSpeech 3
  • Training enwik8 but loss fail to converge

    Training enwik8 but loss fail to converge

    Hi lucidrains, I appreciate your implementation very much, and it helps me a lot with understanding compressive transformer. However when I tried running your code (enwik8 and exactly the same code in github), and the loss failed to converge after 100 epochs. Is this in expectation ? Or should I do other additional effort to improve, for example tokenizing the raw data in enwik8 and remove all the xml tags ? The figure below is the training and validation loss while I train enwik8 with the same code as in github.

    截圖 2021-03-26 下午5 40 16 截圖 2021-03-26 下午5 41 22

    Thanks and look forward to your reply!

    opened by KaiPoChang 2
  • Details about text generation

    Details about text generation

    Hi lucidrains, Thank you for your excellent code. I am curious about the generation scripts. Could you tell me how to generate text with the compressive transformer? Because it has the compressive memory, maybe we cannot use the current predicted word as the input for the next generation (input length ==1). In addition, if the prompt has 100 words and we use tokens [0:100], tokens[1:101], tokens[2:102]... as the input for the following timesteps, the tokens[1:100] may overlap with the memory, because the memory already contains hidden states for tokens[1:100].

    I would be very appeciated if you can provide the generation scripts!

    Thank you

    opened by theseventhflow 3
  • Links to original tf code - fyi

    Links to original tf code - fyi

    After reading deepmind blog post I was looking forward to downloading model but no luck. Looking forward to your implementation.

    You may be aware of this post and link but if not this is the coder's original tf implementation. Hope it helps.

    Copy of comment to original model request:

    https://github.com/huggingface/transformers/issues/4688

    Interested in model weights too but currently not available. Author does mention releasing tf code here:

    https://news.ycombinator.com/item?id=22290227

    Requires tf 1.15+ and deepmind/sonnet ver 1.36. Link to python script here:

    https://github.com/deepmind/sonnet/blob/cd5b5fa48e15e4d020f744968f5209949ebe750f/sonnet/python/modules/nets/transformer.py#L915

    Have tried running as-is but doesn't appear to have options for training on custom data as per the paper and available data sets.

    opened by GenTxt 8
Releases(0.4.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
NNR conformation conditional and global probabilities estimation and analysis in peptides or proteins fragments

NNR and global probabilities estimation and analysis in peptides or protein fragments This module calculates global and NNR conformation dependent pro

0 Jul 15, 2021
A PyTorch implementation of "Multi-Scale Contrastive Siamese Networks for Self-Supervised Graph Representation Learning", IJCAI-21

MERIT A PyTorch implementation of our IJCAI-21 paper Multi-Scale Contrastive Siamese Networks for Self-Supervised Graph Representation Learning. Depen

Graph Analysis & Deep Learning Laboratory, GRAND 32 Jan 02, 2023
Implementation for paper: Self-Regulation for Semantic Segmentation

Self-Regulation for Semantic Segmentation This is the PyTorch implementation for paper Self-Regulation for Semantic Segmentation, ICCV 2021. Citing SR

Dong ZHANG 30 Nov 21, 2022
Meta graph convolutional neural network-assisted resilient swarm communications

Resilient UAV Swarm Communications with Graph Convolutional Neural Network This repository contains the source codes of Resilient UAV Swarm Communicat

62 Dec 06, 2022
Implementation of GGB color space

GGB Color Space This package is implementation of GGB color space from Development of a Robust Algorithm for Detection of Nuclei and Classification of

Resha Dwika Hefni Al-Fahsi 2 Oct 06, 2021
CharacterGAN: Few-Shot Keypoint Character Animation and Reposing

CharacterGAN Implementation of the paper "CharacterGAN: Few-Shot Keypoint Character Animation and Reposing" by Tobias Hinz, Matthew Fisher, Oliver Wan

Tobias Hinz 181 Dec 27, 2022
Pytorch implementation of our paper accepted by NeurIPS 2021 -- Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme

Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme (NeurIPS2021) (Link) Overview Prerequisites Linu

Shaojie Li 34 Mar 31, 2022
Road Crack Detection Using Deep Learning Methods

Road-Crack-Detection-Using-Deep-Learning-Methods This is my Diploma Thesis ¨Road Crack Detection Using Deep Learning Methods¨ under the supervision of

Aggelos Katsaliros 3 May 03, 2022
A simple algorithm for extracting tree height in sparse scene from point cloud data.

TREE HEIGHT EXTRACTION IN SPARSE SCENES BASED ON UAV REMOTE SENSING This is the offical python implementation of the paper "Tree Height Extraction in

6 Oct 28, 2022
Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (CVAMD)

Is it Time to Replace CNNs with Transformers for Medical Images? Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (C

Christos Matsoukas 80 Dec 27, 2022
Band-Adaptive Spectral-Spatial Feature Learning Neural Network for Hyperspectral Image Classification

Band-Adaptive Spectral-Spatial Feature Learning Neural Network for Hyperspectral Image Classification

258 Dec 29, 2022
An Open-Source Toolkit for Prompt-Learning.

An Open-Source Framework for Prompt-learning. Overview • Installation • How To Use • Docs • Paper • Citation • What's New? Nov 2021: Now we have relea

THUNLP 2.3k Jan 07, 2023
🛠 All-in-one web-based IDE specialized for machine learning and data science.

All-in-one web-based development environment for machine learning Getting Started • Features & Screenshots • Support • Report a Bug • FAQ • Known Issu

Machine Learning Tooling 2.9k Jan 09, 2023
This program creates a formatted excel file which highlights the undervalued stock according to Graham's number.

Over-and-Undervalued-Stocks Of Nepse Using Graham's Number Scrap the latest data using different websites and creates a formatted excel file that high

6 May 03, 2022
STBP is a way to train SNN with datasets by Backward propagation.

Spiking neural network (SNN), compared with depth neural network (DNN), has faster processing speed, lower energy consumption and more biological interpretability, which is expected to approach Stron

Ling Zhang 18 Dec 09, 2022
Pytorch implementation for "Density-aware Chamfer Distance as a Comprehensive Metric for Point Cloud Completion" (NeurIPS 2021)

Density-aware Chamfer Distance This repository contains the official PyTorch implementation of our paper: Density-aware Chamfer Distance as a Comprehe

Tong WU 93 Dec 15, 2022
Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021]

Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021] Paper: https://arxiv.org/abs/2104.11208 Introduction Despite the significa

76 Dec 07, 2022
Buffon’s needle: one of the oldest problems in geometric probability

Buffon-s-Needle Buffon’s needle is one of the oldest problems in geometric proba

3 Feb 18, 2022
AdaSpeech 2: Adaptive Text to Speech with Untranscribed Data

AdaSpeech 2: Adaptive Text to Speech with Untranscribed Data [WIP] Unofficial Pytorch implementation of AdaSpeech 2. Requirements : All code written i

Rishikesh (ऋषिकेश) 63 Dec 28, 2022
This repository contains code accompanying the paper "An End-to-End Chinese Text Normalization Model based on Rule-Guided Flat-Lattice Transformer"

FlatTN This repository contains code accompanying the paper "An End-to-End Chinese Text Normalization Model based on Rule-Guided Flat-Lattice Transfor

THUHCSI 74 Nov 28, 2022