A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Python package for reference counting native pointers

refcount master: testing: This package is primarily for managing resources in native libraries, written for instance in C++, from Python. While it boi

CSIRO Hydroinformatics 2 Nov 03, 2022
Flask html response minifier

Flask-HTMLmin Minify flask text/html mime type responses. Just add MINIFY_HTML = True to your deployment config to minify HTML and text responses of y

Hamid Feizabadi 85 Dec 07, 2022
A country information finder module

A country information finder module

Fayas Noushad 3 Nov 28, 2021
This synchronizes my appearances with my calendar

Josh's Schedule Synchronizer Here's the "problem:" I use a Google Sheets spreadsheet to maintain all my public appearances.

Developer Advocacy 2 Oct 18, 2021
Spinning waffle from waffle shaped python code

waffle Spinning waffle from waffle shaped python code Based on a parametric curve: r(t) = 2 - 2*sin(t) + (sin(t)*abs(cos(t)))/(sin(t) + 1.4) projected

Viljar Femoen 5 Feb 14, 2022
A play store search module

A play store search module

Fayas Noushad 5 Dec 01, 2021
Exam assignment for Laboratory of Bioinformatics 2

Exam assignment for Laboratory of Bioinformatics 2 (Alma Mater University of Bologna, Master in Bioinformatics)

2 Oct 22, 2022
Anki cards generator for Leetcode

Leetcode Anki card generator Summary By running this script you'll be able to generate Anki cards with all the leetcode problems. I personally use it

Pavel Safronov 150 Dec 25, 2022
Search and Find Jobs in Ethiopia

✨ EthioJobs ✨ Search and Find Jobs in Ethiopia Easy start critical warning Use pycharm No vscode No sublime No Vim No nothing when you want to use

Abdimk 12 Nov 09, 2022
Demo scripts for the Kubernetes Security Webinar

Kubernetes Security Webinar [in Russian] YouTube video (October 13, 2021) Authors: Artem Yushkovsky (LinkedIn, GitHub) Maxim Mosharov @ Whitespots.io

Slurm 34 Dec 06, 2022
Python Repository for Bachelor Ski Sign.

BachelorSkiSign Python Repository for Bachelor Ski Sign. This application reads data from https://bachelorapi.azurewebsites.net/ It is written in Ciru

Winston 1 Jan 04, 2022
Awesome open-source alternatives to SaaS

Awesome-oss-alternatives - Awesome list of open-source startup alternatives to well-known SaaS products

Runa Capital 12.7k Jan 03, 2023
Multi-Probe Attention for Semantic Indexing

Multi-Probe Attention for Semantic Indexing About This project is developed for the topic of COVID-19 semantic indexing. Directories & files A. The di

Jinghang Gu 1 Dec 18, 2022
Starscape is a Blender add-on for adding stars to the background of a scene.

Starscape Starscape is a Blender add-on for adding stars to the background of a scene. Features The add-on provides the following features: Procedural

Marco Rossini 5 Jun 24, 2022
Fully cross-platform toolkit (and library!) for MachO+Obj-C editing/analysis

fully cross-platform toolkit (and library!) for MachO+Obj-C editing/analysis. Includes a cli kit, a curses GUI, ObjC header dumping, and much more.

cynder 301 Dec 28, 2022
Python library to decode the EU Covid-19 vaccine certificate

DCC Utils Python library to decode the EU Covid-19 vaccine certificate, as specified by the EU. Setup pip install dcc-utils Make sure zbar is installe

Developers Italia 13 Mar 11, 2022
Coinloggr - A learning resource and social platform for the coin collecting community

Coinloggr A learning resource and social platform for the coin collecting commun

John Galiszewski 1 Jan 10, 2022
MODeflattener deobfuscates control flow flattened functions obfuscated by OLLVM using Miasm.

MODeflattener deobfuscates control flow flattened functions obfuscated by OLLVM using Miasm.

Suraj Malhotra 138 Jan 07, 2023
En este repositorio pondré archivos graciositos de python que hago de vez en cuando

🐍 Apuntes de python 🐍 ¿Quién soy? 👽 Saludos,mi nombre es Carlos Lara. Pero mi nickname en internet es Hercules Kan. Soy un programador autodidacta

Carlos E. Lara 3 Nov 16, 2021
A tool to build reproducible wheels for you Python project or for all of your dependencies

asaman: Amra Saman (আমরা সমান) This is a tool to build reproducible wheels for your Python project or for all of your dependencies. What this means is

Kushal Das 14 Aug 05, 2022