[AAAI 21] Curriculum Labeling: Revisiting Pseudo-Labeling for Semi-Supervised Learning

Overview

◥ Curriculum Labeling ◣

Revisiting Pseudo-Labeling for Semi-Supervised Learning

Paola Cascante-Bonilla, Fuwen Tan, Yanjun Qi, Vicente Ordonez.

In the 35th AAAI Conference on Artificial Intelligence. AAAI 2021.

AboutRequirementsTrain/EvalBibtex

About

In this paper we revisit the idea of pseudo-labeling in the context of semi-supervised learning where a learning algorithm has access to a small set of labeled samples and a large set of unlabeled samples. Pseudo-labeling works by applying pseudo-labels to samples in the unlabeled set by using a model trained on the combination of the labeled samples and any previously pseudo-labeled samples, and iteratively repeating this process in a self-training cycle. Current methods seem to have abandoned this approach in favor of consistency regularization methods that train models under a combination of different styles of self-supervised losses on the unlabeled samples and standard supervised losses on the labeled samples. We empirically demonstrate that pseudo-labeling can in fact be competitive with the state-of-the-art, while being more resilient to out-of-distribution samples in the unlabeled set. We identify two key factors that allow pseudo-labeling to achieve such remarkable results (1) applying curriculum learning principles and (2) avoiding concept drift by restarting model parameters before each self-training cycle. We obtain 94.91% accuracy on CIFAR-10 using only 4,000 labeled samples, and 68.87% top-1 accuracy on Imagenet-ILSVRC using only 10% of the labeled samples.


Curriculum Labeling (CL) Algorithm.


Requirements

  • python >= 3.7.7
  • pytorch > 1.5.0
  • torchvision
  • tensorflow-gpu==1.14
  • torchcontrib
  • pytest
  • Download both zca_components.npy and zca_mean.npy. Save them in the main folder (Curriculum-Labeling).

Train

TL;DR

Run the command below to reproduce one of our experiments on CIFAR-10 with WideResNet-28-2:

python main.py --doParallel --seed 821 --nesterov --weight-decay 0.0005 --arch WRN28_2 --batch_size 512 --epochs 700 --lr_rampdown_epochs 750 --add_name WRN28_CIFAR10_AUG_MIX_SWA --mixup --swa

Everything you need to run and evaluate Curriculum Labeling is in main.py. The Wrapper class contains all the main functions to create the model, prepare the dataset, and train your model. The arguments you pass are handled by the Wrapper. For example, if you want to activate the debug mode to sneak-peak the test set scores, you can add the argument --debug when executing python main.py.

The code below shows how to set every step and get ready to train:

import wrapper as super_glue
# all possible parameters are passed to the wrapper as a dictionary
wrapper = super_glue.Wrapper(args_dict)
# one line to prepare datasets
wrapper.prepare_datasets()
# create the model
wrapper.create_network()
# set the hyperparameters
wrapper.set_model_hyperparameters()
# set optimizer (SGD or Adam)
wrapper.set_model_optimizer()
# voilà! really? sure, print the model!
print (wrapper.model)

Then you just have to call the train and evaluate functions:

# train cl
wrapper.train_cl()
# evaluate cl 
wrapper.eval_cl()

Some Arguments and Usage

usage: main.py [-h] [--dataset DATASET] [--num_labeled L]
               [--num_valid_samples V] [--arch ARCH] [--dropout DO]
               [--optimizer OPTIMIZER] [--epochs N] [--start_epoch N] [-b N]
               [--lr LR] [--initial_lr LR] [--lr_rampup EPOCHS]
               [--lr_rampdown_epochs EPOCHS] [--momentum M] [--nesterov]
               [--weight-decay W] [--checkpoint_epochs EPOCHS]
               [--print_freq N] [--pretrained] [--root_dir ROOT_DIR]
               [--data_dir DATA_DIR] [--n_cpus N_CPUS] [--add_name ADD_NAME]
               [--doParallel] [--use_zca] [--pretrainedEval]
               [--pretrainedFrom PATH] [-e] [-evaluateLabeled]
               [-getLabeledResults]
               [--set_labeled_classes SET_LABELED_CLASSES]
               [--set_unlabeled_classes SET_UNLABELED_CLASSES]
               [--percentiles_holder PERCENTILES_HOLDER] [--static_threshold]
               [--seed SEED] [--augPolicy AUGPOLICY] [--swa]
               [--swa_start SWA_START] [--swa_freq SWA_FREQ] [--mixup]
               [--alpha ALPHA] [--debug]

Detailed list of Arguments

arg default help
--help show this help message and exit
--dataset cifar10 dataset: cifar10, svhn or imagenet
--num_labeled 400 number of labeled samples per class
--num_valid_samples 500 number of validation samples per class
--arch cnn13 either of cnn13, WRN28_2, resnet50
--dropout 0.0 dropout rate
--optimizer sgd optimizer we are going to use. can be either adam of sgd
--epochs 100 number of total epochs to run
--start_epoch 0 manual epoch number (useful on restarts)
--batch_size 100 mini-batch size (default: 100)
--learning-rate 0.1 max learning rate
--initial_lr 0.0 initial learning rate when using linear rampup
--lr_rampup 0 length of learning rate rampup in the beginning
--lr_rampdown_epochs 150 length of learning rate cosine rampdown (>= length of training): the epoch at which learning rate reaches to zero
--momentum 0.9 momentum
--nesterov use nesterov momentum
--wd 0.0001 weight decay (default: 1e-4)
--checkpoint_epochs 500 checkpoint frequency (by epoch)
--print_freq 100 print frequency (default: 10)
--pretrained use pre-trained model
--root_dir experiments folder where results are to be stored
--data_dir /data/cifar10/ folder where data is stored
--n_cpus 12 number of cpus for data loading
--add_name SSL_Test Name of your folder to store the experiment results
--doParallel use DataParallel
--use_zca use zca whitening
--pretrainedEval use pre-trained model
--pretrainedFrom /full/path/ path to pretrained results (default: none)
--set_labeled_classes 0,1,2,3,4,5,6,7,8,9 set the classes to treat as the label set
--set_unlabeled_classes 0,1,2,3,4,5,6,7,8,9 set the classes to treat as the unlabeled set
--percentiles_holder 20 mu parameter - sets the steping percentile for thresholding after each iteration
--static_threshold use static threshold
--seed 0 define seed for random distribution of dataset
--augPolicy 2 augmentation policy: 0 for none, 1 for moderate, 2 for heavy (random-augment)
--swa Apply SWA
--swa_start 200 Start SWA
--swa_freq 5 Frequency
--mixup Apply Mixup to inputs
--alpha 1.0 mixup interpolation coefficient (default: 1)
--debug Track the testing accuracy, only for debugging purposes

Bibtex

If you use Curriculum Labeling for your research or projects, please cite Curriculum Labeling: Revisiting Pseudo-Labeling for Semi-Supervised Learning.

@misc{cascantebonilla2020curriculum,
    title={Curriculum Labeling: Revisiting Pseudo-Labeling for Semi-Supervised Learning},
    author={Paola Cascante-Bonilla and Fuwen Tan and Yanjun Qi and Vicente Ordonez},
    year={2020},
    eprint={2001.06001},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Code for paper "Which Training Methods for GANs do actually Converge? (ICML 2018)"

GAN stability This repository contains the experiments in the supplementary material for the paper Which Training Methods for GANs do actually Converg

Lars Mescheder 884 Nov 11, 2022
Source code and dataset for ACL 2019 paper "ERNIE: Enhanced Language Representation with Informative Entities"

ERNIE Source code and dataset for "ERNIE: Enhanced Language Representation with Informative Entities" Reqirements: Pytorch=0.4.1 Python3 tqdm boto3 r

THUNLP 1.3k Dec 30, 2022
Beyond the Imitation Game collaborative benchmark for enormous language models

BIG-bench 🪑 The Beyond the Imitation Game Benchmark (BIG-bench) will be a collaborative benchmark intended to probe large language models, and extrap

Google 1.3k Jan 01, 2023
Residual2Vec: Debiasing graph embedding using random graphs

Residual2Vec: Debiasing graph embedding using random graphs This repository contains the code for S. Kojaku, J. Yoon, I. Constantino, and Y.-Y. Ahn, R

SADAMORI KOJAKU 5 Oct 12, 2022
Source code of paper "BP-Transformer: Modelling Long-Range Context via Binary Partitioning"

BP-Transformer This repo contains the code for our paper BP-Transformer: Modeling Long-Range Context via Binary Partition Zihao Ye, Qipeng Guo, Quan G

Zihao Ye 119 Nov 14, 2022
xFormers is a modular and field agnostic library to flexibly generate transformer architectures by interoperable and optimized building blocks.

Description xFormers is a modular and field agnostic library to flexibly generate transformer architectures by interoperable and optimized building bl

Facebook Research 2.3k Jan 08, 2023
The code for the Subformer, from the EMNLP 2021 Findings paper: "Subformer: Exploring Weight Sharing for Parameter Efficiency in Generative Transformers", by Machel Reid, Edison Marrese-Taylor, and Yutaka Matsuo

Subformer This repository contains the code for the Subformer. To help overcome this we propose the Subformer, allowing us to retain performance while

Machel Reid 10 Dec 27, 2022
Neural building blocks for speaker diarization: speech activity detection, speaker change detection, overlapped speech detection, speaker embedding

⚠️ Checkout develop branch to see what is coming in pyannote.audio 2.0: a much smaller and cleaner codebase Python-first API (the good old pyannote-au

pyannote 2.2k Jan 09, 2023
Python library for Serbian Natural language processing (NLP)

SrbAI - Python biblioteka za procesiranje srpskog jezika SrbAI je projekat prikupljanja algoritama i modela za procesiranje srpskog jezika u jedinstve

Serbian AI Society 3 Nov 22, 2022
An automated program that helps customers of Pizza Palour place their pizza orders

PIzza_Order_Assistant Introduction An automated program that helps customers of Pizza Palour place their pizza orders. The program uses voice commands

Tindi Sommers 1 Dec 26, 2021
Repository for the paper "Optimal Subarchitecture Extraction for BERT"

Bort Companion code for the paper "Optimal Subarchitecture Extraction for BERT." Bort is an optimal subset of architectural parameters for the BERT ar

Alexa 461 Nov 21, 2022
A relatively simple python program to generate one of those reddit text to speech videos dominating youtube.

Reddit text to speech generator A basic reddit tts video generator Current functionality Generate videos for subs based on comments,(askreddit) so rea

Aadvik 17 Dec 19, 2022
A Flask Sentiment Analysis API, with visual implementation

The Sentiment Analysis Api was created using python flask module,it allows users to parse a text or sentence throught the (?text) arguement, then view the sentiment analysis of that sentence. It can

Ifechukwudeni Oweh 10 Jul 17, 2022
SimCSE: Simple Contrastive Learning of Sentence Embeddings

SimCSE: Simple Contrastive Learning of Sentence Embeddings This repository contains the code and pre-trained models for our paper SimCSE: Simple Contr

Princeton Natural Language Processing 2.5k Jan 07, 2023
Machine translation models released by the Gourmet project

Gourmet Models Overview The Gourmet project has released several machine translation models to translate low-resource languages. This repository conta

Edinburgh NLP 5 Dec 08, 2021
Neural network models for joint POS tagging and dependency parsing (CoNLL 2017-2018)

Neural Network Models for Joint POS Tagging and Dependency Parsing Implementations of joint models for POS tagging and dependency parsing, as describe

Dat Quoc Nguyen 152 Sep 02, 2022
An assignment from my grad-level data mining course demonstrating some experience with NLP/neural networks/Pytorch

NLP-Pytorch-Assignment An assignment from my grad-level data mining course (before I started personal projects) demonstrating some experience with NLP

David Thorne 0 Feb 06, 2022
Chinese NewsTitle Generation Project by GPT2.带有超级详细注释的中文GPT2新闻标题生成项目。

GPT2-NewsTitle 带有超详细注释的GPT2新闻标题生成项目 UpDate 01.02.2021 从网上收集数据,将清华新闻数据、搜狗新闻数据等新闻数据集,以及开源的一些摘要数据进行整理清洗,构建一个较完善的中文摘要数据集。 数据集清洗时,仅进行了简单地规则清洗。

logCong 785 Dec 29, 2022
Easy, fast, effective, and automatic g-code compression!

Getting to the meat of g-code. Easy, fast, effective, and automatic g-code compression! MeatPack nearly doubles the effective data rate of a standard

Scott Mudge 97 Nov 21, 2022
Code repository of the paper Neural circuit policies enabling auditable autonomy published in Nature Machine Intelligence

Code repository of the paper Neural circuit policies enabling auditable autonomy published in Nature Machine Intelligence

9 Jan 08, 2023