code for paper "Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning" by Zhongzheng Ren*, Raymond A. Yeh*, Alexander G. Schwing.

Overview

Not All Unlabeled Data are Equal:
Learning to Weight Data in Semi-supervised Learning

Overview

This code is for paper: Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning. Zhongzheng Ren*, Raymond A. Yeh*, Alexander G. Schwing. NeurIPS'20. (*equal contribtion)

Setup

Important: ML_DATA is a shell environment variable that should point to the location where the datasets are installed. See the Install datasets section for more details.
Environement*: this code is tested using python-3.7, anaconda3-5.0.1, cuda-10.0, cudnn-v7.6, tensorflow-1.15

Install dependencies

conda create -n semi-sup python=3.7
conda activate semi-sup
pip install -r requirements.txt

make sure tf.test.is_gpu_available() == True after installation so that GPUs will be used.

Install datasets

export ML_DATA="path to where you want the datasets saved"
export PYTHONPATH=$PYTHONPATH:"path to this repo"

# Download datasets
CUDA_VISIBLE_DEVICES= ./scripts/create_datasets.py
cp $ML_DATA/svhn-test.tfrecord $ML_DATA/svhn_noextra-test.tfrecord

# Create unlabeled datasets
CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/cifar10 $ML_DATA/cifar10-train.tfrecord
CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/svhn $ML_DATA/svhn-train.tfrecord $ML_DATA/svhn-extra.tfrecord
CUDA_VISIBLE_DEVICES= scripts/create_unlabeled.py $ML_DATA/SSL2/svhn_noextra $ML_DATA/svhn-train.tfrecord

# Create semi-supervised subsets
for seed in 0 1 2 3 4 5; do
    for size in 250 1000 4000; do
        CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/cifar10 $ML_DATA/cifar10-train.tfrecord
        CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/svhn $ML_DATA/svhn-train.tfrecord $ML_DATA/svhn-extra.tfrecord
        CUDA_VISIBLE_DEVICES= scripts/create_split.py --seed=$seed --size=$size $ML_DATA/SSL2/svhn_noextra $ML_DATA/svhn-train.tfrecord
    done
done

Running

Setup

All commands must be ran from the project root. The following environment variables must be defined:

export ML_DATA="path to where you want the datasets saved"
export PYTHONPATH=$PYTHONPATH:"path to this repo"

Example

For example, train a model with 32 filters on cifar10 shuffled with seed=1, 250 labeled samples and 1000 validation sample:

# single-gpu
CUDA_VISIBLE_DEVICES=0 python main.py --filters=32 [email protected] --train_dir ./experiments

# multi-gpu: just pass more GPUs and the model automatically scales to them, here we assign GPUs 0-1 to the program:
CUDA_VISIBLE_DEVICES=0,1 python main.py --filters=32 [email protected] --train_dir ./experiments

Naming rule: ${dataset}.${seed}@${size}-${valid}
Available labelled sizes are 250, 1000, 4000.
For validation, available sizes are 1000, 5000.
Possible shuffling seeds are 1, 2, 3, 4, 5 and 0 for no shuffling (0 is not used in practiced since data requires to be shuffled for gradient descent to work properly).

Image classification

The hyper-parameters used in the paper:

# 2GPU setting is recommended
for seed in 1 2 3 4 5; do
    for size in 250 1000 4000; do
    CUDA_VISIBLE_DEVICES=0,1 python main.py --filters=32 \
        --dataset=cifar10.${seed}@${size}-1000 \
        --train_dir ./experiments --alpha 0.01 --inner_steps 512
    done
done

Flags

python main.py --help
# The following option might be too slow to be really practical.
# python main.py --helpfull
# So instead I use this hack to find the flags:
fgrep -R flags.DEFINE libml main.py

Monitoring training progress

You can point tensorboard to the training folder (by default it is --train_dir=./experiments) to monitor the training process:

tensorboard.sh --port 6007 --logdir ./experiments

Checkpoint accuracy

We compute the median accuracy of the last 20 checkpoints in the paper, this is done through this code:

# Following the previous example in which we trained [email protected], extracting accuracy:
./scripts/extract_accuracy.py ./experiments/[email protected]/CTAugment_depth2_th0.80_decay0.990/FixMatch_alpha0.01_archresnet_batch64_confidence0.95_filters32_inf_warm0_inner_steps100_lr0.03_nclass10_repeat4_scales3_size_unlabeled49000_uratio7_wd0.0005_wu1.0
# The command above will create a stats/accuracy.json file in the model folder.
# The format is JSON so you can either see its content as a text file or process it to your liking.

Use you own data

  1. You first need to creat *.tfrecord for the labeled and unlabled data; please check scripts/create_datasets.py and scripts/create_unlabeled.py for examples.
  2. Then you need to creat the splits for semi-supervied learning; see scripts/create_split.py.
  3. modify libml/data.py to support the new dataset. Specifically, check this function and this class.
  4. tune hyper-parameters (e.g., learning rate, num_epochs, etc.) to achieve the best results.

Note: our algorithm involves approximation of inverse-Hessian and computation of per-example gradients. Therefore, running on a dataset with large number of classes will be computationally heavy in terms of both speed and memory.

License

Please check LICENSE

Citing this work

If you use this code for your research, please cite our paper.

@inproceedings{ren-ssl2020,
  title = {Not All Unlabeled Data are Equal: Learning to Weight Data in Semi-supervised Learning},
  author = {Zhongzheng Ren$^\ast$ and Raymond A. Yeh$^\ast$ and Alexander G. Schwing},
  booktitle = {Neural Information Processing Systems (NeurIPS)},
  year = {2020},
  note = {$^\ast$ equal contribution},
}

Acknowledgement

The code is built based on: FixMatch (commit: 08d9b83)

FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. Kihyuk Sohn, David Berthelot, Chun-Liang Li, Zizhao Zhang, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Han Zhang, and Colin Raffel.

Contact

Github issues and PR are preferred. Feel free to contact Jason Ren (zr5 AT illinois.edu) for any questions!

Owner
Jason Ren
[email protected]. Brain and eye.
Jason Ren
Official Pytorch Code for the paper TransWeather

TransWeather Official Code for the paper TransWeather, Arxiv Tech Report 2021 Paper | Website About this repo: This repo hosts the implentation code,

Jeya Maria Jose 81 Dec 30, 2022
Tensorflow python implementation of "Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos"

Learning High Fidelity Depths of Dressed Humans by Watching Social Media Dance Videos This repository is the official tensorflow python implementation

Yasamin Jafarian 287 Jan 06, 2023
Builds a LoRa radio frequency fingerprint identification (RFFI) system based on deep learning techiniques

This project builds a LoRa radio frequency fingerprint identification (RFFI) system based on deep learning techiniques.

20 Dec 30, 2022
Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES)

Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES) This repo contains the full NITRATES pipeline for maximum likelihood-driven discov

13 Nov 08, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
darija <-> english dictionary

darija-dictionary Having advanced IT solutions that are well adapted to the Moroccan context passes inevitably through understanding Moroccan dialect.

DODa 102 Jan 01, 2023
A Broader Picture of Random-walk Based Graph Embedding

Random-walk Embedding Framework This repository is a reference implementation of the random-walk embedding framework as described in the paper: A Broa

Zexi Huang 23 Dec 13, 2022
A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

Segnet is deep fully convolutional neural network architecture for semantic pixel-wise segmentation. This is implementation of http://arxiv.org/pdf/15

Pradyumna Reddy Chinthala 190 Dec 15, 2022
Weakly-Supervised Semantic Segmentation Network with Deep Seeded Region Growing (CVPR 2018).

Weakly-Supervised Semantic Segmentation Network with Deep Seeded Region Growing (CVPR2018) By Zilong Huang, Xinggang Wang, Jiasi Wang, Wenyu Liu and J

Zilong Huang 245 Dec 13, 2022
List of papers, code and experiments using deep learning for time series forecasting

Deep Learning Time Series Forecasting List of state of the art papers focus on deep learning and resources, code and experiments using deep learning f

Alexander Robles 2k Jan 06, 2023
Generating Videos with Scene Dynamics

Generating Videos with Scene Dynamics This repository contains an implementation of Generating Videos with Scene Dynamics by Carl Vondrick, Hamed Pirs

Carl Vondrick 706 Jan 04, 2023
A Unified Generative Framework for Various NER Subtasks.

This is the code for ACL-ICJNLP2021 paper A Unified Generative Framework for Various NER Subtasks. Install the package in the requirements.txt, then u

177 Jan 05, 2023
pix2pix in tensorflow.js

pix2pix in tensorflow.js This repo is moved to https://github.com/yining1023/pix2pix_tensorflowjs_lite See a live demo here: https://yining1023.github

Yining Shi 47 Oct 04, 2022
The codes of paper 'Active-LATHE: An Active Learning Algorithm for Boosting the Error exponent for Learning Homogeneous Ising Trees'

Active-LATHE: An Active Learning Algorithm for Boosting the Error exponent for Learning Homogeneous Ising Trees This project contains the codes of pap

0 Apr 20, 2022
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022
Mix3D: Out-of-Context Data Augmentation for 3D Scenes (3DV 2021)

Mix3D: Out-of-Context Data Augmentation for 3D Scenes (3DV 2021) Alexey Nekrasov*, Jonas Schult*, Or Litany, Bastian Leibe, Francis Engelmann Mix3D is

Alexey Nekrasov 189 Dec 26, 2022
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

Sayak Paul 43 Jan 08, 2023
unet-family: Ultimate version

unet-family: Ultimate version 基于之前my-unet代码,我整理出来了这一份终极版本unet-family,方便其他人阅读。 相比于之前的my-unet代码,代码分类更加规范,有条理 对于clone下来的代码不需要修改各种复杂繁琐的路径问题,直接就可以运行。 并且代码有

2 Sep 19, 2022
Human Pose Detection on EdgeTPU

Coral PoseNet Pose estimation refers to computer vision techniques that detect human figures in images and video, so that one could determine, for exa

google-coral 476 Dec 31, 2022
Talk covering the features of skorch

Skorch Talk Skorch - A Union of Scikit-learn and PyTorch Presentation The slides can be downloaded at: download link. Google Colab Part One - MNIST Pa

Thomas J. Fan 3 Oct 20, 2020