Official code for HH-VAEM

Overview

HH-VAEM

This repository contains the official Pytorch implementation of the Hierarchical Hamiltonian VAE for Mixed-type Data (HH-VAEM) model and the sampling-based feature acquisition technique presented in the paper Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo. HH-VAEM is a Hierarchical VAE model for mixed-type incomplete data that uses Hamiltonian Monte Carlo with automatic hyper-parameter tuning for improved approximate inference. The repository contains the implementation and the experiments provided in the paper.

Please, if you use this code, cite the preprint using:

@article{peis2022missing,
  title={Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo},
  author={Peis, Ignacio and Ma, Chao and Hern{\'a}ndez-Lobato, Jos{\'e} Miguel},
  journal={arXiv preprint arXiv:2202.04599},
  year={2022}
}

Instalation

The installation is straightforward using the following instruction, that creates a conda virtual environment named HH-VAEM using the provided file environment.yml:

conda env create -f environment.yml

Usage

Training

The project is developed in the recent research framework PyTorch Lightning. The HH-VAEM model is implemented as a LightningModule that is trained by means of a Trainer. A model can be trained by using:

# Example for training HH-VAEM on Boston dataset
python train.py --model HHVAEM --dataset boston --split 0

This will automatically download the boston dataset, split in 10 train/test splits and train HH-VAEM on the training split 0. Two folders will be created: data/ for storing the datasets and logs/ for model checkpoints and TensorBoard logs. The variable LOGDIR can be modified in src/configs.py to change the directory where these folders will be created (this might be useful for avoiding overloads in network file systems).

The following datasets are available:

  • A total of 10 UCI datasets: avocado, boston, energy, wine, diabetes, concrete, naval, yatch, bank or insurance.
  • The MNIST datasets: mnist or fashion_mnist.
  • More datasets can be easily added to src/datasets.py.

For each dataset, the corresponding parameter configuration must be added to src/configs.py.

The following models are also available (implemented in src/models/):

  • HHVAEM: the proposed model in the paper.
  • VAEM: the VAEM strategy presented in (Ma et al., 2020) with Gaussian encoder (without including the Partial VAE).
  • HVAEM: A Hierarchical VAEM with two layers of latent variables and a Gaussian encoder.
  • HMCVAEM: A VAEM that includes a tuned HMC sampler for the true posterior.
  • For MNIST datasets (non heterogeneous data), use HHVAE, VAE, HVAE and HMCVAE.

By default, the test stage will be executed at the end of the training stage. This can be cancelled with --test 0 for manually running the test using:

# Example for testing HH-VAEM on Boston dataset
python test.py --model HHVAEM --dataset boston --split 0

which will load the trained model to be tested on the boston test split number 0. Once all the splits are tested, the average results can be obtained using the script in the run/ folder:

# Example for obtaining the average test results with HH-VAEM on Boston dataset
python test_splits.py --model HHVAEM --dataset boston

Experiments

The experiments in the paper can be executed using:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning.py --model HHVAEM --dataset boston --method mi --split 0

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist --split 0

Once this is executed on all the splits, you can plot the SAIA error curves or obtain the average OoD metrics using the scripts in the run/ folder:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning_plots.py --models VAEM HHVAEM --dataset boston

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood_splits.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist


Help

Use the --help option for documentation on the usage of any of the mentioned scripts.

Contributors

Ignacio Peis
Chao Ma
José Miguel Hernández-Lobato

Contact

For further information: [email protected]

Owner
Ignacio Peis
PhD student at UC3M \\ Visitor at the Machine Learning Group, CBL, University of Cambridge
Ignacio Peis
PyCaret is an open-source, low-code machine learning library in Python that automates machine learning workflows.

An open-source, low-code machine learning library in Python 🚀 Version 2.3.5 out now! Check out the release notes here. Official • Docs • Install • Tu

PyCaret 6.7k Jan 08, 2023
Microsoft 5.6k Jan 07, 2023
PySurvival is an open source python package for Survival Analysis modeling

PySurvival What is Pysurvival ? PySurvival is an open source python package for Survival Analysis modeling - the modeling concept used to analyze or p

Square 265 Dec 27, 2022
Tool for producing high quality forecasts for time series data that has multiple seasonality with linear or non-linear growth.

Prophet: Automatic Forecasting Procedure Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends ar

Facebook 15.4k Jan 07, 2023
Stacked Generalization (Ensemble Learning)

Stacking (stacked generalization) Overview ikki407/stacking - Simple and useful stacking library, written in Python. User can use models of scikit-lea

Ikki Tanaka 192 Dec 23, 2022
Distributed scikit-learn meta-estimators in PySpark

sk-dist: Distributed scikit-learn meta-estimators in PySpark What is it? sk-dist is a Python package for machine learning built on top of scikit-learn

Ibotta 282 Dec 09, 2022
A simple machine learning python sign language detection project.

SST Coursework 2022 About the app A python application that utilises the tensorflow object detection algorithm to achieve automatic detection of ameri

Xavier Koh 2 Jun 30, 2022
LiuAlgoTrader is a scalable, multi-process ML-ready framework for effective algorithmic trading

LiuAlgoTrader is a scalable, multi-process ML-ready framework for effective algorithmic trading. The framework simplify development, testing, deployment, analysis and training algo trading strategies

Amichay Oren 458 Dec 24, 2022
Markov bot - A Writing bot based on Markov Chain for Data Structure Lab

基于马尔可夫链的写作机器人 前端 用html/css完成 Demo展示(已给出文本的相应展示) 用户提供相关的语料库后训练的成果 后端 要完成的几个接口 解析文

DysprosiumDy 9 May 05, 2022
MasTrade is a trading bot in baselines3,pytorch,gym

mastrade MasTrade is a trading bot in baselines3,pytorch,gym idea we have for example 1 btc and we buy a crypto with it with market option to trade in

Masoud Azizi 18 May 24, 2022
A machine learning model for Covid case prediction

CovidcasePrediction A machine learning model for Covid case prediction Problem Statement Using regression algorithms we can able to track the active c

VijayAadhithya2019rit 1 Feb 02, 2022
AutoTabular automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications.

AutoTabular automates machine learning tasks enabling you to easily achieve strong predictive performance in your applications. With just a few lines of code, you can train and deploy high-accuracy m

Robin 55 Dec 27, 2022
Dragonfly is an open source python library for scalable Bayesian optimisation.

Dragonfly is an open source python library for scalable Bayesian optimisation. Bayesian optimisation is used for optimising black-box functions whose

744 Jan 02, 2023
K-means clustering is a method used for clustering analysis, especially in data mining and statistics.

K Means Algorithm What is K Means This algorithm is an iterative algorithm that partitions the dataset according to their features into K number of pr

1 Nov 01, 2021
Steganography is the art of hiding the fact that communication is taking place, by hiding information in other information.

Steganography is the art of hiding the fact that communication is taking place, by hiding information in other information.

Priyansh Sharma 7 Nov 09, 2022
scikit-multimodallearn is a Python package implementing algorithms multimodal data.

scikit-multimodallearn is a Python package implementing algorithms multimodal data. It is compatible with scikit-learn, a popul

12 Jun 29, 2022
Toolkit for building machine learning models that generalize to unseen domains and are robust to privacy and other attacks.

Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG) Divyat Mahajan, Shruti Tople, Amit Sharma Privacy & Causal Learning

Microsoft 149 Jan 06, 2023
Crunchdao - Python API for the Crunchdao machine learning tournament

Python API for the Crunchdao machine learning tournament Interact with the Crunc

3 Jan 19, 2022
Simple, fast, and parallelized symbolic regression in Python/Julia via regularized evolution and simulated annealing

Parallelized symbolic regression built on Julia, and interfaced by Python. Uses regularized evolution, simulated annealing, and gradient-free optimization.

Miles Cranmer 924 Jan 03, 2023
MLReef is an open source ML-Ops platform that helps you collaborate, reproduce and share your Machine Learning work with thousands of other users.

The collaboration platform for Machine Learning MLReef is an open source ML-Ops platform that helps you collaborate, reproduce and share your Machine

MLReef 1.4k Dec 27, 2022