The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Overview

Dice Loss for NLP Tasks

This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020.

Setup

  • Install Package Dependencies

The code was tested in Python 3.6.9+ and Pytorch 1.7.1. If you are working on ubuntu GPU machine with CUDA 10.1, please run the following command to setup environment.

$ virtualenv -p /usr/bin/python3.6 venv
$ source venv/bin/activate
$ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
  • Download BERT Model Checkpoints

Before running the repo you must download the BERT-Base and BERT-Large checkpoints from here and unzip it to some directory $BERT_DIR. Then convert original TensorFlow checkpoints for BERT to a PyTorch saved file by running bash scripts/prepare_ckpt.sh <path-to-unzip-tf-bert-checkpoints>.

Apply Dice-Loss to NLP Tasks

In this repository, we apply dice loss to four NLP tasks, including

  1. machine reading comprehension
  2. paraphrase identification task
  3. named entity recognition
  4. text classification

1. Machine Reading Comprehension

Datasets

We take SQuAD 1.1 as an example. Before training, you should download a copy of the data from here.
And move the SQuAD 1.1 train train-v1.1.json and dev file dev-v1.1.json to the directory $DATA_DIR.

Train

We choose BERT as the backbone. During training, the task trainer BertForQA will automatically evaluate on dev set every $val_check_interval epoch, and save the dev predictions into files called $OUTPUT_DIR/predictions_<train-epoch>_<total-train-step>.json and $OUTPUT_DIR/nbest_predictions_<train-epoch>_<total-train-step>.json.

Run scripts/squad1/bert_<model-scale>_<loss-type>.sh to reproduce our experimental results.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [bce, focal, dice] which denotes fine-tuning BERT-Base with binary cross entropy loss, focal loss, dice loss , respectively.

  • Run bash scripts/squad1/bert_base_focal.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for focal loss.

  • Run bash scripts/squad1/bert_base_dice.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for dice loss.

Evaluate

To evaluate a model checkpoint, please run

python3 tasks/squad/evaluate_models.py \
--gpus="1" \
--path_to_model_checkpoint  $OUTPUT_DIR/epoch=2.ckpt \
--eval_batch_size <evaluate-batch-size>

After evaluation, prediction results predictions_dev.json and nbest_predictions_dev.json can be found in $OUTPUT_DIR

To evaluate saved predictions, please run

python3 tasks/squad/evaluate_predictions.py <path-to-dev-v1.1.json> <directory-to-prediction-files>

2. Paraphrase Identification Task

Datasets

We use MRPC (GLUE Version) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Run bash scripts/prepare_mrpc_data.sh $DATA_DIR to download and process datasets for MPRC (GLUE Version) task.

Train

Please run scripts/glue_mrpc/bert_<model-scale>_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint which achieves the highest F1-score on the dev set.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/glue_mrpc/bert_large_focal.sh for focal loss.

  • Run bash scripts/glue_mrpc/bert_large_dice.sh for dice loss.

The evaluation results on the dev and test set are saved at $OUTPUT_DIR/eval_result_log.txt file.
The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Evaluate

To evaluate a model checkpoint on test set, please run

bash scripts/glue_mrpc/eval.sh \
$OUTPUT_DIR \
epoch=*.ckpt

3. Named Entity Recognition

For NER, we use MRC-NER model as the backbone.
Processed datasets and model architecture can be found here.

Train

Please run scripts/<ner-datdaset-name>/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint.
The variable <ner-dataset-name> should take the value of [ner_enontonotes5, ner_zhmsra, ner_zhonto4].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

For Chinese MSRA,

  • Run scripts/ner_zhmsra/bert_focal.sh for focal loss.

  • Run scripts/ner_zhmsra/bert_dice.sh for dice loss.

For Chinese OntoNotes4,

  • Run scripts/ner_zhonto4/bert_focal.sh for focal loss.

  • Run scripts/ner_zhonto4/bert_dice.sh for dice loss.

For English OntoNotes5,

  • Run scripts/ner_enontonotes5/bert_focal.sh. After training, you will get 91.12 Span-F1 on the test set.

  • Run scripts/ner_enontonotes5/bert_dice.sh. After training, you will get 92.01 Span-F1 on the test set.

Evaluate

To evaluate a model checkpoint, please run

CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \
--gpus="1" \
--path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt

4. Text Classification

Datasets

We use TNews (Chinese Text Classification) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Train

We choose BERT as the backbone.
Please run scripts/tnews/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/tnews/bert_focal.sh for focal loss.

  • Run bash scripts/tnews/bert_dice.sh for dice loss.

The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Citation

If you find this repository useful , please cite the following:

@article{li2019dice,
  title={Dice loss for data-imbalanced NLP tasks},
  author={Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Liang, Junjun and Wu, Fei and Li, Jiwei},
  journal={arXiv preprint arXiv:1911.02855},
  year={2019}
}

Contact

xiaoyalixy AT gmail.com OR xiaoya_li AT shannonai.com

Any discussions, suggestions and questions are welcome!

Unofficial implementation of the ImageNet, CIFAR 10 and SVHN Augmentation Policies learned by AutoAugment using pillow

AutoAugment - Learning Augmentation Policies from Data Unofficial implementation of the ImageNet, CIFAR10 and SVHN Augmentation Policies learned by Au

Philip Popien 1.3k Jan 02, 2023
Files for a tutorial to train SegNet for road scenes using the CamVid dataset

SegNet and Bayesian SegNet Tutorial This repository contains all the files for you to complete the 'Getting Started with SegNet' and the 'Bayesian Seg

Alex Kendall 800 Dec 31, 2022
paper list in the area of reinforcenment learning for recommendation systems

paper list in the area of reinforcenment learning for recommendation systems

HenryZhao 23 Jun 09, 2022
Stochastic Tensor Optimization for Robot Motion - A GPU Robot Motion Toolkit

STORM Stochastic Tensor Optimization for Robot Motion - A GPU Robot Motion Toolkit [Install Instructions] [Paper] [Website] This package contains code

NVIDIA Research Projects 101 Dec 12, 2022
Python library for science observations from the James Webb Space Telescope

JWST Calibration Pipeline JWST requires Python 3.7 or above and a C compiler for dependencies. Linux and MacOS platforms are tested and supported. Win

Space Telescope Science Institute 386 Dec 30, 2022
This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT).

Dynamic-Vision-Transformer (Pytorch) This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT). Not All Ima

210 Dec 18, 2022
Code base for "On-the-Fly Test-time Adaptation for Medical Image Segmentation"

On-the-Fly Adaptation Official Pytorch Code base for On-the-Fly Test-time Adaptation for Medical Image Segmentation Paper Introduction One major probl

Jeya Maria Jose 17 Nov 10, 2022
Fashion Landmark Estimation with HRNet

HRNet for Fashion Landmark Estimation (Modified from deep-high-resolution-net.pytorch) Introduction This code applies the HRNet (Deep High-Resolution

SVIP Lab 91 Dec 26, 2022
Realtime segmentation with ENet, the fast and accurate segmentation net.

Enet This is a realtime segmentation net with almost 22 fps on GTX1080 ti, and the model size is very small with only 28M. This repo contains the infe

JinTian 14 Aug 30, 2022
Meta Language-Specific Layers in Multilingual Language Models

Meta Language-Specific Layers in Multilingual Language Models This repo contains the source codes for our paper On Negative Interference in Multilingu

Zirui Wang 20 Feb 13, 2022
Does Oversizing Improve Prosumer Profitability in a Flexibility Market? - A Sensitivity Analysis using PV-battery System

Does Oversizing Improve Prosumer Profitability in a Flexibility Market? - A Sensitivity Analysis using PV-battery System The possibilities to involve

Babu Kumaran Nalini 0 Nov 19, 2021
Fast Axiomatic Attribution for Neural Networks (NeurIPS*2021)

Fast Axiomatic Attribution for Neural Networks This is the official repository accompanying the NeurIPS 2021 paper: R. Hesse, S. Schaub-Meyer, and S.

Visual Inference Lab @TU Darmstadt 11 Nov 21, 2022
Img-process-manual - Utilize Python Numpy and Matplotlib to realize OpenCV baisc image processing function

Img-process-manual - Opencv Library basic graphic processing algorithm coding reproduction based on Numpy and Matplotlib library

Jack_Shaw 2 Dec 12, 2022
iris - Open Source Photos Platform Powered by PyTorch

Open Source Photos Platform Powered by PyTorch. Submission for PyTorch Annual Hackathon 2021.

Omkar Prabhu 137 Sep 10, 2022
UNet model with VGG11 encoder pre-trained on Kaggle Carvana dataset

TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation By Vladimir Iglovikov and Alexey Shvets Introduction TernausNet is

Vladimir Iglovikov 1k Dec 28, 2022
Causal estimators for use with WhyNot

WhyNot Estimators A collection of causal inference estimators implemented in Python and R to pair with the Python causal inference library whynot. For

ZYKLS 8 Apr 06, 2022
Benchmark VAE - Library for Variational Autoencoder benchmarking

Documentation pythae This library implements some of the most common (Variational) Autoencoder models. In particular it provides the possibility to pe

1.1k Jan 02, 2023
Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)

GraspNet Baseline Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). [paper] [dataset] [API] [do

GraspNet 209 Dec 29, 2022
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
chainladder - Property and Casualty Loss Reserving in Python

chainladder (python) chainladder - Property and Casualty Loss Reserving in Python This package gets inspiration from the popular R ChainLadder package

Casualty Actuarial Society 130 Dec 07, 2022