PyTorch Implementation for Deep Metric Learning Pipelines

Overview

Easily Extendable Basic Deep Metric Learning Pipeline

Karsten Roth ([email protected]), Biagio Brattoli ([email protected])

When using this repo in any academic work, please provide a reference to

@misc{roth2020revisiting,
    title={Revisiting Training Strategies and Generalization Performance in Deep Metric Learning},
    author={Karsten Roth and Timo Milbich and Samarth Sinha and Prateek Gupta and Björn Ommer and Joseph Paul Cohen},
    year={2020},
    eprint={2002.08473},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Based on an extendend version of this repo, we have created a thorough comparison and evaluation of Deep Metric Learning:

https://arxiv.org/abs/2002.08473

The newly released code can be found here: https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch

It contains more criteria, miner, metrics and logging options!


For usage, go to section 3 - for results to section 4

1. Overview

This repository contains a full, easily extendable pipeline to test and implement current and new deep metric learning methods. For referencing and testing, this repo contains implementations/dataloaders for:

Loss Functions

Sampling Methods

Datasets

Architectures

NOTE: PKU Vehicle-ID is (optional) because there is no direct way to download the dataset, as it requires special licensing. However, if this dataset becomes available (in the structure shown in part 2.2), it can be used directly.


1.1 Related Repos:


2. Repo & Dataset Structure

2.1 Repo Structure

Repository
│   ### General Files
│   README.md
│   requirements.txt    
│   installer.sh
|
|   ### Main Scripts
|   Standard_Training.py     (main training script)
|   losses.py   (collection of loss and sampling impl.)
│   datasets.py (dataloaders for all datasets)
│   
│   ### Utility scripts
|   auxiliaries.py  (set of useful utilities)
|   evaluate.py     (set of evaluation functions)
│   
│   ### Network Scripts
|   netlib.py       (contains impl. for ResNet50)
|   googlenet.py    (contains impl. for GoogLeNet)
│   
│   
└───Training Results (generated during Training)
|    │   e.g. cub200/Training_Run_Name
|    │   e.g. cars196/Training_Run_Name
|
│   
└───Datasets (should be added, if one does not want to set paths)
|    │   cub200
|    │   cars196
|    │   online_products
|    │   in-shop
|    │   vehicle_id

2.2 Dataset Structures

CUB200-2011/CARS196

cub200/cars196
└───images
|    └───001.Black_footed_Albatross
|           │   Black_Footed_Albatross_0001_796111
|           │   ...
|    ...

Online Products

online_products
└───images
|    └───bicycle_final
|           │   111085122871_0.jpg
|    ...
|
└───Info_Files
|    │   bicycle.txt
|    │   ...

In-Shop Clothes

in-shop
└─img
|    └─MEN
|         └─Denim
|               └─id_00000080
|                  │   01_1_front.jpg
|                  │   ...
|               ...
|         ...
|    ...
|
└─Eval
|  │   list_eval_partition.txt

PKU Vehicle ID

vehicle_id
└───image
|     │   <img>.jpg
|     |   ...
|     
└───train_test_split
|     |   test_list_800.txt
|     |   ...

3. Using the Pipeline

[1.] Requirements

The pipeline is build around Python3 (i.e. by installing Miniconda https://conda.io/miniconda.html') and Pytorch 1.0.0/1. It has been tested around cuda 8 and cuda 9.

To install the required libraries, either directly check requirements.txt or create a conda environment:

conda create -n <Env_Name> python=3.6

Activate it

conda activate <Env_Name>

and run

bash installer.sh

Note that for kMeans- and Nearest Neighbour Computation, the library faiss is used, which can allow to move these computations to GPU if speed is desired. However, in most cases, faiss is fast enough s.t. the computation of evaluation metrics is no bottleneck.
NOTE: If one wishes not to use faiss but standard sklearn, simply use auxiliaries_nofaiss.py to replace auxiliaries.py when importing the libraries.

[2.] Exemplary Runs

The main script is Standard_Training.py. If running without input arguments, training of ResNet50 on CUB200-2011 with Marginloss and Distance-sampling is performed.
Otherwise, the following flags suffice to train with different losses, sampling methods, architectures and datasets:

python Standard_Training.py --dataset <dataset> --loss <loss> --sampling <sampling> --arch <arch> --k_vals <k_vals> --embed_dim <embed_dim>

The following flags are available:

  • <dataset> <- cub200, cars196, online_products, in-shop, vehicle_id
  • <loss> <- marginloss, triplet, npair, proxynca
  • <sampling> <- distance, semihard, random, npair
  • <arch> <- resnet50, googlenet
  • <k_vals> <- List of Recall @ k values to evaluate on, e.g. 1 2 4 8
  • <embed_dim> <- Network embedding dimension. Default: 128 for ResNet50, 512 for GoogLeNet.

For all other training-specific arguments (e.g. batch-size, num. training epochs., ...), simply refer to the input arguments in Standard_Training.py.

NOTE: If one wishes to use a different learning rate for the final linear embedding layer, the flag --fc_lr_mul needs to be set to a value other than zero (i.e. 10 as is done in various implementations).

Finally, to decide the GPU to use and the name of the training folder in which network weights, sample recoveries and metrics are stored, set:

python Standard_Training.py --gpu <gpu_id> --savename <name_of_training_run>

If --savename is not set, a default name based on the starting date will be chosen.

If one wishes to simply use standard parameters and wants to get close to literature results (more or less, depends on seeds and overall training scheduling), refer to sample_training_runs.sh, which contains a list of executable one-liners.

[3.] Implementation Notes regarding Extendability:

To extend or test other sampling or loss methods, simply do:

For Batch-based Sampling:
In losses.py, add the sampling method, which should act on a batch (and the resp. set of labels), e.g.:

def new_sampling(self, batch, label, **additional_parameters): ...

This function should, if it needs to run with existing losses, a list of tuples containing indexes with respect to the batch, e.g. for sampling methods returning triplets:

return [(anchor_idx, positive_idx, negative_idx) for anchor_idx, positive_idx, negative_idx in zip(anchor_idxs, positive_idxs, negative_idxs)]

Also, don't forget to add a handle in Sampler.__init__().

For Data-specific Sampling:
To influence the data samples used to generate the batches, in datasets.py edit BaseTripletDataset.

For New Loss Functions:
Simply add a new class inheriting from torch.nn.Module. Refer to other loss variants to see how to do so. In general, include an instance of the Sampler-class, which will provide sampled data tuples during a forward()-pass, by calling self.sampler_instance.give(batch, labels, **additional_parameters).
Finally, include the loss function in the loss_select()-function. Parameters can be passed through the dictionary-notation (see other examples) and if learnable parameters are added, include them in the to_optim-list.

[4.] Stored Data:

By default, the following files are saved:

Name_of_Training_Run
|  checkpoint.pth.tar   -> Contains network state-dict.
|  hypa.pkl             -> Contains all network parameters as pickle.
|                          Can be used directly to recreate the network.
| log_train_Base.csv    -> Logged training data as CSV.                      
| log_val_Base.csv      -> Logged test metrics as CSV.                    
| Parameter_Info.txt    -> All Parameters stored as readable text-file.
| InfoPlot_Base.svg     -> Graphical summary of training/testing metrics progression.
| sample_recoveries.png -> Sample recoveries for best validation weights.
|                          Acts as a sanity test.

Sample Recoveries Note: Red denotes query images, while green show the resp. nearest neighbours.

Sample Recoveries Note: The header in the summary plot shows the best testing metrics over the whole run.

[5.] Additional Notes:

To finalize, several flags might be of interest when examining the respective runs:

--dist_measure: If set, the ratio of mean intraclass-distances over mean interclass distances
                (by measure of center-of-mass distances) is computed after each epoch and stored/plotted.
--grad_measure: If set, the average (absolute) gradients from the embedding layer to the last
                conv. layer are stored in a Pickle-File. This can be used to examine the change of features during each iteration.

For more details, refer to the respective classes in auxiliaries.py.


4. Results

These results are supposed to be performance estimates achieved by running the respective commands in sample_training_runs.sh. Note that the learning rate scheduling might not be fully optimised, so these values should only serve as reference/expectation, not what can be ultimately achieved with more tweaking.

Note also that there is a not insignificant dependency on the used seed.

CUB200

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 2 -- 4 -- 8
ResNet50 Margin/Distance 68.2 38.7 63.4 -- 74.9 -- 86.0 -- 90.4
ResNet50 Triplet/Softhard 66.2 35.5 61.2 -- 73.2 -- 82.4 -- 89.5
ResNet50 NPair/None 65.4 33.8 59.0 -- 71.3 -- 81.1 -- 88.8
ResNet50 ProxyNCA/None 68.1 38.1 64.0 -- 75.4 -- 84.2 -- 90.5

Cars196

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 2 -- 4 -- 8
ResNet50 Margin/Distance 67.2 37.6 79.3 -- 87.1 -- 92.1 -- 95.4
ResNet50 Triplet/Softhard 64.4 32.4 75.4 -- 84.2 -- 90.1 -- 94.1
ResNet50 NPair/None 62.3 30.1 69.5 -- 80.2 -- 87.3 -- 92.1
ResNet50 ProxyNCA/None 66.3 35.8 80.0 -- 87.2 -- 91.8 -- 95.1

Online Products

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 10 -- 100 -- 1000
ResNet50 Margin/Distance 89.6 34.9 76.1 -- 88.7 -- 95.1 -- 98.3
ResNet50 Triplet/Softhard 89.1 33.7 74.3 -- 87.6 -- 94.9 -- 98.5
ResNet50 NPair/None 88.8 31.1 70.9 -- 85.2 -- 93.8 -- 98.2

In-Shop Clothes

Architecture Loss/Sampling NMI F1 Recall @ 1 -- 10 -- 20 -- 30 -- 50
ResNet50 Margin/Distance 88.2 27.7 84.5 -- 96.1 -- 97.4 -- 97.9 -- 98.5
ResNet50 Triplet/Semihard 89.0 30.8 83.9 -- 96.3 -- 97.6 -- 98.4 -- 98.8
ResNet50 NPair/None 88.0 27.6 80.9 -- 95.0 -- 96.6 -- 97.5 -- 98.2

NOTE:

  1. Regarding Vehicle-ID: Due to the number of test sets, size of the training set and little public accessibility, results are not included for the time being.
  2. Regarding ProxyNCA for Online Products and In-Shop Clothes: Due to the high number of classes, the number of proxies required is too high for useful training (>10000 proxies).

ToDO:

  • Fix Version in requirements.txt
  • Add Results for Implementations
  • Finalize Comments
  • Add Inception-BN
  • Add Lifted Structure Loss
Owner
Karsten Roth
PhD (IMPRS-IS, ELLIS) EML Tuebingen | prev. @VectorInstitute, @mila-iqia and @aws.
Karsten Roth
Breast Cancer Classification Model is applied on a different dataset

Breast Cancer Classification Model is applied on a different dataset

1 Feb 04, 2022
Web mining module for Python, with tools for scraping, natural language processing, machine learning, network analysis and visualization.

Pattern Pattern is a web mining module for Python. It has tools for: Data Mining: web services (Google, Twitter, Wikipedia), web crawler, HTML DOM par

Computational Linguistics Research Group 8.4k Jan 03, 2023
Deep Learning & 3D Convolutional Neural Networks for Speaker Verification

TensorFlow implementation of 3D Convolutional Neural Networks for Speaker Verification - Official Project Page - Pytorch Implementation This repositor

Amirsina Torfi 753 Dec 17, 2022
Machine Learning automation and tracking

The Open-Source MLOps Orchestration Framework MLRun is an open-source MLOps framework that offers an integrative approach to managing your machine-lea

873 Jan 04, 2023
Cupytorch - A small framework mimics PyTorch using CuPy or NumPy

CuPyTorch CuPyTorch是一个小型PyTorch,名字来源于: 不同于已有的几个使用NumPy实现PyTorch的开源项目,本项目通过CuPy支持

Xingkai Yu 23 Aug 17, 2022
A blender add-on that automatically re-aligns wrong axis objects.

Auto Align A blender add-on that automatically re-aligns wrong axis objects. Usage There are three options available in the 3D Viewport Sidebar It

29 Nov 25, 2022
pytorch implementation for PointNet

PointNet.pytorch This repo is implementation for PointNet in pytorch. The model is in pointnet/model.py. It is teste

Fei Xia 1.7k Dec 30, 2022
StackNet is a computational, scalable and analytical Meta modelling framework

StackNet This repository contains StackNet Meta modelling methodology (and software) which is part of my work as a PhD Student in the computer science

Marios Michailidis 1.3k Dec 15, 2022
Pytorch for Segmentation

Pytorch for Semantic Segmentation This repo has been deprecated currently and I will not maintain it. Meanwhile, I strongly recommend you can refer to

ycszen 411 Nov 22, 2022
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

78 Dec 27, 2022
Official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch.

Multi-speaker DGP This repository provides official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch. O

sarulab-speech 24 Sep 07, 2022
Differentiable Factor Graph Optimization for Learning Smoothers @ IROS 2021

Differentiable Factor Graph Optimization for Learning Smoothers Overview Status Setup Datasets Training Evaluation Acknowledgements Overview Code rele

Brent Yi 60 Nov 14, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Code for the SIGGRAPH 2022 paper "DeltaConv: Anisotropic Operators for Geometric Deep Learning on Point Clouds."

DeltaConv [Paper] [Project page] Code for the SIGGRAPH 2022 paper "DeltaConv: Anisotropic Operators for Geometric Deep Learning on Point Clouds" by Ru

98 Nov 26, 2022
Research shows Google collects 20x more data from Android than Apple collects from iOS. Block this non-consensual telemetry using pihole blocklists.

pihole-antitelemetry Research shows Google collects 20x more data from Android than Apple collects from iOS. Block both using these pihole lists. Proj

Adrian Edwards 290 Jan 09, 2023
Explaining Deep Neural Networks - A comparison of different CAM methods based on an insect data set

Explaining Deep Neural Networks - A comparison of different CAM methods based on an insect data set This is the repository for the Deep Learning proje

Robert Krug 3 Feb 06, 2022
A PyTorch Library for Accelerating 3D Deep Learning Research

Kaolin: A Pytorch Library for Accelerating 3D Deep Learning Research Overview NVIDIA Kaolin library provides a PyTorch API for working with a variety

NVIDIA GameWorks 3.5k Jan 07, 2023
E-RAFT: Dense Optical Flow from Event Cameras

E-RAFT: Dense Optical Flow from Event Cameras This is the code for the paper E-RAFT: Dense Optical Flow from Event Cameras by Mathias Gehrig, Mario Mi

Robotics and Perception Group 71 Dec 12, 2022
A fast and easy to use, moddable, Python based Minecraft server!

PyMine PyMine - The fastest, easiest to use, Python-based Minecraft Server! Features Note: This list is not always up to date, and doesn't contain all

PyMine 144 Dec 30, 2022
NeuralWOZ: Learning to Collect Task-Oriented Dialogue via Model-based Simulation (ACL-IJCNLP 2021)

NeuralWOZ This code is official implementation of "NeuralWOZ: Learning to Collect Task-Oriented Dialogue via Model-based Simulation". Sungdong Kim, Mi

NAVER AI 31 Oct 25, 2022