Source code for NAACL 2021 paper "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference"

Related tags

Deep LearningTR-BERT
Overview

TR-BERT

Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference".

model

The code is based on huggaface's transformers. Thanks to them! We will release all the source code in the future.

Requirement

Install dependencies and apex:

pip3 install -r requirement.txt
pip3 install --editable transformers

Pretrained models

Download the DistilBERT-3layer and BERT-1024 from Google Drive/Tsinghua Cloud.

Classfication

Download the IMDB, Yelp, 20News datasets from Google Drive/Tsinghua Cloud.

Download the Hyperpartisan dataset, and randomly split it into train/dev/test set: python3 split_hyperpartisan.py

Train BERT/DistilBERT Model

Use flag --do train:

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path bert-base-uncased --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 16 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 5  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval  --evaluate_during_training  --do_train

where task_name can be set as imdb/yelp_f/20news/hyperpartisan for different tasks and model type can be set as bert/distilbert for different models.

Compute Graident for Residual Strategy

Use flag --do_eval_grad.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_grad

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the policy network solely

Start from the checkpoint from the task-specific fine-tuned model. Change model_type from bert to autobert, and run with flag --do_train --train_rl:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1  --do_lower_case  --do_train --train_rl --alpha 1 --guide_rate 0.5

where alpha is the harmonic coefficient for the length punishment and guide_rate is the proportion of imitation learning steps. model_type can be set as autobert/distilautobert for applying token reduction to BERT/DistilBERT.

Compute Logits for Knowledge Distilation

Use flag --do_eval_logits.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_logits

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the whole network with both the task-specifc objective and RL objective

Start from the checkpoint from --train_rl model and run with flag --do_train --train_both --train_teacher:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1 --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_train --train_both --train_teacher --alpha 1

Evaluate

Use flag --do_eval:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1_both  --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_eval --eval_all_checkpoints

When the batch size is more than 1 in evaluating, we will remain the same number of tokens for each instance in the same batch.

Initialize

For IMDB dataset, we find that when we directly initialize the selector with heuristic objective before train the policy network solely, we can get a bit better performance. For other datasets, this step makes little change. Run this step with flag --do_train --train_init:

python3 trans_imdb_rank.py
python3 run_classification.py  --task_name imdb  --model_type initbert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/bert_init  --do_lower_case  --do_train --train_init 

Question Answering

Download the SQuAD 2.0 dataset.

Download the MRQA dataset with our split] from Google Drive/Tsinghua Cloud.

Download the HotpotQA dataset from the Transformer-XH repository, where paragraphs are retrieved for each question according to TF-IDF, entity linking and hyperlink and re-ranked by BERT re-ranker.

Download the TriviaQA dataset, where paragraphs are re-rank by the linear passage re-ranker in DocQA.

Download the WikiHop dataset.

The whole training progress of question answer models is similiar to text classfication models, with flags --do_train, --do_train --train_rl, --do_train --train_both --train_teacher in turn. The codes of each dataset:

SQuAD: run_squad.py with flag version_2_with_negative

NewsQA / NaturalQA: run_mrqa.py

RACE: run_race_classify.py

HotpotQA: run_hotpotqa.py

TriviaQA: run_triviaqa.py

WikiHop: run_wikihop.py

Harmonic Coefficient Lambda

The example harmonic coefficients are shown as follows:

Dataset train_rl train_both
SQuAD 2.0 5 5
NewsQA 3 5
NaturalQA 2 2
RACE 0.5 0.1
YELP.F 2 0.5
20News 1 1
IMDB 1 1
HotpotQA 0.1 4
TriviaQA 0.5 1
Hyperparisan 0.01 0.01

Cite

If you use the code, please cite this paper:

@inproceedings{ye2021trbert,
  title={TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference},
  author={Deming Ye, Yankai Lin, Yufei Huang, Maosong Sun},
  booktitle={Proceedings of NAACL 2021},
  year={2021}
}
Owner
THUNLP
Natural Language Processing Lab at Tsinghua University
THUNLP
Python with OpenCV - MediaPip Framework Hand Detection

Python HandDetection Python with OpenCV - MediaPip Framework Hand Detection Explore the docs » Contact Me About The Project It is a Computer vision pa

2 Jan 07, 2022
Real-CUGAN - Real Cascade U-Nets for Anime Image Super Resolution

Real Cascade U-Nets for Anime Image Super Resolution 中文 | English 🔥 Real-CUGAN

tarsin 111 Dec 28, 2022
VarCLR: Variable Semantic Representation Pre-training via Contrastive Learning

    VarCLR: Variable Representation Pre-training via Contrastive Learning New: Paper accepted by ICSE 2022. Preprint at arXiv! This repository contain

squaresLab 32 Oct 24, 2022
A Parameter-free Deep Embedded Clustering Method for Single-cell RNA-seq Data

A Parameter-free Deep Embedded Clustering Method for Single-cell RNA-seq Data Overview Clustering analysis is widely utilized in single-cell RNA-seque

AI-Biomed @NSCC-gz 3 May 08, 2022
A fast, dataset-agnostic, deep visual search engine for digital art history

imgs.ai imgs.ai is a fast, dataset-agnostic, deep visual search engine for digital art history based on neural network embeddings. It utilizes modern

Fabian Offert 5 Dec 14, 2022
Scalable Multi-Agent Reinforcement Learning

Scalable Multi-Agent Reinforcement Learning 1. Featured algorithms: Value Function Factorization with Variable Agent Sub-Teams (VAST) [1] 2. Implement

3 Aug 02, 2022
A new play-and-plug method of controlling an existing generative model with conditioning attributes and their compositions.

Viz-It Data Visualizer Web-Application If I ask you where most of the data wrangler looses their time ? It is Data Overview and EDA. Presenting "Viz-I

NVIDIA Research Projects 66 Jan 01, 2023
CLIP2Video: Mastering Video-Text Retrieval via Image CLIP

CLIP2Video: Mastering Video-Text Retrieval via Image CLIP The implementation of paper CLIP2Video: Mastering Video-Text Retrieval via Image CLIP. CLIP2

168 Dec 29, 2022
Fast and Context-Aware Framework for Space-Time Video Super-Resolution (VCIP 2021)

Fast and Context-Aware Framework for Space-Time Video Super-Resolution Preparation Dependencies PyTorch 1.2.0 CUDA 10.0 DCNv2 cd model/DCNv2 bash make

Xueheng Zhang 1 Mar 29, 2022
This is the pytorch re-implementation of the IterNorm

IterNorm-pytorch Pytorch reimplementation of the IterNorm methods, which is described in the following paper: Iterative Normalization: Beyond Standard

Lei Huang 32 Dec 27, 2022
A Real-Time-Strategy game for Deep Learning research

Description DeepRTS is a high-performance Real-TIme strategy game for Reinforcement Learning research. It is written in C++ for performance, but provi

Centre for Artificial Intelligence Research (CAIR) 156 Dec 19, 2022
Adaptive, interpretable wavelets across domains (NeurIPS 2021)

Adaptive wavelets Wavelets which adapt given data (and optionally a pre-trained model). This yields models which are faster, more compressible, and mo

Yu Group 50 Dec 16, 2022
Efficient electromagnetic solver based on rigorous coupled-wave analysis for 3D and 2D multi-layered structures with in-plane periodicity

Efficient electromagnetic solver based on rigorous coupled-wave analysis for 3D and 2D multi-layered structures with in-plane periodicity, such as gratings, photonic-crystal slabs, metasurfaces, surf

Alex Song 17 Dec 19, 2022
Rule Based Classification Project For Python

Rule-Based-Classification-Project (ENG) Business Problem: A game company wants to create new level-based customer definitions (personas) by using some

Deniz Can OĞUZ 4 Oct 29, 2022
Ratatoskr: Worcester Tech's conference scheduling system

Ratatoskr: Worcester Tech's conference scheduling system In Norse mythology, Ratatoskr is a squirrel who runs up and down the world tree Yggdrasil to

4 Dec 22, 2022
Automatic deep learning for image classification.

AutoDL AutoDL automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications. With just a few line

wenqi 2 Oct 12, 2022
《Dual-Resolution Correspondence Network》(NeurIPS 2020)

Dual-Resolution Correspondence Network Dual-Resolution Correspondence Network, NeurIPS 2020 Dependency All dependencies are included in asset/dualrcne

Active Vision Laboratory 45 Nov 21, 2022
An implementation of the Contrast Predictive Coding (CPC) method to train audio features in an unsupervised fashion.

CPC_audio This code implements the Contrast Predictive Coding algorithm on audio data, as described in the paper Unsupervised Pretraining Transfers we

8 Nov 14, 2022
The AWS Certified SysOps Administrator

The AWS Certified SysOps Administrator – Associate (SOA-C02) exam is intended for system administrators in a cloud operations role who have at least 1 year of hands-on experience with deployment, man

Aiden Pearce 32 Dec 11, 2022
Learn other languages ​​using artificial intelligence with python.

The main idea of ​​the project is to facilitate the learning of other languages. We created a simple AI that will interact with you. Just ask questions that if she knows, she will answer.

Pedro Rodrigues 2 Jun 07, 2022