The Codebase for Causal Distillation for Language Models.

Overview

Python 3.7 License CC BY-NC

Causal Distillation for Language Models

Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman

The is an implementation of our preprint Causal Distillation for Language Models. The standard approach to distillation trains a student model against two objectives: a task-specific objective (e.g., language modeling) and an imitation objective that encourages the hidden states of the student model to be similar to those of the larger teacher model. In this paper, we show that it is beneficial to augment distillation with a third objective that encourages the student to imitate the causal computation process of the teacher through interchange intervention training (IIT).

We fork our main codebase from the Huggingface Distillation Interface.

Release Notes

12/02/2021 Our paper on Interchange Intervention Training (IIT) is released! Read this more formal definition of the method.
12/06/2021 Released the causal distillation codebase with the preprint.
12/06/2021 Released evaluation results on distilled tiny-BERT (3 layers) with the Wiki-Text 103M dataset.
⬜️ Released evaluation results on causal-distilled tiny-BERT (3 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released evaluation results on causal-distilled BERT (6 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released more ablation studies.
⬜️ Released causal-distilled tiny-BERT (3 layers) model files.
⬜️ Released causal-distilled BERT (6 layers) model files.

If you experience any issues or have suggestions, please contact me either thourgh the issues page or at [email protected].

Benchmark Results

Here are the results on the dev sets of GLUE:

Model Average-score CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B WNLI
DistilBERT (3 layers) 67.81 22.8 71.6 78.2 82.1 84.3 55.4 86.5 56.7 24.2
CausalBERT (3 layers) 69.71 25.0 72.9 78.6 83.1 84.9 55.4 86.9 66.5 21.5

1 Average-score computed without WNLI.

Main Contents

Citation

If you use this repository, please cite the following two papers: paper for interchange intervention training, and paper for the our distillation method.

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{wu-etal-2021-distill,
        title={Causal Distillation for Language Models}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
        year={2021},
        eprint={2112.02505},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.9.0
  • Transfermers Version: 4.11.3
  • Datasets Version: Version: 1.8.0
  • We have performed experiments on Titan V GPU. We assume 12GB of GPU memory (more memory can expedite training).
  • Since we build our codebase off the Huggingface Distillation Interface, please review their doc for requirements.

Dataset

Following the Huggingface Distillation Interface, we need to pre-process the datasets before we do distillation. You can refer to their repo for details. We adapt their pre-processing scripts, and update with a few improvements. For example, we can now binarize datasets from the Dataset Hub from huggingface directly.

# preprocessing from disk
python script/binarized_data.py \
--file_path ../../bert-mid-tuning/data-files/wikitext-15M \
--split train \
--field_name text \
--max_parsing_example 1000 \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file ./data/binarized_text

# preprocessing from huggingface.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext+bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

# helper scripts to combine two binarized data files
python scripts/data_combinator.py \
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle \
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle \
--split train \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text

# multiprocessing preprocessor.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/ \
--fast_process \
--preprocessing_num_workers 48

After you get the datasets ready, you need to generate token counts as well.

python scripts/token_counts.py \
--data_file data/binarized_text.train.bert-base-uncased.pickle \
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle \
--vocab_size 30522

Distillation

Before training, we recommand you to initialize your student model with weights extracted from the teacher model.

python scripts/extract_distilbert.py \
--model_type bert \
--model_name bert-base-uncased \
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--num_layers 3

Now, here is an example for you to distill with our causal distillation objective or without,

CUDA_VISIBLE_DEVICES=9,4 python causal_train.py \
--force \
--n_gpu 2 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 50 \
--n_epoch 3 \
--batch_size 5

CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py \
--force \
--n_gpu 4 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --alpha_causal 0.00 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 124 \
--n_epoch 6 \
--batch_size 4

Note that you can simply turn our causal distillation objective on/off through setting the arguments.

Evaluation

After you get your distilled models, you need to fine-tune them and evaluate them with downstream tasks. We provide you all the scripts you need to run.

MLM Evaluation

CUDA_VISIBLE_DEVICES=5 python run_mlm.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-15M_seed_42_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer/ \
--dataset_dir ../../bert-mid-tuning/data-files/wikitext-15M/ \
--tokenizer_name bert-base-uncased \
--do_eval \
--output_dir /tmp/test-mlm \
--cache_dir ./distill_cache/

GLUE Evaluation

CUDA_VISIBLE_DEVICES=5,7,8,9 python run_glue.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle/ \
--tokenizer_name bert-base-uncased \
--task_name sst2 \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir ./results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

CoNLL Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_ner.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name conll2003 \
--do_train \
--do_eval \
--output_dir ./ner_results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

SQuAD Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_qa.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--save_total_limit 1 \
--output_dir ./qa_results/
Keras attention models including botnet,CoaT,CoAtNet,CMT,cotnet,halonet,resnest,resnext,resnetd,volo,mlp-mixer,resmlp,gmlp,levit

Keras_cv_attention_models Keras_cv_attention_models Usage Basic Usage Layers Model surgery AotNet ResNetD ResNeXt ResNetQ BotNet VOLO ResNeSt HaloNet

319 Dec 28, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
A PaddlePaddle implementation of STGCN with a few modifications in the model architecture in order to forecast traffic jam.

About This repository contains the code of a PaddlePaddle implementation of STGCN based on the paper Spatio-Temporal Graph Convolutional Networks: A D

Tianjian Li 1 Jan 11, 2022
Optimize Trading Strategies Using Freqtrade

Optimize trading strategy using Freqtrade Short demo on building, testing and optimizing a trading strategy using Freqtrade. The DevBootstrap YouTube

DevBootstrap 139 Jan 01, 2023
This dlib-based facial login system

Facial-Login-System This dlib-based facial login system is a technology capable of matching a human face from a digital webcam frame capture against a

Mushahid Ali 3 Apr 23, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Microsoft 125 Jan 04, 2023
Pytorch implementation of Bert and Pals: Projected Attention Layers for Efficient Adaptation in Multi-Task Learning

PyTorch implementation of BERT and PALs Introduction Work by Asa Cooper Stickland and Iain Murray, University of Edinburgh. Code for BERT and PALs; mo

Asa Cooper Stickland 70 Dec 29, 2022
Learned model to estimate number of distinct values (NDV) of a population using a small sample.

Learned NDV estimator Learned model to estimate number of distinct values (NDV) of a population using a small sample. The model approximates the maxim

2 Nov 21, 2022
Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-like Documents.

Value Retrieval with Arbitrary Queries for Form-like Documents Introduction Pytorch Implementation of Value Retrieval with Arbitrary Queries for Form-

Salesforce 13 Sep 15, 2022
The first public PyTorch implementation of Attentive Recurrent Comparators

arc-pytorch PyTorch implementation of Attentive Recurrent Comparators by Shyam et al. A blog explaining Attentive Recurrent Comparators Visualizing At

Sanyam Agarwal 150 Oct 14, 2022
SimpleDepthEstimation - An unified codebase for NN-based monocular depth estimation methods

SimpleDepthEstimation Introduction This is an unified codebase for NN-based monocular depth estimation methods, the framework is based on detectron2 (

8 Dec 13, 2022
Random Erasing Data Augmentation. Experiments on CIFAR10, CIFAR100 and Fashion-MNIST

Random Erasing Data Augmentation =============================================================== black white random This code has the source code for

Zhun Zhong 654 Dec 26, 2022
Deep Learning & 3D Convolutional Neural Networks for Speaker Verification

TensorFlow implementation of 3D Convolutional Neural Networks for Speaker Verification - Official Project Page - Pytorch Implementation This repositor

Amirsina Torfi 753 Dec 17, 2022
A novel benchmark dataset for Monocular Layout prediction

AutoLay AutoLay: Benchmarking Monocular Layout Estimation Kaustubh Mani, N. Sai Shankar, J. Krishna Murthy, and K. Madhava Krishna Abstract In this pa

Kaustubh Mani 39 Apr 26, 2022
A toolset for creating Qualtrics-based IAT experiments

Qualtrics IAT Tool A web app for generating the Implicit Association Test (IAT) running on Qualtrics Online Web App The app is hosted by Streamlit, a

0 Feb 12, 2022
Hierarchical Few-Shot Generative Models

Hierarchical Few-Shot Generative Models Giorgio Giannone, Ole Winther This repo contains code and experiments for the paper Hierarchical Few-Shot Gene

Giorgio Giannone 6 Dec 12, 2022
Project page for End-to-end Recovery of Human Shape and Pose

End-to-end Recovery of Human Shape and Pose Angjoo Kanazawa, Michael J. Black, David W. Jacobs, Jitendra Malik CVPR 2018 Project Page Requirements Pyt

1.4k Dec 29, 2022
MlTr: Multi-label Classification with Transformer

MlTr: Multi-label Classification with Transformer This is official implement of "MlTr: Multi-label Classification with Transformer". Abstract The task

程星 38 Nov 08, 2022
Python inverse kinematics for your robot model based on Pinocchio.

Python inverse kinematics for your robot model based on Pinocchio.

Stéphane Caron 50 Dec 22, 2022