使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法

Related tags

Text Data & NLPSimCSE
Overview

SimCSE复现

项目描述

SimCSE是一种简单但是很巧妙的NLP对比学习方法,创新性地引入Dropout的方式,对样本添加噪声,从而达到对正样本增强的目的。 该框架的训练目的为:对于batch中的每个样本,拉近其与正样本之间的距离,拉远其与负样本之间的距离,使得模型能够在大规模无监督语料(也可以使用有监督的语料)中学习到文本相似关系。 详见论文:Simple Contrastive Learning of Sentence EmbeddingsSimCSE官方代码仓库

本项目使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法,并且在STS-B数据集上进行消融实验,评价指标为Spearman相关系数,预训练模型为Bert-base-uncased, 验证了SimCSE的有效性。在STS-B数据集上,有监督训练和无监督训练的复现效果如下表。

在无监督训练中,dropout=0.1时,复现效果比原文略差,但也比较接近。当dropout=0.2时,复现效果比原文略高。 ** 但在有监督训练中,不知是否由于batch size过小(原论文使用512),复现效果与论文的效果相差较远,后续会进行排查。 **

训练方法 learning rate batch size dropout Spearman’s correlation
原论文 无监督 3e-5 64 0.1 0.763
复现 无监督 3e-5 64 0.2 0.771
复现 无监督 3e-5 64 0.1 0.748
原论文 有监督 5e-5 512 0.1 0.816
复现 有监督 5e-5 64 0.1 0.764

运行环境

python==3.6、transformers==3.1.0、torch==1.6.0

项目结构

  • data:存放训练数据
    • stsbenchmark:STS-B数据集
      • sts-dev.csv:STS-B验证集
      • sts-test.csv:STS-B验测试集
    • nli_for_simcse.csv:数量275601为的NLI数据集
    • wiki1m_for_simcse.txt:维基百科上获取的100w的文本
  • output:输出目录
  • pretrain_model:预训练模型存放位置
  • script:脚本存放位置。
  • dataset.py
  • model.py:模型代码,包含有监督和无监督损失函数的计算方式
  • train.py:训练代码

使用方法

Quick Start

下载训练数据:

bash script/download_nli.sh
bash script/download_wiki.sh

无监督训练,运行脚本

bash script/run_unsup_train.sh

有监督训练,运行脚本

bash script/run_sup_train.sh

实验

无监督训练

从前四条实验数据中可以看到,较大的batch size在一定程度上可以增加模型的泛化性。

dropout为0.2的时候,训练效果比0.1与0.3更好,有可能dropout=0.1加入的噪声过小,而dropout=0.3加入的噪声过大,增强得到的样本与原始样本差异较大。

learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
3e-5 256 0.1 6000 0.800 0.761
3e-5 128 0.1 4200 0.799 0.747
3e-5 64 0.1 10900 0.803 0.748
3e-5 32 0.1 21300 0.787 0.714
3e-5 64 0.2 11200 0.811 0.771
3e-5 64 0.3 6300 0.781 0.745
1e-5 64 0.1 16400 0.798 0.751

有监督训练

有监督实验的复现结果未达到预期,超参数相同时,在验证集上的得分略高于无监督,但是在测试集上,得分基本没有差异。增大有监督训练的学习率,有监督的训练的得分略高于无监督训练, 但还是与论文声称的0.816相差较远,原论文使用512的batch size, 不知是否由于batch size的设置有关,后续会对有监督的训练代码进一步排查。

不过从训练曲线可以看到,有监督训练的收敛速度明显快于无监督训练,这也符合我们的认知。

训练方法 learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
无监督 3e-5 64 0.1 10900 0.803 0.748
有监督 3e-5 64 0.1 200 0.810 0.748
有监督 5e-5 64 0.1 2300 0.809 0.764
有监督 3e-5 32 0.1 200 0.808 0.743
有监督 5e-5 32 0.1 200 0.806 0.746

无监督训练过程中,验证集得分的变化曲线: avatar

有监督训练过程中,验证集得分的变化曲线: avatar

REFERENCE

TODO

  • 排查有监督学习的效果不符合预期的原因
Long text token classification using LongFormer

Long text token classification using LongFormer

abhishek thakur 161 Aug 07, 2022
Deal or No Deal? End-to-End Learning for Negotiation Dialogues

Introduction This is a PyTorch implementation of the following research papers: (1) Hierarchical Text Generation and Planning for Strategic Dialogue (

Facebook Research 1.4k Dec 29, 2022
A repository to run gpt-j-6b on low vram machines (4.2 gb minimum vram for 2000 token context, 3.5 gb for 1000 token context). Model loading takes 12gb free ram.

Basic-UI-for-GPT-J-6B-with-low-vram A repository to run GPT-J-6B on low vram systems by using both ram, vram and pinned memory. There seem to be some

90 Dec 25, 2022
HiFi DeepVariant + WhatsHap workflowHiFi DeepVariant + WhatsHap workflow

HiFi DeepVariant + WhatsHap workflow Workflow steps align HiFi reads to reference with pbmm2 call small variants with DeepVariant, using two-pass meth

William Rowell 2 May 14, 2022
Converts python code into c++ by using OpenAI CODEX.

🦾 codex_py2cpp 🤖 OpenAI Codex Python to C++ Code Generator Your Python Code is too slow? 🐌 You want to speed it up but forgot how to code in C++? ⌨

Alexander 423 Jan 01, 2023
AI-Broad-casting - AI Broad casting with python

Basic Code 1. Use The Code Configuration Environment conda create -n code_base p

Python interface for converting Penn Treebank trees to Stanford Dependencies and Universal Depenencies

PyStanfordDependencies Python interface for converting Penn Treebank trees to Universal Dependencies and Stanford Dependencies. Example usage Start by

David McClosky 64 May 08, 2022
Lattice methods in TensorFlow

TensorFlow Lattice TensorFlow Lattice is a library that implements constrained and interpretable lattice based models. It is an implementation of Mono

504 Dec 20, 2022
Sentence boundary disambiguation tool for Japanese texts (日本語文境界判定器)

Bunkai Bunkai is a sentence boundary (SB) disambiguation tool for Japanese texts. Quick Start $ pip install bunkai $ echo -e '宿を予約しました♪!まだ2ヶ月も先だけど。早すぎ

Megagon Labs 160 Dec 23, 2022
SAVI2I: Continuous and Diverse Image-to-Image Translation via Signed Attribute Vectors

SAVI2I: Continuous and Diverse Image-to-Image Translation via Signed Attribute Vectors [Paper] [Project Website] Pytorch implementation for SAVI2I. We

Qi Mao 44 Dec 30, 2022
Negative sampling for solving the unlabeled entity problem in NER. ICLR-2021 paper: Empirical Analysis of Unlabeled Entity Problem in Named Entity Recognition.

Negative Sampling for NER Unlabeled entity problem is prevalent in many NER scenarios (e.g., weakly supervised NER). Our paper in ICLR-2021 proposes u

Yangming Li 128 Dec 29, 2022
📜 GPT-2 Rhyming Limerick and Haiku models using data augmentation

Well-formed Limericks and Haikus with GPT2 📜 GPT-2 Rhyming Limerick and Haiku models using data augmentation In collaboration with Matthew Korahais &

Bardia Shahrestani 2 May 26, 2022
[ICLR 2021 Spotlight] Pytorch implementation for "Long-tailed Recognition by Routing Diverse Distribution-Aware Experts."

RIDE: Long-tailed Recognition by Routing Diverse Distribution-Aware Experts. by Xudong Wang, Long Lian, Zhongqi Miao, Ziwei Liu and Stella X. Yu at UC

Xudong (Frank) Wang 205 Dec 16, 2022
Blackstone is a spaCy model and library for processing long-form, unstructured legal text

Blackstone Blackstone is a spaCy model and library for processing long-form, unstructured legal text. Blackstone is an experimental research project f

ICLR&D 579 Jan 08, 2023
Code for Emergent Translation in Multi-Agent Communication

Emergent Translation in Multi-Agent Communication PyTorch implementation of the models described in the paper Emergent Translation in Multi-Agent Comm

Facebook Research 75 Jul 15, 2022
STT for TorchScript is a port of Coqui STT based on DeepSpeech to PyTorch.

st3 STT for TorchScript is a port of Coqui STT based on DeepSpeech to PyTorch. Currently it supports converting pbmm models to pt scripts with integra

Vlad Ki 8 Oct 18, 2021
Material for GW4SHM workshop, 16/03/2022.

GW4SHM Workshop Wednesday, 16th March 2022 (13:00 – 15:15 GMT): Presented by: Dr. Rhodri Nelson, Imperial College London Project website: https://www.

Devito Codes 1 Mar 16, 2022
NLPShala , the best IDE for all Natural language processing tasks.

The revolutionary IDE for all NLP (Natural language processing) stuffs on the internet.

Abhi 3 Aug 08, 2021
Blazing fast language detection using fastText model

Luga A blazing fast language detection using fastText's language models Luga is a Swahili word for language. fastText provides a blazing fast language

Prayson Wilfred Daniel 18 Dec 20, 2022
Web mining module for Python, with tools for scraping, natural language processing, machine learning, network analysis and visualization.

Pattern Pattern is a web mining module for Python. It has tools for: Data Mining: web services (Google, Twitter, Wikipedia), web crawler, HTML DOM par

Computational Linguistics Research Group 8.4k Dec 30, 2022