codes for "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation" (long paper of EMNLP-2022)

Overview

Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference)

Contents

Overview

We propose to conduct scheduled sampling based on decoding steps instead of the original training steps. We observe that our proposal can more realistically simulate the distribution of real translation errors, thus better bridging the gap between training and inference. The paper has been accepted to the main conference of EMNLP-2021.

Background

fastText

We conduct scheduled sampling for the Transformer with a two-pass decoder. An example of pseudo-code is as follows:

# first-pass: the same as the standard Transformer decoder
first_decoder_outputs = decoder(first_decoder_inputs)

# sampling tokens between model predicitions and ground-truth tokens
second_decoder_inputs = sampling_function(first_decoder_outputs, first_decoder_inputs)

# second-pass: computing the decoder again with the above sampled tokens
second_decoder_outputs = decoder(second_decoder_inputs)

Quick to Use

Our approaches are suitable for most autoregressive-based tasks. Please try the following pseudo-codes when conducting scheduled sampling:

import torch

def sampling_function(first_decoder_outputs, first_decoder_inputs, max_seq_len, tgt_lengths)
    '''
    conduct scheduled sampling based on the index of decoded tokens 
    param first_decoder_outputs: [batch_size, seq_len, hidden_size], model prediections 
    param first_decoder_inputs: [batch_size, seq_len, hidden_size], ground-truth target tokens
    param max_seq_len: scalar, the max lengh of target sequence
    param tgt_lengths: [batch_size], the lenghs of target sequences in a mini-batch
    '''

    # indexs of decoding steps
    t = torch.range(0, max_seq_len-1)

    # differenct sampling strategy based on decoding steps
    if sampling_strategy == "exponential":
        threshold_table = exp_radix ** t  
    elif sampling_strategy == "sigmoid":
        threshold_table = sigmoid_k / (sigmoid_k + torch.exp(t / sigmoid_k ))
    elif sampling_strategy == "linear":        
        threshold_table = torch.max(epsilon, 1 - t / max_seq_len)
    else:
        ValuraiseeError("Unknown sampling_strategy %s" % sampling_strategy)

    # convert threshold_table to [batch_size, seq_len]
    threshold_table = threshold_table.unsqueeze_(0).repeat(max_seq_len, 1).tril()
    thresholds = threshold_table[tgt_lengths].view(-1, max_seq_len)
    thresholds = current_thresholds[:, :seq_len]

    # conduct sampling based on the above thresholds
    random_select_seed = torch.rand([batch_size, seq_len]) 
    second_decoder_inputs = torch.where(random_select_seed < thresholds, first_decoder_inputs, first_decoder_outputs)

    return second_decoder_inputs
    

Further Usage

Error accumulation is a common phenomenon in NLP tasks. Whenever you want to simulate the accumulation of errors, our method may come in handy. For examples:

# sampling tokens between noisy target tokens and ground-truth tokens
decoder_inputs = sampling_function(noisy_decoder_inputs, golden_decoder_inputs, max_seq_len, tgt_lengths)

# computing the decoder with the above sampled tokens
decoder_outputs = decoder(decoder_inputs)
# sampling utterences from model predictions and ground-truth utterences
contexts = sampling_function(predicted_utterences, golden_utterences, max_turns, current_turns)

model_predictions = dialogue_model(contexts, target_inputs)

Experiments

We provide scripts to reproduce the results in this paper(NMT and text summarization)

Citation

Please cite this paper if you find this repo useful.

@inproceedings{liu_ss_decoding_2021,
    title = "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation",
    author = "Liu, Yijin  and
      Meng, Fandong  and
      Chen, Yufeng  and
      Xu, Jinan  and
      Zhou, Jie",
    booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
    year = "2021",
    address = "Online"
}

Contact

Please feel free to contact us ([email protected]) for any further questions.

Owner
Adaxry
Fast learner, eagle for new knowledge and deeper understanding
Adaxry
Official Implementation of DE-CondDETR and DELA-CondDETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-CondDETR and DELA-Cond

Wen Wang 41 Dec 12, 2022
Lightweight, Python library for fast and reproducible experimentation :microscope:

Steppy What is Steppy? Steppy is a lightweight, open-source, Python 3 library for fast and reproducible experimentation. Steppy lets data scientist fo

minerva.ml 134 Jul 10, 2022
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 08, 2023
PyTorch implementation of the NIPS-17 paper "Poincaré Embeddings for Learning Hierarchical Representations"

Poincaré Embeddings for Learning Hierarchical Representations PyTorch implementation of Poincaré Embeddings for Learning Hierarchical Representations

Facebook Research 1.6k Dec 25, 2022
一套完整的微博舆情分析流程代码,包括微博爬虫、LDA主题分析和情感分析。

已经将项目的关键文件上传,包含微博爬虫、LDA主题分析和情感分析三个部分。 1.微博爬虫 实现微博评论爬取和微博用户信息爬取,一天大概十万条。 2.LDA主题分析 实现文档主题抽取,包括数据清洗及分词、主题数的确定(主题一致性和困惑度)和最优主题模型的选择(暴力搜索)。 3.情感分析 实现评论文本的

182 Jan 02, 2023
Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction

Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction Requirements The code has been tested running under Python 3.7.4, with the foll

zshicode 84 Jan 01, 2023
Official Repository for the paper "Improving Baselines in the Wild".

iWildCam and FMoW baselines (WILDS) This repository was originally forked from the official repository of WILDS datasets (commit 7e103ed) For general

Kazuki Irie 3 Nov 24, 2022
Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

zshicode 1 Nov 18, 2021
Experiments and examples converting Transformers to ONNX

Experiments and examples converting Transformers to ONNX This repository containes experiments and examples on converting different Transformers to ON

Philipp Schmid 4 Dec 24, 2022
On the Adversarial Robustness of Visual Transformer

On the Adversarial Robustness of Visual Transformer Code for our paper "On the Adversarial Robustness of Visual Transformers"

Rulin Shao 35 Dec 14, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python = 3.8 pytorch = 1.8.0

DV Lab 63 Dec 16, 2022
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Facebook Research 3.8k Dec 22, 2022
This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021.

inverse_attention This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021. Le

Firas Laakom 5 Jul 08, 2022
Sentinel-1 vessel detection model used in the xView3 challenge

sar_vessel_detect Code for the AI2 Skylight team's submission in the xView3 competition (https://iuu.xview.us) for vessel detection in Sentinel-1 SAR

AI2 6 Sep 10, 2022
Predicting a person's gender based on their weight and height

Logistic Regression Advanced Case Study Gender Classification: Predicting a person's gender based on their weight and height 1. Introduction We turn o

1 Feb 01, 2022
Implementation of: "Exploring Randomly Wired Neural Networks for Image Recognition"

RandWireNN Unofficial PyTorch Implementation of: Exploring Randomly Wired Neural Networks for Image Recognition. Results Validation result on Imagenet

Seung-won Park 684 Nov 02, 2022
This game was designed to encourage young people not to gamble on lotteries, as the probablity of correctly guessing the number is infinitesimal!

Lottery Simulator 2022 for Web Launch Application Developed by John Seong in Ontario. This game was designed to encourage young people not to gamble o

John Seong 2 Sep 02, 2022
SymPy-powered, Wolfram|Alpha-like answer engine totally in your browser, without backend computation

SymPy Beta SymPy Beta is a fork of SymPy Gamma. The purpose of this project is to run a SymPy-powered, Wolfram|Alpha-like answer engine totally in you

Liumeo 25 Dec 21, 2022
Scalable Optical Flow-based Image Montaging and Alignment

SOFIMA SOFIMA (Scalable Optical Flow-based Image Montaging and Alignment) is a tool for stitching, aligning and warping large 2d, 3d and 4d microscopy

Google Research 16 Dec 21, 2022
Pytorch implementation of NeurIPS 2021 paper: Geometry Processing with Neural Fields.

Geometry Processing with Neural Fields Pytorch implementation for the NeurIPS 2021 paper: Geometry Processing with Neural Fields Guandao Yang, Serge B

Guandao Yang 162 Dec 16, 2022