Prototypical Networks for Few shot Learning in PyTorch

Overview

Prototypical Networks for Few shot Learning in PyTorch

Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code) in PyTorch.

Prototypical Networks

As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features (n_support) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query samples and their class barycentre can be minimized.

Prototypical Networks

T-SNE

After training, you can compute the t-SNE for the features generated by the model (not done in this repo, more infos about t-SNE here), this is a sample as shown in the paper.

Reference Paper t-SNE

Omniglot Dataset

Kudos to @ludc for his contribute: https://github.com/pytorch/vision/pull/46. We will use the official dataset when it will be added to torchvision if it doesn't imply big changes to the code.

Dataset splits

We implemented the Vynials splitting method as in [Matching Networks for One Shot Learning]. That sould be the same method used in the paper (in fact I download the split files from the "offical" repo). We then apply the same rotations there described. In this way we should be able to compare results obtained by running this code with results described in the reference paper.

Prototypical Batch Sampler

As described in its PyDoc, this class is used to generate the indexes of each batch for a prototypical training algorithm.

In particular, the object is instantiated by passing the list of the labels for the dataset, the sampler infers then the total number of classes and creates a set of indexes for each class ni the dataset. At each episode the sampler selects n_classes random classes and returns a number (n_support + n_query) of samples indexes for each one of the selected classes.

Prototypical Loss

Compute the loss as in the cited paper, mostly inspired by this code by one of its authors.

In prototypical_loss.py both loss function and loss class à la PyTorch are implemented.

The function takes in input the batch input from the model, samples' ground truths and the number n_suppport of samples to be used as support samples. Episode classes get infered from the target list, n_support samples get randomly extracted for each class, their class barycentres get computed, as well as the distances of each remaining samples' embedding from each class barycentre and the probability of each sample of belonging to each episode class get finmally computed; then the loss is then computed from the wrong predictions probabilities (for the query samples) as usual in classification problems.

Training

Please note that the training code is here just for demonstration purposes.

To train the Protonet on this task, cd into this repo's src root folder and execute:

$ python train.py

The script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • nepochs: number of epochs to train for, default to 100

  • learning_rate: learning rate for the model, default to 0.001

  • lr_scheduler_step: StepLR learning rate scheduler step, default to 20

  • lr_scheduler_gamma: StepLR learning rate scheduler gamma, default to 0.5

  • iterations: number of episodes per epoch. default to 100

  • classes_per_it_tr: number of random classes per episode for training. default to 60

  • num_support_tr: number of samples per class to use as support for training. default to 5

  • num_query_tr: nnumber of samples per class to use as query for training. default to 5

  • classes_per_it_val: number of random classes per episode for validation. default to 5

  • num_support_val: number of samples per class to use as support for validation. default to 5

  • num_query_val: number of samples per class to use as query for validation. default to 15

  • manual_seed: input for the manual seeds initializations, default to 7

  • cuda: enables cuda (store True)

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).

Performances

We are trying to reproduce the reference paper performaces, we'll update here our best results.

Model 1-shot (5-way Acc.) 5-shot (5-way Acc.) 1 -shot (20-way Acc.) 5-shot (20-way Acc.)
Reference Paper 98.8% 99.7% 96.0% 98.9%
This repo 98.5%** 99.6%* 95.1%° 98.6%°°

* achieved using default parameters (using --cuda option)

** achieved running python train.py --cuda -nsTr 1 -nsVa 1

° achieved running python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20

°° achieved running python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20

Helpful links

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

@article{DBLP:journals/corr/SnellSZ17,
  author    = {Jake Snell and
               Kevin Swersky and
               Richard S. Zemel},
  title     = {Prototypical Networks for Few-shot Learning},
  journal   = {CoRR},
  volume    = {abs/1703.05175},
  year      = {2017},
  url       = {http://arxiv.org/abs/1703.05175},
  archivePrefix = {arXiv},
  eprint    = {1703.05175},
  timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
  biburl    = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
  bibsource = {dblp computer science bibliography, http://dblp.org}
}

License

This project is licensed under the MIT License

Copyright (c) 2018 Daniele E. Ciriello, Orobix Srl (www.orobix.com).

Owner
Orobix
Orobix
QA-GNN: Question Answering using Language Models and Knowledge Graphs

QA-GNN: Question Answering using Language Models and Knowledge Graphs This repo provides the source code & data of our paper: QA-GNN: Reasoning with L

Michihiro Yasunaga 434 Jan 04, 2023
Full Resolution Residual Networks for Semantic Image Segmentation

Full-Resolution Residual Networks (FRRN) This repository contains code to train and qualitatively evaluate Full-Resolution Residual Networks (FRRNs) a

Toby Pohlen 274 Oct 27, 2022
Training Structured Neural Networks Through Manifold Identification and Variance Reduction

Training Structured Neural Networks Through Manifold Identification and Variance Reduction This repository is a pytorch implementation of the Regulari

0 Dec 23, 2021
Official project repository for 'Normality-Calibrated Autoencoder for Unsupervised Anomaly Detection on Data Contamination'

NCAE_UAD Official project repository of 'Normality-Calibrated Autoencoder for Unsupervised Anomaly Detection on Data Contamination' Abstract In this p

Jongmin Andrew Yu 2 Feb 10, 2022
Official implementation of our paper "Learning to Bootstrap for Combating Label Noise"

Learning to Bootstrap for Combating Label Noise This repo is the official implementation of our paper "Learning to Bootstrap for Combating Label Noise

21 Apr 09, 2022
Quantized tflite models for ailia TFLite Runtime

ailia-models-tflite Quantized tflite models for ailia TFLite Runtime About ailia TFLite Runtime ailia TF Lite Runtime is a TensorFlow Lite compatible

ax Inc. 13 Dec 23, 2022
The official github repository for Towards Continual Knowledge Learning of Language Models

Towards Continual Knowledge Learning of Language Models This is the official github repository for Towards Continual Knowledge Learning of Language Mo

Joel Jang | 장요엘 65 Jan 07, 2023
Parallel and High-Fidelity Text-to-Lip Generation; AAAI 2022 ; Official code

Parallel and High-Fidelity Text-to-Lip Generation This repository is the official PyTorch implementation of our AAAI-2022 paper, in which we propose P

Zhying 77 Dec 21, 2022
PyTorch reimplementation of REALM and ORQA

PyTorch reimplementation of REALM and ORQA

Li-Huai (Allan) Lin 17 Aug 20, 2022
BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalanced Tongue Data

Balanced-Evolutionary-Semi-Stacking Code for the paper ''BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalan

0 Jan 16, 2022
State-Relabeling Adversarial Active Learning

State-Relabeling Adversarial Active Learning Code for SRAAL [2020 CVPR Oral] Requirements torch = 1.6.0 numpy = 1.19.1 tqdm = 4.31.1 AL Results The

10 Jul 14, 2022
Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

SA-Net: Shuffle Attention for Deep Convolutional Neural Networks (paper) By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software T

Qing-Long Zhang 199 Jan 08, 2023
Repo público onde postarei meus estudos de Python, buscando aprender por meio do compartilhamento do aprendizado!

Seja bem vindo à minha repo de Estudos em Python 3! Este é um repositório criado por um programador amador que estuda tópicos de finanças, estatística

32 Dec 24, 2022
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
Official repository for MixFaceNets: Extremely Efficient Face Recognition Networks

MixFaceNets This is the official repository of the paper: MixFaceNets: Extremely Efficient Face Recognition Networks. (Accepted in IJCB2021) https://i

Fadi Boutros 51 Dec 13, 2022
A containerized REST API around OpenAI's CLIP model.

OpenAI's CLIP — REST API This is a container wrapping OpenAI's CLIP model in a RESTful interface. Running the container locally First, build the conta

Santiago Valdarrama 48 Nov 06, 2022
Tools for investing in Python

InvestOps Original repository on GitHub Original author is Magnus Erik Hvass Pedersen Introduction This is a Python package with simple and effective

24 Nov 26, 2022
implement of SwiftNet:Real-time Video Object Segmentation

SwiftNet The official PyTorch implementation of SwiftNet:Real-time Video Object Segmentation, which has been accepted by CVPR2021. Requirements Python

haochen wang 64 Dec 14, 2022
Bottom-up attention model for image captioning and VQA, based on Faster R-CNN and Visual Genome

bottom-up-attention This code implements a bottom-up attention model, based on multi-gpu training of Faster R-CNN with ResNet-101, using object and at

Peter Anderson 1.3k Jan 09, 2023
A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

George Gunter 4 Nov 14, 2022