PyTorch META-DATASET (Few-shot classification benchmark)

Overview

PyTorch META-DATASET (Few-shot classification benchmark)

This repo contains a PyTorch implementation of meta-dataset and a unified implementation of some few-shot methods. This repo may be useful to you if you:

  • want some pre-trained ImageNet models in PyTorch for META-DATASET;
  • want to benchmark your method on META-DATASET (but do not want to mix your PyTorch code with the original TensorFlow implementation);
  • are looking for a codebase to visualize few-shot episodes.

Benefits over original code:

  1. This repo can be properly seeded, allowing to repeat the same random series of episodes if needed;
  2. Data shuffling is performed without using a buffer, hence reducing the memory consumption;
  3. Better results can be obtained using this repo thanks to an enhanced way of resizing images. More details in the paper.

Note that this code also includes the original implementation for comparison (using the PyTorch workaround proposed by the authors). If you wish to use the original implementation, set the option loader_version: 'tf' in base.yaml (by default set to pytorch).

Yet to do:

  1. Add more methods
  2. Test for the multi-source setting

Table of contents

1. Setting up

Please carefully follow the instructions below to get started.

1.1 Requirements

The present code was developped and tested in Python 3.8. The list of requirements is provided in requirements.txt:

pip install -r requirements.txt

1.2 Data

To download the META-DATASET, please follow the details instructions provided at meta-dataset to obtain the .tfrecords converted data. Once done, make sure all converted dataset are in a single folder, and execute the following script to produce index files:

bash scripts/make_records/make_index_files.sh <path_to_converted_data>

This may take a few minutes. Once all this is done, set the path variable in config/base.yaml to your data folder.

1.3 Download pre-trained models

We provide trained Resnet-18 and WRN-2810 models on the training split of ILSVRC_2012 at checkpoints. All non-episodic baselines use the same checkpoint, stored in the standard folder. The results (averaged over 600 episodes) obtained with the provided Resnet-18 are summarized below:

Inductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
Finetune Resnet-18 59.8 60.5 63.5 80.6 80.9 61.5 45.2 91.1 55.1 41.8 64.0
ProtoNet Resnet-18 48.2 46.7 44.6 53.8 70.3 45.1 38.5 82.4 42.2 38.0 51.0
SimpleShot Resnet-18 60.0 54.2 55.9 78.6 77.8 57.4 49.2 90.3 49.6 44.2 61.7
Transductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
BD-CSPN Resnet-18 60.5 54.4 55.2 80.9 77.9 57.3 50.0 91.7 47.8 43.9 62.0
TIM-GD Resnet-18 63.6 65.6 66.4 85.6 84.7 65.8 57.5 95.6 65.2 50.9 70.1

See Sect. 1.4 and 1.5 to reproduce these results.

1.4 Train models from scratch (optional)

In order to train you model from scratch, execute scripts/train.sh script:

bash scripts/train.sh <method> <architecture> <dataset>

method is to be chosen among all method specific config files in config/, architecture in ['resnet18', 'wideres2810'] and dataset among all datasets (as named by the META-DATASET converted folders). Note that the hierarchy of arguments passed to src/train.py and src/eval.py is the following: base_config < method_config < opts arguments.

Mutiprocessing : This code supports distributed training. To leverage this feature, set the gpus option accordingly (for instance gpus: [0, 1, 2, 3]).

1.5 Test your models

Once trained (or once pre-trained models downloaded), you can evaluate your model on the test split of each dataset by running:

bash scripts/test.sh <method> <architecture> <base_dataset> <test_dataset>

Results will be saved in results/ / where corresponds to a unique hash number of the config (you can only get the same result folder iff all hyperparameters are the same).

2. Visualization of results

2.1 Training metrics

During training, training loss and validation accuracy are recorded and saved as .npy files in the checkpoint folder. Then, you can use the src/plot.py to plot these metrics (even during training).

Example 1: Plot the metrics of the standard (=non episodic) resnet-18 on ImageNet:

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/standard/

Example 2: Plot the metrics of all Resnet-18 trained on ImageNet

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/

2.2 Inference metrics

For methods that perform test-time optimization (for instance MAML, TIM, Finetune, ...), method specific metrics are plotted in real-time (versus test iterations) and averaged over test epidodes, which can allow you to track unexpected behavior easily. Such metrics are implemented in src/metrics/, and the choice of which metric to plot is specificied through the eval_metrics option in the method .yaml config file. An example with TIM method is provided below.

2.3 Visualization of episodes

By setting the option visu: True at inference, you can visualize samples of episodes. An example of such visualization is given below:

The samples will be saved in results/. All relevant optons can be found in the base.yaml file, in the EVAL-VISU section.

3. Incorporate your own method

This code was designed to allow easy incorporation of new methods.

Step 1: Add your method .py file to src/methods/ by following the template provided in src/methods/method.py.

Step 2: Add import in src/methods/__init__.py

Step 3: Add your method .yaml config file including the required options episodic_training and method (name of the class corresponding to your method). Also make sure that if your method performs test-time optimization, you also properly set the option iter that specifies the number of optimization steps performed at inference (this argument is also used to plot the inference metrics, see section 2.2).

4. Contributions

Contributions are more than welcome. In particular, if you want to add methods/pre-trained models, do make a pull-request.

5. Citation

If you find this repo useful for your research, please consider citing the following papers:

@misc{boudiaf2021mutualinformation,
      title={Mutual-Information Based Few-Shot Classification}, 
      author={Malik Boudiaf and Ziko Imtiaz Masud and Jérôme Rony and Jose Dolz and Ismail Ben Ayed and Pablo Piantanida},
      year={2021},
      eprint={2106.12252},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Additionally, do not hesitate to file issues if you encounter problems, or reach out directly to Malik Boudiaf ([email protected]).

6. Acknowledgments

I thank the authors of meta-dataset for releasing their code and the author of open-source TFRecord reader for open sourcing an awesome Pytorch-compatible TFRecordReader ! Also big thanks to @hkervadec for his thorough code review !

Owner
Malik Boudiaf
Malik Boudiaf
A general, feasible, and extensible framework for classification tasks.

Pytorch Classification A general, feasible and extensible framework for 2D image classification. Features Easy to configure (model, hyperparameters) T

Eugene 26 Nov 22, 2022
The first dataset on shadow generation for the foreground object in real-world scenes.

Object-Shadow-Generation-Dataset-DESOBA Object Shadow Generation is to deal with the shadow inconsistency between the foreground object and the backgr

BCMI 105 Dec 30, 2022
Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification"

hypergraph_reid Implementation of "Learning Multi-Granular Hypergraphs for Video-Based Person Re-Identification" If you find this help your research,

62 Dec 21, 2022
⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.

Optimized Einsum Optimized Einsum: A tensor contraction order optimizer Optimized einsum can significantly reduce the overall execution time of einsum

Daniel Smith 653 Dec 30, 2022
OSLO: Open Source framework for Large-scale transformer Optimization

O S L O Open Source framework for Large-scale transformer Optimization What's New: December 21, 2021 Released OSLO 1.0. What is OSLO about? OSLO is a

TUNiB 280 Nov 24, 2022
This package contains a PyTorch Implementation of IB-GAN of the submitted paper in AAAI 2021

The PyTorch implementation of IB-GAN model of AAAI 2021 This package contains a PyTorch implementation of IB-GAN presented in the submitted paper (IB-

Insu Jeon 9 Mar 30, 2022
1st ranked 'driver careless behavior detection' for AI Online Competition 2021, hosted by MSIT Korea.

2021AICompetition-03 본 repo 는 mAy-I Inc. 팀으로 참가한 2021 인공지능 온라인 경진대회 중 [이미지] 운전 사고 예방을 위한 운전자 부주의 행동 검출 모델] 태스크 수행을 위한 레포지토리입니다. mAy-I 는 과학기술정보통신부가 주최하

Junhyuk Park 9 Dec 01, 2022
GANTheftAuto is a fork of the Nvidia's GameGAN

Description GANTheftAuto is a fork of the Nvidia's GameGAN, which is research focused on emulating dynamic game environments. The early research done

Harrison 801 Dec 27, 2022
Retrieval.pytorch - The code we used in [2020 DIGIX]

Retrieval.pytorch - The code we used in [2020 DIGIX]

Guo-Hua Wang 2 Feb 07, 2022
CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Detection in Remote Sensing Images

CFC-Net This project hosts the official implementation for the paper: CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Dete

ming71 55 Dec 12, 2022
A modular application for performing anomaly detection in networks

Deep-Learning-Models-for-Network-Annomaly-Detection The modular app consists for mainly three annomaly detection algorithms. The system supports model

Shivam Patel 1 Dec 09, 2021
A fast implementation of bss_eval metrics for blind source separation

fast_bss_eval Do you have a zillion BSS audio files to process and it is taking days ? Is your simulation never ending ? Fear no more! fast_bss_eval i

Robin Scheibler 99 Dec 13, 2022
Pixel-Perfect Structure-from-Motion with Featuremetric Refinement (ICCV 2021, Oral)

Pixel-Perfect Structure-from-Motion (ICCV 2021 Oral) We introduce a framework that improves the accuracy of Structure-from-Motion by refining keypoint

Computer Vision and Geometry Lab 831 Dec 29, 2022
Repository for training material for the 2022 SDSC HPC/CI User Training Course

hpc-training-2022 Repository for training material for the 2022 SDSC HPC/CI Training Series HPC/CI Training Series home https://www.sdsc.edu/event_ite

sdsc-hpc-training-org 21 Jul 27, 2022
Dataset and Code for the paper "DepthTrack: Unveiling the Power of RGBD Tracking" (ICCV2021), and "Depth-only Object Tracking" (BMVC2021)

DeT and DOT Code and datasets for "DepthTrack: Unveiling the Power of RGBD Tracking" (ICCV2021) "Depth-only Object Tracking" (BMVC2021) @InProceedings

Yan Song 55 Dec 15, 2022
Company clustering with K-means/GMM and visualization with PCA, t-SNE, using SSAN relation extraction

RE results graph visualization and company clustering Installation pip install -r requirements.txt python -m nltk.downloader stopwords python3.7 main.

Jieun Han 1 Oct 06, 2022
Lightwood is Legos for Machine Learning.

Lightwood is like Legos for Machine Learning. A Pytorch based framework that breaks down machine learning problems into smaller blocks that can be glu

MindsDB Inc 312 Jan 08, 2023
A highly efficient, fast, powerful and light-weight anime downloader and streamer for your favorite anime.

AnimDL - Download & Stream Your Favorite Anime AnimDL is an incredibly powerful tool for downloading and streaming anime. Core features Abuses the dev

KR 759 Jan 08, 2023
Robustness between the worst and average case

Robustness between the worst and average case A repository that implements intermediate robustness training and evaluation from the NeurIPS 2021 paper

CMU Locus Lab 16 Dec 02, 2022
Pytorch implementation of paper: "NeurMiPs: Neural Mixture of Planar Experts for View Synthesis"

NeurMips: Neural Mixture of Planar Experts for View Synthesis This is the official repo for PyTorch implementation of paper "NeurMips: Neural Mixture

James Lin 101 Dec 13, 2022