PAWS 🐾 Predicting View-Assignments with Support Samples

Related tags

Deep Learningsuncet
Overview

PAWS 🐾 Predicting View-Assignments with Support Samples

This repo provides a PyTorch implementation of PAWS (predicting view assignments with support samples), as described in the paper Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples.

CD21_260_SWAV2_PAWS_Flowchart_FINAL

PAWS is a method for semi-supervised learning that builds on the principles of self-supervised distance-metric learning. PAWS pre-trains a model to minimize a consistency loss, which ensures that different views of the same unlabeled image are assigned similar pseudo-labels. The pseudo-labels are generated non-parametrically, by comparing the representations of the image views to those of a set of randomly sampled labeled images. The distance between the view representations and labeled representations is used to provide a weighting over class labels, which we interpret as a soft pseudo-label. By non-parametrically incorporating labeled samples in this way, PAWS extends the distance-metric loss used in self-supervised methods such as BYOL and SwAV to the semi-supervised setting.

Also provided in this repo is a PyTorch implementation of the semi-supervised SimCLR+CT method described in the paper Supervision Accelerates Pretraining in Contrastive Semi-Supervised Learning of Visual Representations. SimCLR+CT combines the SimCLR self-supervised loss with the SuNCEt (supervised noise contrastive estimation) loss for semi-supervised learning.

Pretrained models

We provide the full checkpoints for the PAWS pre-trained models, both with and without fine-tuning. The full checkpoints for the pretrained models contain the backbone, projection head, and prediction head weights. The finetuned model checkpoints, on the other hand, only include the backbone and linear classifier head weights. Top-1 classification accuracy for the pretrained models is reported using a nearest neighbour classifier. Top-1 classification accuracy for the finetuned models is reported using the class labels predicted by the network's last linear layer.

1% labels 10% labels
epochs network pretrained (NN) finetuned pretrained (NN) finetuned
300 RN50 65.4% 66.5% 73.1% 75.5%
200 RN50 64.6% 66.1% 71.9% 75.0%
100 RN50 62.6% 63.8% 71.0% 73.9%

Running PAWS semi-supervised pre-training and fine-tuning

Config files

All experiment parameters are specified in config files (as opposed to command-line-arguments). Config files make it easier to keep track of different experiments, as well as launch batches of jobs at a time. See the configs/ directory for example config files.

Requirements

  • Python 3.8
  • PyTorch install 1.7.1
  • torchvision
  • CUDA 11.0
  • Apex with CUDA extension
  • Other dependencies: PyYaml, numpy, opencv, submitit

Labeled Training Splits

For reproducibilty, we have pre-specified the labeled training images as .txt files in the imagenet_subsets/ and cifar10_subsets/ directories. Based on your specifications in your experiment's config file, our implementation will automatically use the images specified in one of these .txt files as the set of labeled images. On ImageNet, if you happen to request a split of the data that is not contained in imagenet_subsets/ (for example, if you set unlabeled_frac !=0.9 and unlabeled_frac != 0.99, i.e., not 10% labeled or 1% labeled settings), then the code will independently flip a coin at the start of training for each training image with probability 1-unlabeled_frac to determine whether or not to keep the image's label.

Single-GPU training

PAWS is very simple to implement and experiment with. Our implementation starts from the main.py, which parses the experiment config file and runs the desired script (e.g., paws pre-training or fine-tuning) locally on a single GPU.

CIFAR10 pre-training

For example, to pre-train with PAWS on CIFAR10 locally, using a single GPU using the pre-training experiment configs specificed inside configs/paws/cifar10_train.yaml, run:

python main.py
  --sel paws_train
  --fname configs/paws/cifar10_train.yaml

CIFAR10 evaluation

To fine-tune the pre-trained model for a few optimization steps with the SuNCEt (supervised noise contrastive estimation) loss on a single GPU using the pre-training experiment configs specificed inside configs/paws/cifar10_snn.yaml, run:

python main.py
  --sel snn_fine_tune
  --fname configs/paws/cifar10_snn.yaml

To then evaluate the nearest-neighbours performance of the model, locally, on a single GPU, run:

python snn_eval.py
  --model-name wide_resnet28w2 --use-pred
  --pretrained $path_to_pretrained_model
  --unlabeled_frac $1.-fraction_of_labeled_train_data_to_support_nearest_neighbour_classification
  --root-path $path_to_root_datasets_directory
  --image-folder $image_directory_inside_root_path
  --dataset-name cifar10_fine_tune
  --split-seed $which_prespecified_seed_to_split_labeled_data

Multi-GPU training

Running PAWS across multiple GPUs on a cluster is also very simple. In the multi-GPU setting, the implementation starts from main_distributed.py, which, in addition to parsing the config file and launching the desired script, also allows for specifying details about distributed training. For distributed training, we use the popular open-source submitit tool and provide examples for a SLURM cluster, but feel free to edit main_distributed.py for your purposes to specify a different approach to launching a multi-GPU job on a cluster.

ImageNet pre-training

For example, to pre-train with PAWS on 64 GPUs using the pre-training experiment configs specificed inside configs/paws/imgnt_train.yaml, run:

python main_distributed.py
  --sel paws_train
  --fname configs/paws/imgnt_train.yaml
  --partition $slurm_partition
  --nodes 8 --tasks-per-node 8
  --time 1000
  --device volta16gb

ImageNet fine-tuning

To fine-tune a pre-trained model on 4 GPUs using the fine-tuning experiment configs specified inside configs/paws/fine_tune.yaml, run:

python main_distributed.py
  --sel fine_tune
  --fname configs/paws/fine_tune.yaml
  --partition $slurm_partition
  --nodes 1 --tasks-per-node 4
  --time 1000
  --device volta16gb

To evaluate the fine-tuned model locally on a single GPU, use the same config file, configs/paws/fine_tune.yaml, but change training: true to training: false. Then run:

python main.py
  --sel fine_tune
  --fname configs/paws/fine_tune.yaml

Soft Nearest Neighbours evaluation

To evaluate the nearest-neighbours performance of a pre-trained ResNet50 model on a single GPU, run:

python snn_eval.py
  --model-name resnet50 --use-pred
  --pretrained $path_to_pretrained_model
  --unlabeled_frac $1.-fraction_of_labeled_train_data_to_support_nearest_neighbour_classification
  --root-path $path_to_root_datasets_directory
  --image-folder $image_directory_inside_root_path
  --dataset-name $one_of:[imagenet_fine_tune, cifar10_fine_tune]

License

See the LICENSE file for details about the license under which this code is made available.

Citation

If you find this repository useful in your research, please consider giving a star and a citation 🐾

@article{assran2021semisupervised,
  title={Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples}, 
  author={Assran, Mahmoud, and Caron, Mathilde, and Misra, Ishan, and Bojanowski, Piotr and Joulin, Armand, and Ballas, Nicolas, and Rabbat, Michael},
  journal={arXiv preprint arXiv:2104.13963},
  year={2021}
}
@article{assran2020supervision,
  title={Supervision Accelerates Pretraining in Contrastive Semi-Supervised Learning of Visual Representations},
  author={Assran, Mahmoud, and Ballas, Nicolas, and Castrejon, Lluis, and Rabbat, Michael},
  journal={arXiv preprint arXiv:2006.10803},
  year={2020}
}
Owner
Facebook Research
Facebook Research
SegTransVAE: Hybrid CNN - Transformer with Regularization for medical image segmentation

SegTransVAE: Hybrid CNN - Transformer with Regularization for medical image segmentation This repo is the official implementation for SegTransVAE. Seg

Nguyen Truong Hai 4 Aug 04, 2022
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Ibai Gorordo 12 Aug 27, 2022
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
A simple program for training and testing vit

Vit This is a simple program for training and testing vit. Key requirements: torch, torchvision and timm. Dataset I put 5 categories of the cub classi

xiezhenyu 2 Oct 11, 2022
D-NeRF: Neural Radiance Fields for Dynamic Scenes

D-NeRF: Neural Radiance Fields for Dynamic Scenes [Project] [Paper] D-NeRF is a method for synthesizing novel views, at an arbitrary point in time, of

Albert Pumarola 291 Jan 02, 2023
A "gym" style toolkit for building lightweight Neural Architecture Search systems

A "gym" style toolkit for building lightweight Neural Architecture Search systems

Jack Turner 12 Nov 05, 2022
Node Dependent Local Smoothing for Scalable Graph Learning

Node Dependent Local Smoothing for Scalable Graph Learning Requirements Environments: Xeon Gold 5120 (CPU), 384GB(RAM), TITAN RTX (GPU), Ubuntu 16.04

Wentao Zhang 15 Nov 28, 2022
Contour-guided image completion with perceptual grouping (BMVC 2021 publication)

Contour-guided Image Completion with Perceptual Grouping Authors Morteza Rezanejad*, Sidharth Gupta*, Chandra Gummaluru, Ryan Marten, John Wilder, Mic

Sid Gupta 6 Dec 27, 2022
Dense Contrastive Learning (DenseCL) for self-supervised representation learning, CVPR 2021.

Dense Contrastive Learning for Self-Supervised Visual Pre-Training This project hosts the code for implementing the DenseCL algorithm for se

Xinlong Wang 491 Jan 03, 2023
Code for the CIKM 2019 paper "DSANet: Dual Self-Attention Network for Multivariate Time Series Forecasting".

Dual Self-Attention Network for Multivariate Time Series Forecasting 20.10.26 Update: Due to the difficulty of installation and code maintenance cause

Kyon Huang 223 Dec 16, 2022
Semi-Supervised Learning with Ladder Networks in Keras. Get 98% test accuracy on MNIST with just 100 labeled examples !

Semi-Supervised Learning with Ladder Networks in Keras This is an implementation of Ladder Network in Keras. Ladder network is a model for semi-superv

Divam Gupta 101 Sep 07, 2022
Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction

Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction Requirements The code has been tested running under Python 3.7.4, with the foll

zshicode 84 Jan 01, 2023
Train/evaluate a Keras model, get metrics streamed to a dashboard in your browser.

Hera Train/evaluate a Keras model, get metrics streamed to a dashboard in your browser. Setting up Step 1. Plant the spy Install the package pip

Keplr 495 Dec 10, 2022
ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

ICCV2021 Paper: AutoShape: Real-Time Shape-Aware Monocular 3D Object Detection

Zongdai 107 Dec 20, 2022
Corgis are the cutest creatures; have 30K of them!

corgi-net This is a dataset of corgi images scraped from the corgi subreddit. After filtering using an ImageNet classifier, the training set consists

Alex Nichol 6 Dec 24, 2022
Dilated Convolution for Semantic Image Segmentation

Multi-Scale Context Aggregation by Dilated Convolutions Introduction Properties of dilated convolution are discussed in our ICLR 2016 conference paper

Fisher Yu 764 Dec 26, 2022
Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos.

A Latent Transformer for Disentangled Face Editing in Images and Videos Official implementation for paper: A Latent Transformer for Disentangled Face

InterDigital 108 Dec 09, 2022
Machine Learning Model deployment for Container (TensorFlow Serving)

try_tf_serving ├───dataset │ ├───testing │ │ ├───paper │ │ ├───rock │ │ └───scissors │ └───training │ ├───paper │ ├───rock

Azhar Rizki Zulma 5 Jan 07, 2022
A curated list of awesome open source libraries to deploy, monitor, version and scale your machine learning

Awesome production machine learning This repository contains a curated list of awesome open source libraries that will help you deploy, monitor, versi

The Institute for Ethical Machine Learning 12.9k Jan 04, 2023
Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Nikhil Iyer 9 Dec 27, 2022