PyTorch implementation of SwAV (Swapping Assignments between Views)

Related tags

Deep Learningswav
Overview

Unsupervised Learning of Visual Features by Contrasting Cluster Assignments

This code provides a PyTorch implementation and pretrained models for SwAV (Swapping Assignments between Views), as described in the paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.

SwAV Illustration

SwAV is an efficient and simple method for pre-training convnets without using annotations. Similarly to contrastive approaches, SwAV learns representations by comparing transformations of an image, but unlike contrastive methods, it does not require to compute feature pairwise comparisons. It makes our framework more efficient since it does not require a large memory bank or an auxiliary momentum network. Specifically, our method simultaneously clusters the data while enforcing consistency between cluster assignments produced for different augmentations (or “views”) of the same image, instead of comparing features directly. Simply put, we use a “swapped” prediction mechanism where we predict the cluster assignment of a view from the representation of another view. Our method can be trained with large and small batches and can scale to unlimited amounts of data.

Model Zoo

We release several models pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone. To load our best SwAV pre-trained ResNet-50 model, simply do:

import torch
model = torch.hub.load('facebookresearch/swav:main', 'resnet50')

We provide several baseline SwAV pre-trained models with ResNet-50 architecture in torchvision format. We also provide models pre-trained with DeepCluster-v2 and SeLa-v2 obtained by applying improvements from the self-supervised community to DeepCluster and SeLa (see details in the appendix of our paper).

method epochs batch-size multi-crop ImageNet top-1 acc. url args
SwAV 800 4096 2x224 + 6x96 75.3 model script
SwAV 400 4096 2x224 + 6x96 74.6 model script
SwAV 200 4096 2x224 + 6x96 73.9 model script
SwAV 100 4096 2x224 + 6x96 72.1 model script
SwAV 200 256 2x224 + 6x96 72.7 model script
SwAV 400 256 2x224 + 6x96 74.3 model script
SwAV 400 4096 2x224 70.1 model script
DeepCluster-v2 800 4096 2x224 + 6x96 75.2 model script
DeepCluster-v2 400 4096 2x160 + 4x96 74.3 model script
DeepCluster-v2 400 4096 2x224 70.2 model script
SeLa-v2 400 4096 2x160 + 4x96 71.8 model -
SeLa-v2 400 4096 2x224 67.2 model -

Larger architectures

We provide SwAV models with ResNet-50 networks where we multiply the width by a factor ×2, ×4, and ×5. To load the corresponding backbone you can use:

import torch
rn50w2 = torch.hub.load('facebookresearch/swav:main', 'resnet50w2')
rn50w4 = torch.hub.load('facebookresearch/swav:main', 'resnet50w4')
rn50w5 = torch.hub.load('facebookresearch/swav:main', 'resnet50w5')
network parameters epochs ImageNet top-1 acc. url args
RN50-w2 94M 400 77.3 model script
RN50-w4 375M 400 77.9 model script
RN50-w5 586M 400 78.5 model -

Running times

We provide the running times for some of our runs:

method batch-size multi-crop scripts time per epoch
SwAV 4096 2x224 + 6x96 * * * * 3min40s
SwAV 256 2x224 + 6x96 * * 52min10s
DeepCluster-v2 4096 2x160 + 4x96 * 3min13s

Running SwAV unsupervised training

Requirements

Singlenode training

SwAV is very simple to implement and experiment with. Our implementation consists in a main_swav.py file from which are imported the dataset definition src/multicropdataset.py, the model architecture src/resnet50.py and some miscellaneous training utilities src/utils.py.

For example, to train SwAV baseline on a single node with 8 gpus for 400 epochs, run:

python -m torch.distributed.launch --nproc_per_node=8 main_swav.py \
--data_path /path/to/imagenet/train \
--epochs 400 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15

Multinode training

Distributed training is available via Slurm. We provide several SBATCH scripts to reproduce our SwAV models. For example, to train SwAV on 8 nodes and 64 GPUs with a batch size of 4096 for 800 epochs run:

sbatch ./scripts/swav_800ep_pretrain.sh

Note that you might need to remove the copyright header from the sbatch file to launch it.

Set up dist_url parameter: We refer the user to pytorch distributed documentation (env or file or tcp) for setting the distributed initialization method (parameter dist_url) correctly. In the provided sbatch files, we use the tcp init method (see * for example).

Evaluating models

Evaluate models: Linear classification on ImageNet

To train a supervised linear classifier on frozen features/weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar

The resulting linear classifier can be downloaded here.

Evaluate models: Semi-supervised learning on ImageNet

To reproduce our results and fine-tune a network with 1% or 10% of ImageNet labels on a single node with 8 gpus, run:

  • 10% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "10" \
--lr 0.01 \
--lr_last_layer 0.2
  • 1% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "1" \
--lr 0.02 \
--lr_last_layer 5

Evaluate models: Transferring to Detection with DETR

DETR is a recent object detection framework that reaches competitive performance with Faster R-CNN while being conceptually simpler and trainable end-to-end. We evaluate our SwAV ResNet-50 backbone on object detection on COCO dataset using DETR framework with full fine-tuning. Here are the instructions for reproducing our experiments:

  1. Install detr and prepare COCO dataset following these instructions.

  2. Apply the changes highlighted in this gist to detr backbone file in order to load SwAV backbone instead of ImageNet supervised weights.

  3. Launch training from detr repository with run_with_submitit.py.

python run_with_submitit.py --batch_size 4 --nodes 2 --lr_backbone 5e-5

Common Issues

For help or issues using SwAV, please submit a GitHub issue.

The loss does not decrease and is stuck at ln(nmb_prototypes) (8.006 for 3000 prototypes).

It sometimes happens that the system collapses at the beginning and does not manage to converge. We have found the following empirical workarounds to improve convergence and avoid collapsing at the beginning:

  • use a lower epsilon value (--epsilon 0.03 instead of the default 0.05)
  • carefully tune the hyper-parameters
  • freeze the prototypes during first iterations (freeze_prototypes_niters argument)
  • switch to hard assignment
  • remove batch-normalization layer from the projection head
  • reduce the difficulty of the problem (less crops or softer data augmentation)

We now analyze the collapsing problem: it happens when all examples are mapped to the same unique representation. In other words, the convnet always has the same output regardless of its input, it is a constant function. All examples gets the same cluster assignment because they are identical, and the only valid assignment that satisfy the equipartition constraint in this case is the uniform assignment (1/K where K is the number of prototypes). In turn, this uniform assignment is trivial to predict since it is the same for all examples. Reducing epsilon parameter (see Eq(3) of our paper) encourages the assignments Q to be sharper (i.e. less uniform), which strongly helps avoiding collapse. However, using a too low value for epsilon may lead to numerical instability.

Training gets unstable when using the queue.

The queue is composed of feature representations from the previous batches. These lines discard the oldest feature representations from the queue and save the newest one (i.e. from the current batch) through a round-robin mechanism. This way, the assignment problem is performed on more samples: without the queue we assign B examples to num_prototypes clusters where B is the total batch size while with the queue we assign (B + queue_length) examples to num_prototypes clusters. This is especially useful when working with small batches because it improves the precision of the assignment.

If you start using the queue too early or if you use a too large queue, this can considerably disturb training: this is because the queue members are too inconsistent. After introducing the queue the loss should be lower than what it was without the queue. On the following loss curve (30 first epochs of this script) we introduced the queue at epoch 15. We observe that it made the loss go more down.

SwAV training loss batch_size=256 during the first 30 epochs

If when introducing the queue, the loss goes up and does not decrease afterwards you should stop your training and change the queue parameters. We recommend (i) using a smaller queue, (ii) starting the queue later in training.

License

See the LICENSE file for more details.

See also

PyTorch Lightning Bolts: Implementation by the Lightning team.

SwAV-TF: A TensorFlow re-implementation.

Citation

If you find this repository useful in your research, please cite:

@article{caron2020unsupervised,
  title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments},
  author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand},
  booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)},
  year={2020}
}
Owner
Meta Research
Meta Research
Adaptable tools to make reinforcement learning and evolutionary computation algorithms.

Pearl The Parallel Evolutionary and Reinforcement Learning Library (Pearl) is a pytorch based package with the goal of being excellent for rapid proto

38 Jan 01, 2023
Global-Local Attention for Emotion Recognition

Global-Local Attention for Emotion Recognition Requirements Python 3 Install tensorflow (or tensorflow-gpu) = 2.0.0 Install some other packages pip i

Minh Nhat Le 15 Apr 21, 2022
Fully convolutional deep neural network to remove transparent overlays from images

Fully convolutional deep neural network to remove transparent overlays from images

Marc Belmont 1.1k Jan 06, 2023
DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data.

DWIPrep: A Robust Preprocessing Pipeline for dMRI Data DWIPrep is a robust and easy-to-use pipeline for preprocessing of diverse dMRI data. The transp

Gal Ben-Zvi 1 Jan 09, 2023
Deep Learning (with PyTorch)

Deep Learning (with PyTorch) This notebook repository now has a companion website, where all the course material can be found in video and textual for

Alfredo Canziani 6.2k Jan 07, 2023
Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations.

Pyserini Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations. Retrieval using sparse re

Castorini 706 Dec 29, 2022
Source code of SIGIR2021 Paper 'One Chatbot Per Person: Creating Personalized Chatbots based on Implicit Profiles'

DHAP Source code of SIGIR2021 Long Paper: One Chatbot Per Person: Creating Personalized Chatbots based on Implicit User Profiles . Preinstallation Fir

ZYMa 32 Dec 06, 2022
Denoising Diffusion Probabilistic Models

Denoising Diffusion Probabilistic Models Jonathan Ho, Ajay Jain, Pieter Abbeel Paper: https://arxiv.org/abs/2006.11239 Website: https://hojonathanho.g

Jonathan Ho 1.5k Jan 08, 2023
Unsupervised Learning of Video Representations using LSTMs

Unsupervised Learning of Video Representations using LSTMs Code for paper Unsupervised Learning of Video Representations using LSTMs by Nitish Srivast

Elman Mansimov 341 Dec 20, 2022
Implementation of ProteinBERT in Pytorch

ProteinBERT - Pytorch (wip) Implementation of ProteinBERT in Pytorch. Original Repository Install $ pip install protein-bert-pytorch Usage import torc

Phil Wang 92 Dec 25, 2022
Python implementation of Wu et al (2018)'s registration fusion

reg-fusion Projection of a central sulcus probability map using the RF-ANTs approach (right hemisphere shown). This is a Python implementation of Wu e

Dan Gale 26 Nov 12, 2021
[CVPR 2022 Oral] TubeDETR: Spatio-Temporal Video Grounding with Transformers

TubeDETR: Spatio-Temporal Video Grounding with Transformers Website • STVG Demo • Paper This repository provides the code for our paper. This includes

Antoine Yang 108 Dec 27, 2022
Keras implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

Keras implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

Yam Peleg 63 Sep 21, 2022
Depth image based mouse cursor visual haptic

Depth image based mouse cursor visual haptic How to run it. Install pyqt5. Install python modules pip install Pillow pip install numpy For illustrati

Xiong Jie 17 Dec 20, 2022
In this project, we create and implement a deep learning library from scratch.

ARA In this project, we create and implement a deep learning library from scratch. Table of Contents Deep Leaning Library Table of Contents About The

22 Aug 23, 2022
codes for "Scheduled Sampling Based on Decoding Steps for Neural Machine Translation" (long paper of EMNLP-2022)

Scheduled Sampling Based on Decoding Steps for Neural Machine Translation (EMNLP-2021 main conference) Contents Overview Background Quick to Use Furth

Adaxry 13 Jul 25, 2022
MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python

MNE-Python MNE-Python software is an open-source Python package for exploring, visualizing, and analyzing human neurophysiological data such as MEG, E

MNE tools for MEG and EEG data analysis 2.1k Dec 28, 2022
Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision. ICCV 2021.

Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision Download links and PyTorch implementation of "Towers of Ba

Blakey Wu 40 Dec 14, 2022
Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors

Gas detection Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors. Description The MQ-2 sensor can detect multiple gases (CO, H2, CH4, LPG,

Filip Š 15 Sep 30, 2022
Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral

Good news! We release a clean version of PVNet: clean-pvnet, including how to train the PVNet on the custom dataset. Use PVNet with a detector. The tr

ZJU3DV 722 Dec 27, 2022