PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

Related tags

Deep LearningMAE-priv
Overview

MAE for Self-supervised ViT

Introduction

This is an unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners for self-supervised ViT.

This repo is mainly based on moco-v3, pytorch-image-models and BEiT

TODO

  • visualization of reconstruction image
  • linear prob
  • more results
  • transfer learning
  • ...

Main Results

The following results are based on ImageNet-1k self-supervised pre-training, followed by ImageNet-1k supervised training for linear evaluation or end-to-end fine-tuning.

Vit-Base

pretrain
epochs
with
pixel-norm
linear
acc
fine-tuning
acc
100 False -- 75.58 [1]
100 True -- 77.19
800 True -- --

On 8 NVIDIA GeForce RTX 3090 GPUs, pretrain for 100 epochs needs about 9 hours, 4096 batch size needs about 24 GB GPU memory.

[1]. fine-tuning for 50 epochs;

Vit-Large

pretrain
epochs
with
pixel-norm
linear
acc
fine-tuning
acc
100 False -- --
100 True -- --

On 8 NVIDIA A40 GPUs, pretrain for 100 epochs needs about 34 hours, 4096 batch size needs about xx GB GPU memory.

Usage: Preparation

The code has been tested with CUDA 11.4, PyTorch 1.8.2.

Notes:

  1. The batch size specified by -b is the total batch size across all GPUs from all nodes.
  2. The learning rate specified by --lr is the base lr (corresponding to 256 batch-size), and is adjusted by the linear lr scaling rule.
  3. In this repo, only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported. This code is improved to better suit the multi-node setting, and by default uses automatic mixed-precision for pre-training.
  4. Only pretraining and finetuning have been tested.

Usage: Self-supervised Pre-Training

Below is examples for MAE pre-training.

ViT-Base with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, batch 4096

python main_mae.py \
  -c cfgs/ViT-B16_ImageNet1K_pretrain.yaml \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  [your imagenet-folder with train and val folders]

or

sh train_mae.sh

ViT-Large with 1-node (8-GPU, NVIDIA A40) pre-training, batch 2048

python main_mae.py \
  -c cfgs/ViT-L16_ImageNet1K_pretrain.yaml \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  [your imagenet-folder with train and val folders]

Usage: End-to-End Fine-tuning ViT

Below is examples for MAE fine-tuning.

ViT-Base with 1-node (8-GPU, NVIDIA GeForce RTX 3090) training, batch 1024

python main_fintune.py \
  -c cfgs/ViT-B16_ImageNet1K_finetune.yaml \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  [your imagenet-folder with train and val folders]

ViT-Large with 2-node (16-GPU, 8 NVIDIA GeForce RTX 3090 + 8 NVIDIA A40) training, batch 512

python main_fintune.py \
  -c cfgs/ViT-B16_ImageNet1K_finetune.yaml \
  --multiprocessing-distributed --world-size 2 --rank 0 \
  [your imagenet-folder with train and val folders]

On another node, run the same command with --rank 1.

Note:

  1. We use --resume rather than --finetune in the DeiT repo, as its --finetune option trains under eval mode. When loading the pre-trained model, revise model_without_ddp.load_state_dict(checkpoint['model']) with strict=False.

[TODO] Usage: Linear Classification

By default, we use momentum-SGD and a batch size of 1024 for linear classification on frozen features/weights. This can be done with a single 8-GPU node.

python main_lincls.py \
  -a [architecture] --lr [learning rate] \
  --dist-url 'tcp://localhost:10001' \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  --pretrained [your checkpoint path]/[your checkpoint file].pth.tar \
  [your imagenet-folder with train and val folders]

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Citation

If you use the code of this repo, please cite the original papre and this repo:

@Article{he2021mae,
  author  = {Kaiming He* and Xinlei Chen* and Saining Xie and Yanghao Li and Piotr Dolla ́r and Ross Girshick},
  title   = {Masked Autoencoders Are Scalable Vision Learners},
  journal = {arXiv preprint arXiv:2111.06377},
  year    = {2021},
}
@misc{yang2021maepriv,
  author       = {Lu Yang* and Pu Cao* and Yang Nie and Qing Song},
  title        = {MAE-priv},
  howpublished = {\url{https://github.com/BUPT-PRIV/MAE-priv}},
  year         = {2021},
}
Implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTorch

Neural Distance Embeddings for Biological Sequences Official implementation of Neural Distance Embeddings for Biological Sequences (NeuroSEED) in PyTo

Gabriele Corso 56 Dec 23, 2022
Deep Convolutional Generative Adversarial Networks

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks Alec Radford, Luke Metz, Soumith Chintala All images in t

Alec Radford 3.4k Dec 29, 2022
Video Representation Learning by Recognizing Temporal Transformations. In ECCV, 2020.

Video Representation Learning by Recognizing Temporal Transformations [Project Page] Simon Jenni, Givi Meishvili, and Paolo Favaro. In ECCV, 2020. Thi

Simon Jenni 46 Nov 14, 2022
Official Implementation for the "An Empirical Investigation of 3D Anomaly Detection and Segmentation" paper.

An Empirical Investigation of 3D Anomaly Detection and Segmentation Project | Paper Official PyTorch Implementation for the "An Empirical Investigatio

Eliahu Horwitz 55 Dec 14, 2022
Official implementation of "Implicit Neural Representations with Periodic Activation Functions"

Implicit Neural Representations with Periodic Activation Functions Project Page | Paper | Data Vincent Sitzmann*, Julien N. P. Martel*, Alexander W. B

Vincent Sitzmann 1.4k Jan 06, 2023
This repository contains all source code, pre-trained models related to the paper "An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator"

An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator This is a Pytorch implementation for the paper "An Empirical Study o

Cuong Nguyen 3 Nov 15, 2021
3D cascade RCNN for object detection on point cloud

3D Cascade RCNN This is the implementation of 3D Cascade RCNN: High Quality Object Detection in Point Clouds. We designed a 3D object detection model

Qi Cai 22 Dec 02, 2022
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 06, 2023
Code for the paper "A Study of Face Obfuscation in ImageNet"

A Study of Face Obfuscation in ImageNet Code for the paper: A Study of Face Obfuscation in ImageNet Kaiyu Yang, Jacqueline Yau, Li Fei-Fei, Jia Deng,

35 Oct 04, 2022
​ This is the Pytorch implementation of Progressive Attentional Manifold Alignment.

PAMA This is the Pytorch implementation of Progressive Attentional Manifold Alignment. Requirements python 3.6 pytorch 1.2.0+ PIL, numpy, matplotlib C

98 Nov 15, 2022
PointPillars inference with TensorRT

A project demonstrating how to use CUDA-PointPillars to deal with cloud points data from lidar.

NVIDIA AI IOT 315 Dec 31, 2022
[SIGGRAPH 2021 Asia] DeepVecFont: Synthesizing High-quality Vector Fonts via Dual-modality Learning

DeepVecFont This is the official Pytorch implementation of the paper: Yizhi Wang and Zhouhui Lian. DeepVecFont: Synthesizing High-quality Vector Fonts

Yizhi Wang 146 Dec 18, 2022
BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalanced Tongue Data

Balanced-Evolutionary-Semi-Stacking Code for the paper ''BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalan

0 Jan 16, 2022
Code for "Reconstructing 3D Human Pose by Watching Humans in the Mirror", CVPR 2021 oral

Reconstructing 3D Human Pose by Watching Humans in the Mirror Qi Fang*, Qing Shuai*, Junting Dong, Hujun Bao, Xiaowei Zhou CVPR 2021 Oral The videos a

ZJU3DV 178 Dec 13, 2022
Spatial color quantization in Rust

rscolorq Rust port of Derrick Coetzee's scolorq, based on the 1998 paper "On spatial quantization of color images" by Jan Puzicha, Markus Held, Jens K

Collyn O'Kane 37 Dec 22, 2022
Using machine learning to predict undergrad college admissions.

College-Prediction Project- Overview: Many have tried, many have failed. Few trailblazers are ambitious enought to chase acceptance into the top 15 un

John H Klinges 1 Jan 05, 2022
Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation

Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation (CVPR2019) This is a pytorch implementatio

Yawei Luo 280 Jan 01, 2023
Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation

Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation Official PyTorch implementation for the paper Look

Rishabh Jangir 20 Nov 24, 2022
A library for answering questions using data you cannot see

A library for computing on data you do not own and cannot see PySyft is a Python library for secure and private Deep Learning. PySyft decouples privat

OpenMined 8.5k Jan 02, 2023
This is the code repository implementing the paper "TreePartNet: Neural Decomposition of Point Clouds for 3D Tree Reconstruction".

TreePartNet This is the code repository implementing the paper "TreePartNet: Neural Decomposition of Point Clouds for 3D Tree Reconstruction". Depende

刘彦超 34 Nov 30, 2022