[NeurIPS 2020] Semi-Supervision (Unlabeled Data) & Self-Supervision Improve Class-Imbalanced / Long-Tailed Learning

Overview

Rethinking the Value of Labels for Improving Class-Imbalanced Learning

This repository contains the implementation code for paper:
Rethinking the Value of Labels for Improving Class-Imbalanced Learning
Yuzhe Yang, and Zhi Xu
34th Conference on Neural Information Processing Systems (NeurIPS), 2020
[Website] [arXiv] [Paper] [Slides] [Video]

If you find this code or idea useful, please consider citing our work:

@inproceedings{yang2020rethinking,
  title={Rethinking the Value of Labels for Improving Class-Imbalanced Learning},
  author={Yang, Yuzhe and Xu, Zhi},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}

Overview

In this work, we show theoretically and empirically that, both semi-supervised learning (using unlabeled data) and self-supervised pre-training (first pre-train the model with self-supervision) can substantially improve the performance on imbalanced (long-tailed) datasets, regardless of the imbalanceness on labeled/unlabeled data and the base training techniques.

Semi-Supervised Imbalanced Learning: Using unlabeled data helps to shape clearer class boundaries and results in better class separation, especially for the tail classes. semi

Self-Supervised Imbalanced Learning: Self-supervised pre-training (SSP) helps mitigate the tail classes leakage during testing, which results in better learned boundaries and representations. self

Installation

Prerequisites

Dependencies

  • PyTorch (>= 1.2, tested on 1.4)
  • yaml
  • scikit-learn
  • TensorboardX

Code Overview

Main Files

Main Arguments

  • --dataset: name of chosen long-tailed dataset
  • --imb_factor: imbalance factor (inverse value of imbalance ratio \rho in the paper)
  • --imb_factor_unlabel: imbalance factor for unlabeled data (inverse value of unlabel imbalance ratio \rho_U)
  • --pretrained_model: path to self-supervised pre-trained models
  • --resume: path to resume checkpoint (also for evaluation)

Getting Started

Semi-Supervised Imbalanced Learning

Unlabeled data sourcing

CIFAR-10-LT: CIFAR-10 unlabeled data is prepared following this repo using the 80M TinyImages. In short, a data sourcing model is trained to distinguish CIFAR-10 classes and an "non-CIFAR" class. For each class, images are then ranked based on the prediction confidence, and unlabeled (imbalanced) datasets are constructed accordingly. Use the following link to download the prepared unlabeled data, and place in your data_path:

SVHN-LT: Since its own dataset contains an extra part with 531.1K additional (labeled) samples, they are directly used to simulate the unlabeled dataset.

Note that the class imbalance in unlabeled data is also considered, which is controlled by --imb_factor_unlabel (\rho_U in the paper). See imbalance_cifar.py and imbalance_svhn.py for details.

Semi-supervised learning with pseudo-labeling

To perform pseudo-labeling (self-training), first a base classifier is trained on original imbalanced dataset. With the trained base classifier, pseudo-labels can be generated using

python gen_pseudolabels.py --resume <ckpt-path> --data_dir <data_path> --output_dir <output_path> --output_filename <save_name>

We provide generated pseudo label files for CIFAR-10-LT & SVHN-LT with \rho=50, using base models trained with standard cross-entropy (CE) loss:

To train with unlabeled data, for example, on CIFAR-10-LT with \rho=50 and \rho_U=50

python train_semi.py --dataset cifar10 --imb_factor 0.02 --imb_factor_unlabel 0.02

Self-Supervised Imbalanced Learning

Self-supervised pre-training (SSP)

To perform Rotation SSP on CIFAR-10-LT with \rho=100

python pretrain_rot.py --dataset cifar10 --imb_factor 0.01

To perform MoCo SSP on ImageNet-LT

python pretrain_moco.py --dataset imagenet --data <data_path>

Network training with SSP models

Train on CIFAR-10-LT with \rho=100

python train.py --dataset cifar10 --imb_factor 0.01 --pretrained_model <path_to_ssp_model>

Train on ImageNet-LT / iNaturalist 2018

python -m imagenet_inat.main --cfg <path_to_ssp_config> --model_dir <path_to_ssp_model>

Results and Models

All related data and checkpoints can be found via this link. Individual results and checkpoints are detailed as follows.

Semi-Supervised Imbalanced Learning

CIFAR-10-LT

Model Top-1 Error Download
CE + [email protected] (\rho=50 and \rho_U=1) 16.79 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=25) 16.88 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=50) 18.36 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=100) 19.94 ResNet-32

SVHN-LT

Model Top-1 Error Download
CE + [email protected] (\rho=50 and \rho_U=1) 13.07 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=25) 13.36 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=50) 13.16 ResNet-32
CE + [email protected] (\rho=50 and \rho_U=100) 14.54 ResNet-32

Test a pretrained checkpoint

python train_semi.py --dataset cifar10 --resume <ckpt-path> -e

Self-Supervised Imbalanced Learning

CIFAR-10-LT

  • Self-supervised pre-trained models (Rotation)

    Dataset Setting \rho=100 \rho=50 \rho=10
    Download ResNet-32 ResNet-32 ResNet-32
  • Final models (200 epochs)

    Model \rho Top-1 Error Download
    CE(Uniform) + SSP 10 12.28 ResNet-32
    CE(Uniform) + SSP 50 21.80 ResNet-32
    CE(Uniform) + SSP 100 26.50 ResNet-32
    CE(Balanced) + SSP 10 11.57 ResNet-32
    CE(Balanced) + SSP 50 19.60 ResNet-32
    CE(Balanced) + SSP 100 23.47 ResNet-32

CIFAR-100-LT

  • Self-supervised pre-trained models (Rotation)

    Dataset Setting \rho=100 \rho=50 \rho=10
    Download ResNet-32 ResNet-32 ResNet-32
  • Final models (200 epochs)

    Model \rho Top-1 Error Download
    CE(Uniform) + SSP 10 42.93 ResNet-32
    CE(Uniform) + SSP 50 54.96 ResNet-32
    CE(Uniform) + SSP 100 59.60 ResNet-32
    CE(Balanced) + SSP 10 41.94 ResNet-32
    CE(Balanced) + SSP 50 52.91 ResNet-32
    CE(Balanced) + SSP 100 56.94 ResNet-32

ImageNet-LT

  • Self-supervised pre-trained models (MoCo)
    [ResNet-50]

  • Final models (90 epochs)

    Model Top-1 Error Download
    CE(Uniform) + SSP 54.4 ResNet-50
    CE(Balanced) + SSP 52.4 ResNet-50
    cRT + SSP 48.7 ResNet-50

iNaturalist 2018

  • Self-supervised pre-trained models (MoCo)
    [ResNet-50]

  • Final models (90 epochs)

    Model Top-1 Error Download
    CE(Uniform) + SSP 35.6 ResNet-50
    CE(Balanced) + SSP 34.1 ResNet-50
    cRT + SSP 31.9 ResNet-50

Test a pretrained checkpoint

# test on CIFAR-10 / CIFAR-100
python train.py --dataset cifar10 --resume <ckpt-path> -e

# test on ImageNet-LT / iNaturalist 2018
python -m imagenet_inat.main --cfg <path_to_ssp_config> --model_dir <path_to_model> --test

Acknowledgements

This code is partly based on the open-source implementations from the following sources: OpenLongTailRecognition, classifier-balancing, LDAM-DRW, MoCo, and semisup-adv.

Contact

If you have any questions, feel free to contact us through email ([email protected] & [email protected]) or Github issues. Enjoy!

Owner
Yuzhe Yang
Ph.D. student at MIT CSAIL
Yuzhe Yang
In this tutorial, you will perform inference across 10 well-known pre-trained object detectors and fine-tune on a custom dataset. Design and train your own object detector.

Object Detection Object detection is a computer vision task for locating instances of predefined objects in images or videos. In this tutorial, you wi

Ibrahim Sobh 62 Dec 25, 2022
BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation This is a demo implementation of BYOL for Audio (BYOL-A), a self-sup

NTT Communication Science Laboratories 160 Jan 04, 2023
DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification Created by Yongming Rao, Wenliang Zhao, Benlin Liu, Jiwen Lu, Jie Zhou, Ch

Yongming Rao 414 Jan 01, 2023
Imitating Deep Learning Dynamics via Locally Elastic Stochastic Differential Equations

Imitating Deep Learning Dynamics via Locally Elastic Stochastic Differential Equations This repo contains official code for the NeurIPS 2021 paper Imi

Jiayao Zhang 2 Oct 18, 2021
Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks

OnsagerNet Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks This is the original pyTorch implemenati

Haijun.Yu 3 Aug 24, 2022
Fader Networks: Manipulating Images by Sliding Attributes - NIPS 2017

FaderNetworks PyTorch implementation of Fader Networks (NIPS 2017). Fader Networks can generate different realistic versions of images by modifying at

Facebook Research 753 Dec 23, 2022
This is the official repository of XVFI (eXtreme Video Frame Interpolation)

XVFI This is the official repository of XVFI (eXtreme Video Frame Interpolation), https://arxiv.org/abs/2103.16206 Last Update: 20210607 We provide th

Jihyong Oh 195 Dec 29, 2022
Mall-Customers-Segmentation - Customer Segmentation Using K-Means Clustering

Overview Customer Segmentation is one the most important applications of unsupervised learning. Using clustering techniques, companies can identify th

NelakurthiSudheer 2 Jan 03, 2022
Code-free deep segmentation for computational pathology

NoCodeSeg: Deep segmentation made easy! This is the official repository for the manuscript "Code-free development and deployment of deep segmentation

André Pedersen 26 Nov 23, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 01, 2022
PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch.

snn-localization repo PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch. Install Dependencies Orig

Sami BARCHID 1 Jan 06, 2022
Official Implementation of SWAGAN: A Style-based Wavelet-driven Generative Model

Official Implementation of SWAGAN: A Style-based Wavelet-driven Generative Model SWAGAN: A Style-based Wavelet-driven Generative Model Rinon Gal, Dana

55 Dec 06, 2022
Real-Time Social Distance Monitoring tool using Computer Vision

Social Distance Detector A Real-Time Social Distance Monitoring Tool Table of Contents Motivation YOLO Theory Detection Output Tech Stack Functionalit

Pranav B 13 Oct 14, 2022
A resource for learning about ML, DL, PyTorch and TensorFlow. Feedback always appreciated :)

A resource for learning about ML, DL, PyTorch and TensorFlow. Feedback always appreciated :)

Aladdin Persson 4.7k Jan 08, 2023
Converting CPT to bert form for use

cpt-encoder 将CPT转成bert形式使用 说明 刚刚刷到又出了一种模型:CPT,看论文显示,在很多中文任务上性能比mac bert还好,就迫不及待想把它用起来。 根据对源码的研究,发现该模型在做nlu建模时主要用的encoder部分,也就是bert,因此我将这部分权重转为bert权重类型

黄辉 1 Oct 14, 2021
AntroPy: entropy and complexity of (EEG) time-series in Python

AntroPy is a Python 3 package providing several time-efficient algorithms for computing the complexity of time-series. It can be used for example to e

Raphael Vallat 153 Dec 27, 2022
Official Implementation for HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing

HyperStyle: StyleGAN Inversion with HyperNetworks for Real Image Editing Yuval Alaluf*, Omer Tov*, Ron Mokady, Rinon Gal, Amit H. Bermano *Denotes equ

885 Jan 06, 2023
SAT: 2D Semantics Assisted Training for 3D Visual Grounding, ICCV 2021 (Oral)

SAT: 2D Semantics Assisted Training for 3D Visual Grounding SAT: 2D Semantics Assisted Training for 3D Visual Grounding by Zhengyuan Yang, Songyang Zh

Zhengyuan Yang 22 Nov 30, 2022
Spatial Contrastive Learning for Few-Shot Classification (SCL)

This repo contains the official implementation of Spatial Contrastive Learning for Few-Shot Classification (SCL), which presents of a novel contrastive learning method applied to few-shot image class

Yassine 34 Dec 25, 2022
Accompanying code for the paper "A Kernel Test for Causal Association via Noise Contrastive Backdoor Adjustment".

#backdoor-HSIC (bd_HSIC) Accompanying code for the paper "A Kernel Test for Causal Association via Noise Contrastive Backdoor Adjustment". To generate

Robert Hu 0 Nov 25, 2021