Pytorch implementation of MLP-Mixer with loading pre-trained models.

Overview

MLP-Mixer-Pytorch

PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained parameters.

Usage

import torch
import numpy as np
from mlp_mixer import MlpMixer

pretrain_model='./pretrain_models/imagenet21k_Mixer-B_16.npz'

model = MlpMixer(num_classes=10, 
                 num_blocks=12, 
                 patch_size=16, 
                 hidden_dim=768, 
                 tokens_mlp_dim=384, 
                 channels_mlp_dim=3072, 
                 image_size=224
                 )

# load official ImageNet pre-trained model:
model.load_from(np.load(pretrain_model))
print ('Finish loading the pre-trained model!')

num_param = sum(p.numel() for p in model.parameters()) / 1e6
print ('Total params.: %f M'%num_param)

pred = model(img)

Fine-tuning

Download the official pre-trained models at https://console.cloud.google.com/storage/mixer_models/.

Hypyer-parameters setting for better fine-tuning:

optim = torch.optim.SGD(param_list, 
                        lr=5e-4, 
                        weight_decay=1e-7,
                        momentum=0.9, 
                        nesterov=True
                        )
lr_schdlr = WarmupCosineLrScheduler(optim, 
                                    n_iters_all, 
                                    warmup_iter=0
                                    )

Using the pre-trained model to fine-tune MLP-Mixer can obtain remarkable improvements (e.g., +10% accuracy on a small dataset).

Note that we can also change the patch_size (e.g., patch_size=8) for inputs with different resolutions, but smaller patch_size may not always bring performance improvements.

Citation

@misc{tolstikhin2021mlpmixer,
      title={MLP-Mixer: An all-MLP Architecture for Vision}, 
      author={Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
      year={2021},
      eprint={2105.01601},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

  1. The implementation is based on the original paper and the official Tensorflow repo: https://github.com/google-research/vision_transformer.
  2. It also refers to the re-implementation repo: https://github.com/d-li14/mlp-mixer.pytorch.
Owner
Qiushi Yang
Student in CityU, Hong Kong. Interested in computer vision and machine learning.
Qiushi Yang
This package is for running the semantic SLAM algorithm using extracted planar surfaces from the received detection

Semantic SLAM This package can perform optimization of pose estimated from VO/VIO methods which tend to drift over time. It uses planar surfaces extra

Hriday Bavle 125 Dec 02, 2022
Code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection"

CTDNet The PyTorch code for ACM MM2021 paper "Complementary Trilateral Decoder for Fast and Accurate Salient Object Detection" Requirements Python 3.6

CVTEAM 28 Oct 20, 2022
PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimation

StructDepth PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimat

SJTU-ViSYS 112 Nov 28, 2022
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 3 Aug 20, 2022
Projects of Andfun Yangon

AndFunYangon Projects of Andfun Yangon First Commit We can use gsearch.py to sea

Htin Aung Lu 1 Dec 28, 2021
Lightweight, Python library for fast and reproducible experimentation :microscope:

Steppy What is Steppy? Steppy is a lightweight, open-source, Python 3 library for fast and reproducible experimentation. Steppy lets data scientist fo

minerva.ml 134 Jul 10, 2022
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

DALL-E in Pytorch Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the ge

Phil Wang 5k Jan 04, 2023
Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks

OnsagerNet Learning hidden low dimensional dyanmics using a Generalized Onsager Principle and neural networks This is the original pyTorch implemenati

Haijun.Yu 3 Aug 24, 2022
Pytorch implementation of forward and inverse Haar Wavelets 2D

Pytorch implementation of forward and inverse Haar Wavelets 2D

Sergei Belousov 9 Oct 30, 2022
Structured Edge Detection Toolbox

################################################################### # # # Structure

Piotr Dollar 779 Jan 02, 2023
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Code Release for the paper "TriBERT: Full-body Human-centric Audio-visual Representation Learning for Visual Sound Separation"

TriBERT This repository contains the code for the NeurIPS 2021 paper titled "TriBERT: Full-body Human-centric Audio-visual Representation Learning for

UBC Computer Vision Group 8 Aug 31, 2022
Zero-Cost Proxies for Lightweight NAS

Zero-Cost-NAS Companion code for the ICLR2021 paper: Zero-Cost Proxies for Lightweight NAS tl;dr A single minibatch of data is used to score neural ne

SamsungLabs 108 Dec 20, 2022
Compute descriptors for 3D point cloud registration using a multi scale sparse voxel architecture

MS-SVConv : 3D Point Cloud Registration with Multi-Scale Architecture and Self-supervised Fine-tuning Compute features for 3D point cloud registration

42 Jul 25, 2022
Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow

xRBM Library Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow Installation Using pip: pip install xrbm Examples Tut

Omid Alemi 55 Dec 29, 2022
A unified framework to jointly model images, text, and human attention traces.

connect-caption-and-trace This repository contains the reference code for our paper Connecting What to Say With Where to Look by Modeling Human Attent

Meta Research 73 Oct 24, 2022
Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch

SRDenseNet-pytorch Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch (http://openaccess.thecvf.com/content_ICC

wxy 114 Nov 26, 2022
Supporting code for the Neograd algorithm

Neograd This repo supports the paper Neograd: Gradient Descent with a Near-Ideal Learning Rate, which introduces the algorithm "Neograd". The paper an

Michael Zimmer 12 May 01, 2022
InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

Deep Insight 13.2k Jan 06, 2023