Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Overview

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding

PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) paper.

Method Description

Distilled Sentence Embedding (DSE) distills knowledge from a finetuned state-of-the-art transformer model (BERT) to create high quality sentence embeddings. For a complete description, as well as implementation details and hyperparameters, please refer to the paper.

Usage

Follow the instructions below in order to run the training procedure of the Distilled Sentence Embedding (DSE) method. The python scripts below can be run with the -h parameter to get more information.

1. Install Requirements

Tested with Python 3.6+.

pip install -r requirements.txt

2. Download GLUE Datasets

Run the download_glue_data.py script to download the GLUE datasets.

python download_glue_data.py

3. Finetune BERT on a Specific Task

Finetune a standard BERT model on a specific task (e.g., MRPC, MNLI, etc.). Below is an example for the MRPC dataset.

python finetune_bert.py \
--bert_model bert-large-uncased-whole-word-masking \
--task_name mrpc \
--do_train \
--do_eval \
--do_lower_case \
--data_dir glue_data/MRPC \
--max_seq_length 128 \
--train_batch_size 32 \
--gradient_accumulation_steps 2 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir outputs/large_uncased_finetuned_mrpc \
--overwrite_output_dir \
--no_parallel

Note: For our code to work with the AllNLI dataset (a combination of the MNLI and SNLI datasets), you simply need to create a folder where the downloaded GLUE datasets are and copy the MNLI and SNLI datasets into it.

4. Create Logits for Distillation from the Finetuned BERT

Execute the following command to create the logits which will be used for the distillation training objective. Note that the bert_checkpoint_dir parameter has to match the output_dir from the previous command.

python run_distillation_logits_creator.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--do_lower_case \
--train_features_path glue_data/MRPC/train_bert-large-uncased-whole-word-masking_128_mrpc \
--bert_checkpoint_dir outputs/large_uncased_finetuned_mrpc

5. Train the DSE Model using the Finetuned BERT Logits

Train the DSE model using the extracted logits. Notice that the distillation_logits_path parameter needs to be changed according to the task.

python dse_train_runner.py \
--task_name mrpc \
--data_dir glue_data/MRPC \
--distillation_logits_path outputs/logits/large_uncased_finetuned_mrpc_logits.pt \
--do_lower_case \
--file_log \
--epochs 8 \
--store_checkpoints \
--fc_dims 512 1

Important Notes:

  • To store checkpoints for the model make sure that the --store_checkpoints flag is passed as shown above.
  • The fc_dims parameter accepts a list of space separated integers, and is the dimensions of the fully connected classifier that is put on top of the extracted features from the Siamese DSE model. The output dimension (in this case 1) needs to be changed according to the wanted output dimensionality. For example, for the MNLI dataset the fc_dims parameter should be 512 3 since it is a 3 class classification task.

6. Loading the Trained DSE Model

During training, checkpoints of the Trainer object which contains the model will be saved. You can load these checkpoints and extract the model state dictionary from them. Then you can load the state into a created DSESiameseClassifier model. The load_dse_checkpoint_example.py script contains an example of how to do that.

To load the model trained with the example commands above you can use:

python load_dse_checkpoint_example.py \
--task_name mrpc \
--trainer_checkpoint <path_to_saved_checkpoint> \
--do_lower_case \
--fc_dims 512 1

Acknowledgments

Citation

If you find this code useful, please cite the following paper:

@inproceedings{barkan2020scalable,
  title={Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding},
  author={Barkan, Oren and Razin, Noam and Malkiel, Itzik and Katz, Ori and Caciularu, Avi and Koenigstein, Noam},
  booktitle={AAAI Conference on Artificial Intelligence (AAAI)},
  year={2020}
}
Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
8-week curriculum for AI Builders

curriculum 8-week curriculum for AI Builders สารบัญ บทที่ 1 - Machine Learning คืออะไร บทที่ 2 - ชุดข้อมูลมหัศจรรย์และถิ่นที่อยู่ บทที่ 3 - Stochastic

AI Builders 134 Jan 03, 2023
Release of the ConditionalQA dataset

ConditionalQA Datasets accompanying the paper ConditionalQA: A Complex Reading Comprehension Dataset with Conditional Answers. Disclaimer This dataset

14 Oct 17, 2022
SCU OlympicsRunning Baseline

Competition 1v1 running Environment check details in Jidi Competition RLChina2021智能体竞赛 做出的修改: 奖励重塑:修改了环境,重新设置了奖励的分配,使得奖励组成不只有零和博弈,还有探索环境的奖励。 算法微调:修改了官

ZiSeoi Wong 2 Nov 23, 2021
A flexible and extensible framework for gait recognition.

A flexible and extensible framework for gait recognition. You can focus on designing your own models and comparing with state-of-the-arts easily with the help of OpenGait.

Shiqi Yu 335 Dec 22, 2022
SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals

SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals Abstract Sleep apnea (SA) is a common slee

9 Dec 21, 2022
This is the pytorch implementation for the paper: *Learning Accurate Performance Predictors for Ultrafast Automated Model Compression*, which is in submission to TPAMI

SeerNet This is the pytorch implementation for the paper: Learning Accurate Performance Predictors for Ultrafast Automated Model Compression, which is

3 May 01, 2022
Rapid experimentation and scaling of deep learning models on molecular and crystal graphs.

LitMatter A template for rapid experimentation and scaling deep learning models on molecular and crystal graphs. How to use Clone this repository and

Nathan Frey 32 Dec 06, 2022
Walk with fastai

Shield: This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. Walk with fastai What is this p

Walk with fastai 124 Dec 10, 2022
This repository contains the code for the paper in EMNLP 2021: "HRKD: Hierarchical Relational Knowledge Distillation for Cross-domain Language Model Compression".

HRKD: Hierarchical Relational Knowledge Distillation for Cross-domain Language Model Compression This repository contains the code for the paper in EM

Chenhe Dong 2 Mar 24, 2022
Quadruped-command-tracking-controller - Quadruped command tracking controller (flat terrain)

Quadruped command tracking controller (flat terrain) Prepare Install RAISIM link

Yunho Kim 4 Oct 20, 2022
Simultaneous Demand Prediction and Planning

Simultaneous Demand Prediction and Planning Dependencies Python packages: Pytorch, scikit-learn, Pandas, Numpy, PyYAML Data POI: data/poi Road network

Yizong Wang 1 Sep 01, 2022
Fairness Metrics: All you need to know

Fairness Metrics: All you need to know Testing machine learning software for ethical bias has become a pressing current concern. Recent research has p

Anonymous2020 1 Jan 17, 2022
Code for the paper: Sketch Your Own GAN

Sketch Your Own GAN Project | Paper | Youtube Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to match the in

677 Dec 28, 2022
zeus is a Python implementation of the Ensemble Slice Sampling method.

zeus is a Python implementation of the Ensemble Slice Sampling method. Fast & Robust Bayesian Inference, Efficient Markov Chain Monte Carlo (MCMC), Bl

Minas Karamanis 197 Dec 04, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022
(ICCV 2021) PyTorch implementation of Paper "Progressive Correspondence Pruning by Consensus Learning"

CLNet (ICCV 2021) PyTorch implementation of Paper "Progressive Correspondence Pruning by Consensus Learning" [project page] [paper] Citing CLNet If yo

Chen Zhao 22 Aug 26, 2022
One-line your code easily but still with the fun of doing so!

One-liner-iser One-line your code easily but still with the fun of doing so! Have YOU ever wanted to write one-line Python code, but don't have the sa

5 May 04, 2022
This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation

This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation. Yolov5 is used to detect fire and smoke and unet is used to segment fire.

7 Jan 08, 2023
rliable is an open-source Python library for reliable evaluation, even with a handful of runs, on reinforcement learning and machine learnings benchmarks.

Open-source library for reliable evaluation on reinforcement learning and machine learning benchmarks. See NeurIPS 2021 oral for details.

Google Research 529 Jan 01, 2023