PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Overview

SupContrast: Supervised Contrastive Learning

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. Paper
(2) A Simple Framework for Contrastive Learning of Visual Representations. Paper

Loss Function

The loss function SupConLoss in losses.py takes features (L2 normalized) and labels as input, and return the loss. If labels is None or not passed to the it, it degenerates to SimCLR.

Usage:

from losses import SupConLoss

# define loss with a temperature `temp`
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = ...
# labels: [bsz]
labels = ...

# SupContrast
loss = criterion(features, labels)
# or SimCLR
loss = criterion(features)
...

Comparison

Results on CIFAR-10:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 95.0
SupContrast ResNet50 Supervised Contrastive 96.0
SimCLR ResNet50 Unsupervised Contrastive 93.6

Results on CIFAR-100:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 75.3
SupContrast ResNet50 Supervised Contrastive 76.5
SimCLR ResNet50 Unsupervised Contrastive 70.7

Results on ImageNet (Stay tuned):

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy -
SupContrast ResNet50 Supervised Contrastive 79.1 (MoCo trick)
SimCLR ResNet50 Unsupervised Contrastive -

Running

You might use CUDA_VISIBLE_DEVICES to set proper number of GPUs, and/or switch to CIFAR100 by --dataset cifar100.
(1) Standard Cross-Entropy

python main_ce.py --batch_size 1024 \
  --learning_rate 0.8 \
  --cosine --syncBN \

(2) Supervised Contrastive Learning
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.1 \
  --cosine

You can also specify --syncBN but I found it not crucial for SupContrast (syncBN 95.9% v.s. BN 96.0%).
Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 5 \
  --ckpt /path/to/model.pth

(3) SimCLR
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.5 \
  --cosine --syncBN \
  --method SimCLR

The --method SimCLR flag simply stops labels from being passed to SupConLoss criterion. Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 1 \
  --ckpt /path/to/model.pth

On custom dataset:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5  \ 
  --temp 0.1 --cosine \
  --dataset path \
  --data_folder ./path \
  --mean "(0.4914, 0.4822, 0.4465)" \
  --std "(0.2675, 0.2565, 0.2761)" \
  --method SimCLR

The --data_folder must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and

t-SNE Visualization

(1) Standard Cross-Entropy

(2) Supervised Contrastive Learning

(3) SimCLR

Reference

@Article{khosla2020supervised,
    title   = {Supervised Contrastive Learning},
    author  = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
    journal = {arXiv preprint arXiv:2004.11362},
    year    = {2020},
}
Owner
Yonglong Tian
CS Ph.D. student in AI @ MIT
Yonglong Tian
PECOS - Prediction for Enormous and Correlated Spaces

PECOS - Predictions for Enormous and Correlated Output Spaces PECOS is a versatile and modular machine learning (ML) framework for fast learning and i

Amazon 387 Jan 04, 2023
[NeurIPS 2021] Introspective Distillation for Robust Question Answering

Introspective Distillation (IntroD) This repository is the Pytorch implementation of our paper "Introspective Distillation for Robust Question Answeri

Yulei Niu 13 Jul 26, 2022
Code for "Human Pose Regression with Residual Log-likelihood Estimation", ICCV 2021 Oral

Human Pose Regression with Residual Log-likelihood Estimation [Paper] [arXiv] [Project Page] Human Pose Regression with Residual Log-likelihood Estima

JeffLi 347 Dec 24, 2022
The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

Facebook Research 1k Dec 31, 2022
A ssl analyzer which could analyzer target domain's certificate.

ssl_analyzer A ssl analyzer which could analyzer target domain's certificate. Analyze the domain name ssl certificate information according to the inp

vincent 17 Dec 12, 2022
MTA:SA Server Configer.

MTAConfiger MTA:SA Server Configer. Hi 👋 , I'm Alireza A Python Developer Boy 🔭 I’m currently working on my C# projects 🌱 I’m currently Learning CS

3 Jun 07, 2022
Rendering Point Clouds with Compute Shaders

Compute Shader Based Point Cloud Rendering This repository contains the source code to our techreport: Rendering Point Clouds with Compute Shaders and

Markus Schütz 460 Jan 05, 2023
repro_eval is a collection of measures to evaluate the reproducibility/replicability of system-oriented IR experiments

repro_eval repro_eval is a collection of measures to evaluate the reproducibility/replicability of system-oriented IR experiments. The measures were d

IR Group at Technische Hochschule Köln 9 May 25, 2022
RATE: Overcoming Noise and Sparsity of Textual Features in Real-Time Location Estimation (CIKM'17)

RATE: Overcoming Noise and Sparsity of Textual Features in Real-Time Location Estimation This is the implementation of RATE: Overcoming Noise and Spar

Yu Zhang 5 Feb 10, 2022
Invert and perturb GAN images for test-time ensembling

GAN Ensembling Project Page | Paper | Bibtex Ensembling with Deep Generative Views. Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhan

Lucy Chai 93 Dec 08, 2022
A proof of concept ai-powered Recaptcha v2 solver

Recaptcha Fullauto I've decided to open source my old Recaptcha v2 solver. My latest version will be opened sourced this summer. I am hoping this proj

Nate 60 Dec 20, 2022
Caffe models in TensorFlow

Caffe to TensorFlow Convert Caffe models to TensorFlow. Usage Run convert.py to convert an existing Caffe model to TensorFlow. Make sure you're using

Saumitro Dasgupta 2.8k Dec 31, 2022
Balancing Principle for Unsupervised Domain Adaptation

Blancing Principle for Domain Adaptation NeurIPS 2021 Paper Abstract We address the unsolved algorithm design problem of choosing a justified regulari

Marius-Constantin Dinu 4 Dec 15, 2022
PFFDTD is an open-source FDTD simulator for 3D room acoustics

PFFDTD is an open-source FDTD simulator for 3D room acoustics

Brian Hamilton 34 Nov 24, 2022
Differentiable architecture search for convolutional and recurrent networks

Differentiable Architecture Search Code accompanying the paper DARTS: Differentiable Architecture Search Hanxiao Liu, Karen Simonyan, Yiming Yang. arX

Hanxiao Liu 3.7k Jan 09, 2023
Convert Mission Planner (ArduCopter) Waypoint Missions to Litchi CSV Format to execute on DJI Drones

Mission Planner to Litchi Convert Mission Planner (ArduCopter) Waypoint Surveys to Litchi CSV Format to execute on DJI Drones Litchi doesn't support S

Yaros 24 Dec 09, 2022
Centroid-UNet is deep neural network model to detect centroids from satellite images.

Centroid UNet - Locating Object Centroids in Aerial/Serial Images Introduction Centroid-UNet is deep neural network model to detect centroids from Aer

GIC-AIT 19 Dec 08, 2022
Time-Optimal Planning for Quadrotor Waypoint Flight

Time-Optimal Planning for Quadrotor Waypoint Flight This is an example implementation of the paper "Time-Optimal Planning for Quadrotor Waypoint Fligh

Robotics and Perception Group 38 Dec 02, 2022
PyTorch implementation of an end-to-end Handwritten Text Recognition (HTR) system based on attention encoder-decoder networks

AttentionHTR PyTorch implementation of an end-to-end Handwritten Text Recognition (HTR) system based on attention encoder-decoder networks. Scene Text

Dmitrijs Kass 31 Dec 22, 2022