A PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework".

Related tags

Deep Learningmugs
Overview

Mugs: A Multi-Granular Self-Supervised Learning Framework

This is a PyTorch implementation of Mugs proposed by our paper "Mugs: A Multi-Granular Self-Supervised Learning Framework". arXiv

PWC

Overall framework of Mugs.

Fig 1. Overall framework of Mugs. In (a), for each image, two random crops of one image are fed into backbones of student and teacher. Three granular supervisions: 1) instance discrimination supervision, 2) local-group discrimination supervision, and 3) group discrimination supervision, are adopted to learn multi-granular representation. In (b), local-group modules in student/teacher averages all patch tokens, and finds top-k neighbors from memory buffer to aggregate them with the average for obtaining a local-group feature.

Pretrained models on ImageNet-1K

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks.

Table 1. KNN and linear probing performance with their corresponding hyper-parameters, logs and model weights.

arch params pretraining epochs k-nn linear download
ViT-S/16 21M 100 72.3% 76.4% backbone only full ckpt args logs eval logs
ViT-S/16 21M 300 74.8% 78.2% backbone only full ckpt args logs eval logs
ViT-S/16 21M 800 75.6% 78.9% backbone only full ckpt args logs eval logs
ViT-B/16 85M 400 78.0% 80.6% backbone only full ckpt args logs eval logs
ViT-L/16 307M 250 80.3% 82.1% backbone only full ckpt args logs eval logs
Comparison of linear probing accuracy on ImageNet-1K.

Fig 2. Comparison of linear probing accuracy on ImageNet-1K.

Pretraining Settings

Environment

For reproducing, please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.8, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. For the full environment, please refer to our Dockerfile file.

ViT pretraining 🍺

To pretraining each model, please find the exact hyper-parameter settings at the args column of Table 1. For training log and linear probing log, please refer to the log and eval logs column of Table 1.

ViT-Small pretraining:

To run ViT-small for 100 epochs, we use two nodes of total 8 A100 GPUs (total 512 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=8 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.04 --group_warmup_teacher_temp_epochs 0 --weight_decay_end 0.2 --norm_last_layer false --epochs 100

To run ViT-small for 300 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 300

To run ViT-small for 800 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 800

ViT-Base pretraining:

To run ViT-base for 400 epochs, we use two nodes of total 24 A100 GPUs (total 1024 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=24 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_base 
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --min_lr 2e-06 --weight_decay_end 0.1 --freeze_last_layer 3 --norm_last_layer 
false --epochs 400

ViT-Large pretraining:

To run ViT-large for 250 epochs, we use two nodes of total 40 A100 GPUs (total 640 minibatch size) by using following command:

python -m torch.distributed.launch --nproc_per_node=40 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_large 
--lr 0.0015 --min_lr 1.5e-4 --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --weight_decay 0.025 
--weight_decay_end 0.08 --norm_last_layer true --drop_path_rate 0.3 --freeze_last_layer 3 --epochs 250

Evaluation

We are cleaning up the evalutation code and will release them when they are ready.

Self-attention visualization

Here we provide the self-attention map of the [CLS] token on the heads of the last layer

Self-attention from a ViT-Base/16 trained with Mugs

Fig 3. Self-attention from a ViT-Base/16 trained with Mugs.

T-SNE visualization

Here we provide the T-SNE visualization of the learned feature by ViT-B/16. We show the fish classes in ImageNet-1K, i.e., the first six classes, including tench, goldfish, white shark, tiger shark, hammerhead, electric ray. See more examples in Appendix.

T-SNE visualization of the learned feature by ViT-B/16.

Fig 4. T-SNE visualization of the learned feature by ViT-B/16.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star and citation 🍺 :

@inproceedings{mugs2022SSL,
  title={Mugs: A Multi-Granular Self-Supervised Learning Framework},
  author={Pan Zhou and Yichen Zhou and Chenyang Si and Weihao Yu and Teck Khim Ng and Shuicheng Yan},
  booktitle={arXiv preprint arXiv:2203.14415},
  year={2022}
}
Owner
Sea AI Lab
Sea AI Lab
A python library for face detection and features extraction based on mediapipe library

FaceAnalyzer A python library for face detection and features extraction based on mediapipe library Introduction FaceAnalyzer is a library based on me

Saifeddine ALOUI 14 Dec 30, 2022
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
PyTorch Implementation of ByteDance's Cross-speaker Emotion Transfer Based on Speaker Condition Layer Normalization and Semi-Supervised Training in Text-To-Speech

Cross-Speaker-Emotion-Transfer - PyTorch Implementation PyTorch Implementation of ByteDance's Cross-speaker Emotion Transfer Based on Speaker Conditio

Keon Lee 114 Jan 08, 2023
Curvlearn, a Tensorflow based non-Euclidean deep learning framework.

English | 简体中文 Why Non-Euclidean Geometry Considering these simple graph structures shown below. Nodes with same color has 2-hop distance whereas 1-ho

Alibaba 123 Dec 12, 2022
Poisson Surface Reconstruction for LiDAR Odometry and Mapping

Poisson Surface Reconstruction for LiDAR Odometry and Mapping Surfels TSDF Our Approach Table: Qualitative comparison between the different mapping te

Photogrammetry & Robotics Bonn 305 Dec 21, 2022
Implementing DropPath/StochasticDepth in PyTorch

%load_ext memory_profiler Implementing Stochastic Depth/Drop Path In PyTorch DropPath is available on glasses my computer vision library! Introduction

Francesco Saverio Zuppichini 13 Jan 05, 2023
PyTorch Live is an easy to use library of tools for creating on-device ML demos on Android and iOS.

PyTorch Live is an easy to use library of tools for creating on-device ML demos on Android and iOS. With Live, you can build a working mobile app ML demo in minutes.

559 Jan 01, 2023
AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision

NeurIPS 2021 Paper "Learning with Algorithmic Supervision via Continuous Relaxations"

Felix Petersen 76 Jan 01, 2023
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
Official implementation for the paper: Generating Smooth Pose Sequences for Diverse Human Motion Prediction

Generating Smooth Pose Sequences for Diverse Human Motion Prediction This is official implementation for the paper Generating Smooth Pose Sequences fo

Wei Mao 28 Dec 10, 2022
Pytorch implementation of various High Dynamic Range (HDR) Imaging algorithms

Deep High Dynamic Range Imaging Benchmark This repository is the pytorch impleme

Tianhong Dai 5 Nov 16, 2022
Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES)

Non-Imaging Transient Reconstruction And TEmporal Search (NITRATES) This repo contains the full NITRATES pipeline for maximum likelihood-driven discov

13 Nov 08, 2022
GraphLily: A Graph Linear Algebra Overlay on HBM-Equipped FPGAs

GraphLily: A Graph Linear Algebra Overlay on HBM-Equipped FPGAs GraphLily is the first FPGA overlay for graph processing. GraphLily supports a rich se

Cornell Zhang Research Group 39 Dec 13, 2022
Official implementation of our neural-network-based fast diffuse room impulse response generator (FAST-RIR)

This is the official implementation of our neural-network-based fast diffuse room impulse response generator (FAST-RIR) for generating room impulse responses (RIRs) for a given acoustic environment.

12 Jan 13, 2022
A Tensorflow based library for Time Series Modelling with Gaussian Processes

Markovflow Documentation | Tutorials | API reference | Slack What does Markovflow do? Markovflow is a Python library for time-series analysis via prob

Secondmind Labs 24 Dec 12, 2022
Human Activity Recognition example using TensorFlow on smartphone sensors dataset and an LSTM RNN. Classifying the type of movement amongst six activity categories - Guillaume Chevalier

LSTMs for Human Activity Recognition Human Activity Recognition (HAR) using smartphones dataset and an LSTM RNN. Classifying the type of movement amon

Guillaume Chevalier 3.1k Dec 30, 2022
Graph Convolutional Neural Networks with Data-driven Graph Filter (GCNN-DDGF)

Graph Convolutional Gated Recurrent Neural Network (GCGRNN) Improved from Graph Convolutional Neural Networks with Data-driven Graph Filter (GCNN-DDGF

Lei Lin 21 Dec 18, 2022
Transfer Learning for Pose Estimation of Illustrated Characters

bizarre-pose-estimator Transfer Learning for Pose Estimation of Illustrated Characters Shuhong Chen *, Matthias Zwicker * WACV2022 [arxiv] [video] [po

Shuhong Chen 142 Dec 28, 2022
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

73 Jan 01, 2023
WormMovementSimulation - 3D Simulation of Worm Body Movement with Neurons attached to its body

Generate 3D Locomotion Data This module is intended to create 2D video trajector

1 Aug 09, 2022