PyTorch implementation of SwAV (Swapping Assignments between Views)

Related tags

Deep Learningswav
Overview

Unsupervised Learning of Visual Features by Contrasting Cluster Assignments

This code provides a PyTorch implementation and pretrained models for SwAV (Swapping Assignments between Views), as described in the paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.

SwAV Illustration

SwAV is an efficient and simple method for pre-training convnets without using annotations. Similarly to contrastive approaches, SwAV learns representations by comparing transformations of an image, but unlike contrastive methods, it does not require to compute feature pairwise comparisons. It makes our framework more efficient since it does not require a large memory bank or an auxiliary momentum network. Specifically, our method simultaneously clusters the data while enforcing consistency between cluster assignments produced for different augmentations (or “views”) of the same image, instead of comparing features directly. Simply put, we use a “swapped” prediction mechanism where we predict the cluster assignment of a view from the representation of another view. Our method can be trained with large and small batches and can scale to unlimited amounts of data.

Model Zoo

We release several models pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone. To load our best SwAV pre-trained ResNet-50 model, simply do:

import torch
model = torch.hub.load('facebookresearch/swav:main', 'resnet50')

We provide several baseline SwAV pre-trained models with ResNet-50 architecture in torchvision format. We also provide models pre-trained with DeepCluster-v2 and SeLa-v2 obtained by applying improvements from the self-supervised community to DeepCluster and SeLa (see details in the appendix of our paper).

method epochs batch-size multi-crop ImageNet top-1 acc. url args
SwAV 800 4096 2x224 + 6x96 75.3 model script
SwAV 400 4096 2x224 + 6x96 74.6 model script
SwAV 200 4096 2x224 + 6x96 73.9 model script
SwAV 100 4096 2x224 + 6x96 72.1 model script
SwAV 200 256 2x224 + 6x96 72.7 model script
SwAV 400 256 2x224 + 6x96 74.3 model script
SwAV 400 4096 2x224 70.1 model script
DeepCluster-v2 800 4096 2x224 + 6x96 75.2 model script
DeepCluster-v2 400 4096 2x160 + 4x96 74.3 model script
DeepCluster-v2 400 4096 2x224 70.2 model script
SeLa-v2 400 4096 2x160 + 4x96 71.8 model -
SeLa-v2 400 4096 2x224 67.2 model -

Larger architectures

We provide SwAV models with ResNet-50 networks where we multiply the width by a factor ×2, ×4, and ×5. To load the corresponding backbone you can use:

import torch
rn50w2 = torch.hub.load('facebookresearch/swav:main', 'resnet50w2')
rn50w4 = torch.hub.load('facebookresearch/swav:main', 'resnet50w4')
rn50w5 = torch.hub.load('facebookresearch/swav:main', 'resnet50w5')
network parameters epochs ImageNet top-1 acc. url args
RN50-w2 94M 400 77.3 model script
RN50-w4 375M 400 77.9 model script
RN50-w5 586M 400 78.5 model -

Running times

We provide the running times for some of our runs:

method batch-size multi-crop scripts time per epoch
SwAV 4096 2x224 + 6x96 * * * * 3min40s
SwAV 256 2x224 + 6x96 * * 52min10s
DeepCluster-v2 4096 2x160 + 4x96 * 3min13s

Running SwAV unsupervised training

Requirements

Singlenode training

SwAV is very simple to implement and experiment with. Our implementation consists in a main_swav.py file from which are imported the dataset definition src/multicropdataset.py, the model architecture src/resnet50.py and some miscellaneous training utilities src/utils.py.

For example, to train SwAV baseline on a single node with 8 gpus for 400 epochs, run:

python -m torch.distributed.launch --nproc_per_node=8 main_swav.py \
--data_path /path/to/imagenet/train \
--epochs 400 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15

Multinode training

Distributed training is available via Slurm. We provide several SBATCH scripts to reproduce our SwAV models. For example, to train SwAV on 8 nodes and 64 GPUs with a batch size of 4096 for 800 epochs run:

sbatch ./scripts/swav_800ep_pretrain.sh

Note that you might need to remove the copyright header from the sbatch file to launch it.

Set up dist_url parameter: We refer the user to pytorch distributed documentation (env or file or tcp) for setting the distributed initialization method (parameter dist_url) correctly. In the provided sbatch files, we use the tcp init method (see * for example).

Evaluating models

Evaluate models: Linear classification on ImageNet

To train a supervised linear classifier on frozen features/weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar

The resulting linear classifier can be downloaded here.

Evaluate models: Semi-supervised learning on ImageNet

To reproduce our results and fine-tune a network with 1% or 10% of ImageNet labels on a single node with 8 gpus, run:

  • 10% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "10" \
--lr 0.01 \
--lr_last_layer 0.2
  • 1% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "1" \
--lr 0.02 \
--lr_last_layer 5

Evaluate models: Transferring to Detection with DETR

DETR is a recent object detection framework that reaches competitive performance with Faster R-CNN while being conceptually simpler and trainable end-to-end. We evaluate our SwAV ResNet-50 backbone on object detection on COCO dataset using DETR framework with full fine-tuning. Here are the instructions for reproducing our experiments:

  1. Install detr and prepare COCO dataset following these instructions.

  2. Apply the changes highlighted in this gist to detr backbone file in order to load SwAV backbone instead of ImageNet supervised weights.

  3. Launch training from detr repository with run_with_submitit.py.

python run_with_submitit.py --batch_size 4 --nodes 2 --lr_backbone 5e-5

Common Issues

For help or issues using SwAV, please submit a GitHub issue.

The loss does not decrease and is stuck at ln(nmb_prototypes) (8.006 for 3000 prototypes).

It sometimes happens that the system collapses at the beginning and does not manage to converge. We have found the following empirical workarounds to improve convergence and avoid collapsing at the beginning:

  • use a lower epsilon value (--epsilon 0.03 instead of the default 0.05)
  • carefully tune the hyper-parameters
  • freeze the prototypes during first iterations (freeze_prototypes_niters argument)
  • switch to hard assignment
  • remove batch-normalization layer from the projection head
  • reduce the difficulty of the problem (less crops or softer data augmentation)

We now analyze the collapsing problem: it happens when all examples are mapped to the same unique representation. In other words, the convnet always has the same output regardless of its input, it is a constant function. All examples gets the same cluster assignment because they are identical, and the only valid assignment that satisfy the equipartition constraint in this case is the uniform assignment (1/K where K is the number of prototypes). In turn, this uniform assignment is trivial to predict since it is the same for all examples. Reducing epsilon parameter (see Eq(3) of our paper) encourages the assignments Q to be sharper (i.e. less uniform), which strongly helps avoiding collapse. However, using a too low value for epsilon may lead to numerical instability.

Training gets unstable when using the queue.

The queue is composed of feature representations from the previous batches. These lines discard the oldest feature representations from the queue and save the newest one (i.e. from the current batch) through a round-robin mechanism. This way, the assignment problem is performed on more samples: without the queue we assign B examples to num_prototypes clusters where B is the total batch size while with the queue we assign (B + queue_length) examples to num_prototypes clusters. This is especially useful when working with small batches because it improves the precision of the assignment.

If you start using the queue too early or if you use a too large queue, this can considerably disturb training: this is because the queue members are too inconsistent. After introducing the queue the loss should be lower than what it was without the queue. On the following loss curve (30 first epochs of this script) we introduced the queue at epoch 15. We observe that it made the loss go more down.

SwAV training loss batch_size=256 during the first 30 epochs

If when introducing the queue, the loss goes up and does not decrease afterwards you should stop your training and change the queue parameters. We recommend (i) using a smaller queue, (ii) starting the queue later in training.

License

See the LICENSE file for more details.

See also

PyTorch Lightning Bolts: Implementation by the Lightning team.

SwAV-TF: A TensorFlow re-implementation.

Citation

If you find this repository useful in your research, please cite:

@article{caron2020unsupervised,
  title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments},
  author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand},
  booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)},
  year={2020}
}
Owner
Meta Research
Meta Research
Predicting path with preference based on user demonstration using Maximum Entropy Deep Inverse Reinforcement Learning in a continuous environment

Preference-Planning-Deep-IRL Introduction Check my portfolio post Dependencies Gym stable-baselines3 PyTorch Usage Take Demonstration python3 record.

Tianyu Li 9 Oct 26, 2022
Build Low Code Automated Tensorflow, What-IF explainable models in just 3 lines of code.

Build Low Code Automated Tensorflow explainable models in just 3 lines of code.

Hasan Rafiq 170 Dec 26, 2022
Experiments on Flood Segmentation on Sentinel-1 SAR Imagery with Cyclical Pseudo Labeling and Noisy Student Training

Flood Detection Challenge This repository contains code for our submission to the ETCI 2021 Competition on Flood Detection (Winning Solution #2). Acco

Siddha Ganju 108 Dec 28, 2022
Optimal space decomposition based-product quantization for approximate nearest neighbor search

Optimal space decomposition based-product quantization for approximate nearest neighbor search Abstract Product quantization(PQ) is an effective neare

Mylove 1 Nov 19, 2021
A lossless neural compression framework built on top of JAX.

Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible

Rosalind Franklin Institute 2 Mar 14, 2022
Styleformer - Official Pytorch Implementation

Styleformer -- Official PyTorch implementation Styleformer: Transformer based Generative Adversarial Networks with Style Vector(https://arxiv.org/abs/

Jeeseung Park 159 Dec 12, 2022
H&M Fashion Image similarity search with Weaviate and DocArray

H&M Fashion Image similarity search with Weaviate and DocArray This example shows how to do image similarity search using DocArray and Weaviate as Doc

Laura Ham 18 Aug 11, 2022
(CVPR 2022) A minimalistic mapless end-to-end stack for joint perception, prediction, planning and control for self driving.

LAV Learning from All Vehicles Dian Chen, Philipp Krähenbühl CVPR 2022 (also arXiV 2203.11934) This repo contains code for paper Learning from all veh

Dian Chen 300 Dec 15, 2022
Manifold Alignment for Semantically Aligned Style Transfer

Manifold Alignment for Semantically Aligned Style Transfer [Paper] Getting Started MAST has been tested on CentOS 7.6 with python = 3.6. It supports

35 Nov 14, 2022
The final project of "Applying AI to 2D Medical Imaging Data" of "AI for Healthcare" nanodegree - Udacity.

Pneumonia Detection from X-Rays Project Overview In this project, you will apply the skills that you have acquired in this 2D medical imaging course t

Omar Laham 1 Jan 14, 2022
Normal Learning in Videos with Attention Prototype Network

Codes_APN Official codes of CVPR21 paper: Normal Learning in Videos with Attention Prototype Network (https://arxiv.org/abs/2108.11055) Overview of ou

11 Dec 13, 2022
A spherical CNN for weather forecasting

DeepSphere-Weather - Deep Learning on the sphere for weather/climate applications. The code in this repository provides a scalable and flexible framew

DeepSphere 47 Dec 25, 2022
A repository for storing njxzc final exam review material

文档地址,请戳我 👈 👈 👈 ☀️ 1.Reason 大三上期末复习软件工程的时候,发现其他高校在GitHub上开源了他们学校的期末试题,我很受触动。期末

GuJiakai 2 Jan 18, 2022
A lightweight python AUTOmatic-arRAY library.

A lightweight python AUTOmatic-arRAY library. Write numeric code that works for: numpy cupy dask autograd jax mars tensorflow pytorch ... and indeed a

Johnnie Gray 62 Dec 27, 2022
MegEngine implementation of YOLOX

Introduction YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and ind

旷视天元 MegEngine 77 Nov 22, 2022
PyTorch code for the paper: FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning This is the PyTorch implementation of our paper: FeatMatch: Feature-Based Augmentat

43 Nov 19, 2022
[ICCV 2021] Counterfactual Attention Learning for Fine-Grained Visual Categorization and Re-identification

Counterfactual Attention Learning Created by Yongming Rao*, Guangyi Chen*, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for ICCV

Yongming Rao 90 Dec 31, 2022
Implementation of popular bandit algorithms in batch environments.

batch-bandits Implementation of popular bandit algorithms in batch environments. Source code to our paper "The Impact of Batch Learning in Stochastic

Danil Provodin 2 Sep 11, 2022
a dnn ai project to classify which food people are eating on audio recordings

Deep Learning - EAT Challenge About This project is part of an AI challenge of the DeepLearning course 2021 at the University of Augsburg. The objecti

Marco Tröster 1 Oct 24, 2021
I tried to apply the CAM algorithm to YOLOv4 and it worked.

YOLOV4:You Only Look Once目标检测模型在pytorch当中的实现 2021年2月7日更新: 加入letterbox_image的选项,关闭letterbox_image后网络的map得到大幅度提升。 目录 性能情况 Performance 实现的内容 Achievement

55 Dec 05, 2022