Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Overview

Shortformer

This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the WikiText-103 dataset. Read the full paper here.

The Shortformer is a combination of two methods:

  1. Staged Training: We first train the model on short input subsequences and then train it on longer ones. This improves both train speed and evaluation perplexity.
  2. Position-Infused Attention + Caching: We cache previously computed subsequence representations and attend to them using Position-Infused Attention. Position-Infused Attention modifies the model so that position embeddings are not added to the word embeddings at the bottom of the network, but instead, they are added to the keys and queries in the attention sublayer (but not to the values). We show that PIA + caching vastly speeds up generation and also improves perplexity.

Staged training requires no modification to the original code. To see how we implemented the Position-Infused Attention and caching, click here. Implementing PIA and caching is very easy, and we've provided detailed comments in the code to explain what how we did it.

If you use this code or results from our paper, please cite:

@misc{press2020shortformer,
      title={Shortformer: Better Language Modeling using Shorter Inputs}, 
      author={Ofir Press and Noah A. Smith and Mike Lewis},
      year={2020},
      eprint={2012.15832},
}

Requirements and Installation

This repository is a fork of the Fairseq repository and so has the same requirements.

Once you've installed the dependencies, you can install this repository by running:

pip install --editable .

Preparing the data

To download and preprocess the data, run:

cd examples/language_model/
bash prepare-wikitext-103.sh
cd ../..


TEXT=examples/language_model/wikitext-103
python preprocess.py \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Train/Inference for the different models

Shortformer

Our Shortformer model takes the baseline and adds caching, Position-Infused Attention, and Staged Training.

To train the first stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints128e100/     --arch transformer_lm_wiki103     --max-update 140100 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128 --max-tokens-valid 128 --tokens-from-prev 128 --curriculum 1000 --required-batch-size-multiple 1 --save-interval 100

If your GPUs don't have enough memory to execute that command, you can set --update-freq to 2 and --max-tokens to 4608, or set --update-freq to 3 and --max-tokens to 3072 for running the model with even lower memory constraints. This chunks the batch into 2 or 3 different parts and computes each part seperately (instead of in parallel), so it uses less memory but runs slower.

After that, to train the model with the second (and final) stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir shortformer/ --restore-file checkpoints128e100/checkpoint100.pt     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

Again, you can use the update-freq/max-tokens method from above if you run out of memory.

Saved Checkpoint

If you'd like to download the Shortformer instead of training it, it is available here. Rename that file to checkpoint_best.pt if you'd like to follow the directions below.

Inference

For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1

For token-by-token generation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --sliding-inf 1 --context-window 511 --max-tokens 512

(Note that --context-window is a fairseq command and doesn't have the exact meaning that the term "context window" has in our paper.)

Shortformer (without Staged Training)

Staged training improves the perplexity of the model and makes training faster, so there's no reason not to use it, but if you would like to train the Shortformer without it, the command is

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir shortformer-no-st/      --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

For inference, use the same commands as the ones for the Shortformer (above).

Baseline with Staged Training

Our Shortformer model is fast to train and for token-by-token generation, but if speed is not an issue, we can achieve slightly better performance by just applying Staged Training to the Baevski & Auli baseline LM. This model is very slow but achieves the best perplexity.

To train the first stage, download the unmodified fairseq reporsitory and then run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints-st-128e50/     --arch transformer_lm_wiki103     --max-update 70050 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128  --required-batch-size-multiple 1 --save-interval 50

After that, to train the model with the second (and final) stage:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir st/ --restore-file checkpoints-st-128e50/checkpoint50.pt     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints

Inference

For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1

For sliding window evaluation of the validation set, with a stride of 2,560, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/checkpoint_best.pt  --sample-break-mode none --gen-subset valid   --max-sentences 1 --context-window 2560

Baseline - Baevski & Auli

To train the baseline, download the unmodified fairseq repository and then run:

python train.py --task language_modeling     data-bin/wikitext-103     --save-dir baseline/  --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints

Inference

Use the same commands as in the 'Baseline with Staged Training' inference subsection.

Owner
Ofir Press
PhD student @uwnlp
Ofir Press
Top #1 Submission code for the first https://alphamev.ai MEV competition with best AUC (0.9893) and MSE (0.0982).

alphamev-winning-submission Top #1 Submission code for the first alphamev MEV competition with best AUC (0.9893) and MSE (0.0982). The code won't run

70 Oct 29, 2022
Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising

Kai Zhang 1.2k Dec 29, 2022
TransGAN: Two Transformers Can Make One Strong GAN

[Preprint] "TransGAN: Two Transformers Can Make One Strong GAN", Yifan Jiang, Shiyu Chang, Zhangyang Wang

VITA 1.5k Jan 07, 2023
UFPR-ADMR-v2 Dataset

UFPR-ADMR-v2 Dataset The UFPR-ADMRv2 dataset contains 5,000 dial meter images obtained on-site by employees of the Energy Company of Paraná (Copel), w

Gabriel Salomon 8 Sep 29, 2022
Official Repo for Ground-aware Monocular 3D Object Detection for Autonomous Driving

Visual 3D Detection Package: This repo aims to provide flexible and reproducible visual 3D detection on KITTI dataset. We expect scripts starting from

Yuxuan Liu 305 Dec 19, 2022
This repository holds code and data for our PETS'22 article 'From "Onion Not Found" to Guard Discovery'.

From "Onion Not Found" to Guard Discovery (PETS'22) This repository holds the code and data for our PETS'22 paper titled 'From "Onion Not Found" to Gu

Lennart Oldenburg 3 May 04, 2022
Improving Compound Activity Classification via Deep Transfer and Representation Learning

Improving Compound Activity Classification via Deep Transfer and Representation Learning This repository is the official implementation of Improving C

NingLab 2 Nov 24, 2021
EsViT: Efficient self-supervised Vision Transformers

Efficient Self-Supervised Vision Transformers (EsViT) PyTorch implementation for EsViT, built with two techniques: A multi-stage Transformer architect

Microsoft 352 Dec 25, 2022
Cartoon-StyleGan2 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation

Fine-tuning StyleGAN2 for Cartoon Face Generation

Jihye Back 520 Jan 04, 2023
Weakly Supervised Learning of Rigid 3D Scene Flow

Weakly Supervised Learning of Rigid 3D Scene Flow This repository provides code and data to train and evaluate a weakly supervised method for rigid 3D

Zan Gojcic 124 Dec 27, 2022
Official repository for MixFaceNets: Extremely Efficient Face Recognition Networks

MixFaceNets This is the official repository of the paper: MixFaceNets: Extremely Efficient Face Recognition Networks. (Accepted in IJCB2021) https://i

Fadi Boutros 51 Dec 13, 2022
A PyTorch implementation of the Transformer model in "Attention is All You Need".

Attention is all you need: A Pytorch Implementation This is a PyTorch implementation of the Transformer model in "Attention is All You Need" (Ashish V

Yu-Hsiang Huang 7.1k Jan 04, 2023
[CVPR 2021] Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach

Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach This is the repo to host the dataset TextSeg and code for TexRNe

SHI Lab 174 Dec 19, 2022
This is the official code for the paper "Ad2Attack: Adaptive Adversarial Attack for Real-Time UAV Tracking".

Ad^2Attack:Adaptive Adversarial Attack on Real-Time UAV Tracking Demo video 📹 Our video on bilibili demonstrates the test results of Ad^2Attack on se

Intelligent Vision for Robotics in Complex Environment 10 Nov 07, 2022
Let Python optimize the best stop loss and take profits for your TradingView strategy.

TradingView Machine Learning TradeView is a free and open source Trading View bot written in Python. It is designed to support all major exchanges. It

Robert Roman 473 Jan 09, 2023
🌳 A Python-inspired implementation of the Optimum-Path Forest classifier.

OPFython: A Python-Inspired Optimum-Path Forest Classifier Welcome to OPFython. Note that this implementation relies purely on the standard LibOPF. Th

Gustavo Rosa 30 Jan 04, 2023
Reinforcement-learning - Repository of the class assignment questions for the course on reinforcement learning

DSE 314/614: Reinforcement Learning This repository containing reinforcement lea

Manav Mishra 4 Apr 15, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Pretraining on Dynamic Graph Neural Networks

Pretraining on Dynamic Graph Neural Networks Our article is PT-DGNN and the code is modified based on GPT-GNN Requirements python 3.6 Ubuntu 18.04.5 L

7 Dec 17, 2022
Sub-tomogram-Detection - Deep learning based model for Cyro ET Sub-tomogram-Detection

Deep learning based model for Cyro ET Sub-tomogram-Detection High degree of stru

Siddhant Kumar 2 Feb 04, 2022