A set of tests for evaluating large-scale algorithms for Wasserstein-2 transport maps computation.

Overview

Continuous Wasserstein-2 Benchmark

This is the official Python implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark (paper on arxiv) by Alexander Korotin, Lingxiao Li, Aude Genevay, Justin Solomon, Alexander Filippov and Evgeny Burnaev.

The repository contains a set of continuous benchmark measures for testing optimal transport solvers for quadratic cost (Wasserstein-2 distance), the code for optimal transport solvers and their evaluation.

Citation

@article{korotin2021neural,
  title={Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark},
  author={Korotin, Alexander and Li, Lingxiao and Genevay, Aude and Solomon, Justin and Filippov, Alexander and Burnaev, Evgeny},
  journal={arXiv preprint arXiv:2106.01954},
  year={2021}
}

Pre-requisites

The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with

torch==1.3.0 torchvision==0.4.1

The code might not run as intended in newer torch versions.

Related repositories

Loading Benchmark Pairs

from src import map_benchmark as mbm

# Load benchmark pair for dimension 16 (2, 4, ..., 256)
benchmark = mbm.Mix3ToMix10Benchmark(16)
# OR load 'Early' images benchmark pair ('Early', 'Mid', 'Late')
# benchmark = mbm.CelebA64Benchmark('Early')

# Sample 32 random points from the benchmark measures
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)

# Compute the true forward map for points X
X.requires_grad_(True)
Y_true = benchmark.map_fwd(X, nograd=True)

Repository structure

All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/). Auxilary source code is moved to .py modules (src/). Continuous benchmark pairs are stored as .pt checkpoints (benchmarks/).

Evaluation of Existing Solvers

We provide all the code to evaluate existing dual OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper.

Testing Existing Solvers On High-Dimensional Benchmarks

  • notebooks/MM_test_hd_benchmark.ipynb -- testing [MM], [MMv2] solvers and their reversed versions
  • notebooks/MMv1_test_hd_benchmark.ipynb -- testing [MMv1] solver
  • notebooks/MM-B_test_hd_benchmark.ipynb -- testing [MM-B] solver
  • notebooks/W2_test_hd_benchmark.ipynb -- testing [W2] solver and its reversed version
  • notebooks/QC_test_hd_benchmark.ipynb -- testing [QC] solver
  • notebooks/LS_test_hd_benchmark.ipynb -- testing [LS] solver

Testing Existing Solvers On Images Benchmark Pairs (CelebA 64x64 Aligned Faces)

  • notebooks/MM_test_images_benchmark.ipynb -- testing [MM] solver and its reversed version
  • notebooks/W2_test_images_benchmark.ipynb -- testing [W2]
  • notebooks/MM-B_test_images_benchmark.ipynb -- testing [MM-B] solver
  • notebooks/QC_test_images_benchmark.ipynb -- testing [QC] solver

[LS], [MMv2], [MMv1] solvers are not considered in this experiment.

Generative Modeling by Using Existing Solvers to Compute Loss

Warning: training may take several days before achieving reasonable FID scores!

  • notebooks/MM_test_image_generation.ipynb -- generative modeling by [MM] solver or its reversed version
  • notebooks/W2_test_image_generation.ipynb -- generative modeling by [W2] solver

For [QC] solver we used the code from the official WGAN-QC repo.

Training Benchmark Pairs From Scratch

This code is provided for completeness and is not intended to be used to retrain existing benchmark pairs, but might be used as the base to train new pairs on new datasets. High-dimensional benchmak pairs can be trained from scratch. Training images benchmark pairs requires generator network checkpoints. We used WGAN-QC model to provide such checkpoints.

  • notebooks/W2_train_hd_benchmark.ipynb -- training high-dimensional benchmark bairs by [W2] solver
  • notebooks/W2_train_images_benchmark.ipynb -- training images benchmark bairs by [W2] solver

Credits

Owner
Alexander
PhD Student (Computer Science) at Skolkovo University of Science and Technology (Moscow, Russia)
Alexander
Implementation of Multistream Transformers in Pytorch

Multistream Transformers Implementation of Multistream Transformers in Pytorch. This repository deviates slightly from the paper, where instead of usi

Phil Wang 47 Jul 26, 2022
tsflex - feature-extraction benchmarking

tsflex - feature-extraction benchmarking This repository withholds the benchmark results and visualization code of the tsflex paper and toolkit. Flow

PreDiCT.IDLab 5 Mar 25, 2022
Bagua is a flexible and performant distributed training algorithm development framework.

Bagua is a flexible and performant distributed training algorithm development framework.

786 Dec 17, 2022
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos.

A Latent Transformer for Disentangled Face Editing in Images and Videos Official implementation for paper: A Latent Transformer for Disentangled Face

InterDigital 108 Dec 09, 2022
Pre-Trained Image Processing Transformer (IPT)

Pre-Trained Image Processing Transformer (IPT) By Hanting Chen, Yunhe Wang, Tianyu Guo, Chang Xu, Yiping Deng, Zhenhua Liu, Siwei Ma, Chunjing Xu, Cha

HUAWEI Noah's Ark Lab 332 Dec 18, 2022
Tensorflow 2 implementation of the paper: Learning and Evaluating Representations for Deep One-class Classification published at ICLR 2021

Deep Representation One-class Classification (DROC). This is not an officially supported Google product. Tensorflow 2 implementation of the paper: Lea

Google Research 137 Dec 23, 2022
ROS-UGV-Control-Interface - Control interface which can be used in any UGV

ROS-UGV-Control-Interface Cam Closed: Cam Opened:

Ahmet Fatih Akcan 1 Nov 04, 2022
Facial recognition project

Facial recognition project documentation Project introduction This project is developed by linuxu. It is a face model recognition project developed ba

Jefferson 2 Dec 04, 2022
EMNLP 2021 paper The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers.

Codebase for training transformers on systematic generalization datasets. The official repository for our EMNLP 2021 paper The Devil is in the Detail:

Csordás Róbert 57 Nov 21, 2022
Consistency Regularization for Adversarial Robustness

Consistency Regularization for Adversarial Robustness Official PyTorch implementation of Consistency Regularization for Adversarial Robustness by Jiho

40 Dec 17, 2022
Fast and accurate optimisation for registration with little learningconvexadam

convexAdam Learn2Reg 2021 Submission Fast and accurate optimisation for registration with little learning Excellent results on Learn2Reg 2021 challeng

17 Dec 06, 2022
Reproducing Results from A Hybrid Approach to Targeting Social Assistance

title author date output Reproducing Results from A Hybrid Approach to Targeting Social Assistance Lendie Follett and Heath Henderson 12/28/2021 html_

Lendie Follett 0 Jan 06, 2022
LightHuBERT: Lightweight and Configurable Speech Representation Learning with Once-for-All Hidden-Unit BERT

LightHuBERT LightHuBERT: Lightweight and Configurable Speech Representation Learning with Once-for-All Hidden-Unit BERT | Github | Huggingface | SUPER

WangRui 46 Dec 29, 2022
Codes for [NeurIPS'21] You are caught stealing my winning lottery ticket! Making a lottery ticket claim its ownership.

You are caught stealing my winning lottery ticket! Making a lottery ticket claim its ownership Codes for [NeurIPS'21] You are caught stealing my winni

VITA 8 Nov 01, 2022
A Protein-RNA Interface Predictor Based on Semantics of Sequences

PRIP PRIP:A Protein-RNA Interface Predictor Based on Semantics of Sequences installation gensim==3.8.3 matplotlib==3.1.3 xgboost==1.3.3 prettytable==2

李优 0 Mar 25, 2022
CondNet: Conditional Classifier for Scene Segmentation

CondNet: Conditional Classifier for Scene Segmentation Introduction The fully convolutional network (FCN) has achieved tremendous success in dense vis

ycszen 31 Jul 22, 2022
RSNA Intracranial Hemorrhage Detection with python

RSNA Intracranial Hemorrhage Detection This is the source code for the first place solution to the RSNA2019 Intracranial Hemorrhage Detection Challeng

24 Nov 30, 2022
PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations.

HPNet This repository contains the PyTorch implementation of paper: HPNet: Deep Primitive Segmentation Using Hybrid Representations. Installation The

Siming Yan 42 Dec 07, 2022
pcnaDeep integrates cutting-edge detection techniques with tracking and cell cycle resolving models.

pcnaDeep: a deep-learning based single-cell cycle profiler with PCNA signal Welcome! pcnaDeep integrates cutting-edge detection techniques with tracki

ChanLab 8 Oct 18, 2022