Computing Shapley values using VAEAC

Overview

Shapley values and the VAEAC method

In this GitHub repository, we present the implementation of the VAEAC approach from our paper "Using Shapley Values and Variational Autoencoders to Explain Predictive Models with Dependent Mixed Features", see Olsen et al. (2021).

The variational autoencoder with arbitrary condiditioning (VAEAC) approach is based on the work of (Ivanov et al., 2019). The VAEAC is an extension of the regular variational autoencoder (Kingma and Welling, 2019). Instead of giving a probabilistic representation for the distribution equation it gives a representation for the conditional distribution equation, for all possible feature subsets equation simultaneously, where equation is the set of all features.

To make the VAEAC methodology work in the Shapley value framework, established in the R-package Shapr (Sellereite and Jullum, 2019), we have made alterations to the original implementation of Ivanov.

The VAEAC model is implemented in Pytorch, hence, that portion of the repository is written in Python. To compute the Shapley values, we have written the necessary R-code to make the VAEAC approach run on top of the R-package shapr.

Setup

In addition to the prerequisites required by Ivanov, we also need several R-packages. All prerequisites are specified in requirements.txt.

This code was tested on Linux and macOS (should also work on Windows), Python 3.6.4, PyTorch 1.0. and R 4.0.2.

To user has to specify the system path to the Python environment and the system path of the downloaded repository in Source_Shapr_VAEAC.R.

Example

The following example shows how a random forest model is trained on the Abalone data set from the UCI machine learning repository, and how shapr explains the individual predictions.

Note that we only use Diameter (continuous), ShuckedWeight (continuous), and Sex (categorical) as features and let the response be Rings, that is, the age of the abalone.

# Import libraries
library(shapr)
library(ranger)
library(data.table)

# Load the R files needed for computing Shapley values using VAEAC.
source("/Users/larsolsen/Desktop/PhD/R_Codes/Source_Shapr_VAEAC.R")

# Set the working directory to be the root folder of the GitHub repository. 
setwd("~/PhD/Paper1/Code_for_GitHub")

# Read in the Abalone data set.
abalone = readRDS("data/Abalone.data")
str(abalone)

# Predict rings based on Diameter, ShuckedWeight, and Sex (categorical), using a random forrest model.
model = ranger(Rings ~ Diameter + ShuckedWeight + Sex, data = abalone[abalone$test_instance == FALSE,])

# Specifying the phi_0, i.e. the expected prediction without any features.
phi_0 <- mean(abalone$Rings[abalone$test_instance == FALSE])

# Prepare the data for explanation. Diameter, ShuckedWeight, and Sex correspond to 3,6,9.
explainer <- shapr(abalone[abalone$test_instance == FALSE, c(3,6,9)], model)
#> The specified model provides feature classes that are NA. The classes of data are taken as the truth.

# Train the VAEAC model with specified parameters and add it to the explainer
explainer_added_vaeac = add_vaeac_to_explainer(
  explainer, 
  epochs = 30L,
  width = 32L,
  depth = 3L,
  latent_dim = 8L,
  lr = 0.002,
  num_different_vaeac_initiate = 2L,
  epochs_initiation_phase = 2L,
  validation_iwae_num_samples = 25L,
  verbose_summary = TRUE)

# Computing the actual Shapley values with kernelSHAP accounting for feature dependence using
# the VAEAC distribution approach with parameters defined above
explanation = explain.vaeac(abalone[abalone$test_instance == TRUE][1:8,c(3,6,9)],
                            approach = "vaeac",
                            explainer = explainer_added_vaeac,
                            prediction_zero = phi_0,
                            which_vaeac_model = "best")

# Printing the Shapley values for the test data.
# For more information about the interpretation of the values in the table, see ?shapr::explain.
print(explanation$dt)
#>        none   Diameter  ShuckedWeight        Sex
#> 1: 9.927152  0.63282471     0.4175608  0.4499676
#> 2: 9.927152 -0.79836795    -0.6419839  1.5737014
#> 3: 9.927152 -0.93500891    -1.1925897 -0.9140548
#> 4: 9.927152  0.57225851     0.5306906 -1.3036202
#> 5: 9.927152 -1.24280895    -1.1766845  1.2437640
#> 6: 9.927152 -0.77290507    -0.5976597  1.5194251
#> 7: 9.927152 -0.05275627     0.1306941 -1.1755597
#> 8: 9.927153  0.44593977     0.1788577  0.6895557

# Finally, we plot the resulting explanations.
plot(explanation, plot_phi0 = FALSE)

Citation

If you find this code useful in your research, please consider citing our paper:

@misc{Olsen2021Shapley,
      title={Using Shapley Values and Variational Autoencoders to Explain Predictive Models with Dependent Mixed Features}, 
      author={Lars Henry Berge Olsen and Ingrid Kristine Glad and Martin Jullum and Kjersti Aas},
      year={2021},
      eprint={2111.13507},
      archivePrefix={arXiv},
      primaryClass={stat.ML},
      url={https://arxiv.org/abs/2111.13507}
}

References

Ivanov, O., Figurnov, M., and Vetrov, D. (2019). “Variational Autoencoder with ArbitraryConditioning”. In:International Conference on Learning Representations.

Kingma, D. P. and Welling, M. (2014). "Auto-Encoding Variational Bayes". In: 2nd International Conference on Learning Representations, ICLR 2014.

Olsen, L. H. B., Glad, I. K., Jullum, M. and Aas, K. (2021). "Using Shapley Values and Variational Autoencoders to Explain Predictive Models with Dependent Mixed Features".

Sellereite, N. and Jullum, M. (2019). “shapr: An R-package for explaining machine learningmodels with dependence-aware Shapley values”. In:Journal of Open Source Softwarevol. 5,no. 46, p. 2027.

Unofficial Pytorch Implementation of WaveGrad2

WaveGrad 2 — Unofficial PyTorch Implementation WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis Unofficial PyTorch+Lightning Implementati

MINDs Lab 104 Nov 29, 2022
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
LogDeep is an open source deeplearning-based log analysis toolkit for automated anomaly detection.

LogDeep is an open source deeplearning-based log analysis toolkit for automated anomaly detection.

donglee 279 Dec 13, 2022
Test-Time Personalization with a Transformer for Human Pose Estimation, NeurIPS 2021

Transforming Self-Supervision in Test Time for Personalizing Human Pose Estimation This is an official implementation of the NeurIPS 2021 paper: Trans

41 Nov 28, 2022
CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices.

CenterFace Introduce CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices. Recent Update 2019.09.

StarClouds 1.2k Dec 21, 2022
Official implement of Paper:A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sening images

A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images 深度监督影像融合网络DSIFN用于高分辨率双时相遥感影像变化检测 Of

Chenxiao Zhang 135 Dec 19, 2022
A set of tools to pre-calibrate and calibrate (multi-focus) plenoptic cameras (e.g., a Raytrix R12) based on the libpleno.

COMPOTE: Calibration Of Multi-focus PlenOpTic camEra. COMPOTE is a set of tools to pre-calibrate and calibrate (multifocus) plenoptic cameras (e.g., a

ComSEE - Computers that SEE 4 May 10, 2022
This Artificial Intelligence program can take a black and white/grayscale image and generate a realistic or plausible colorized version of the same picture.

Colorizer The point of this project is to write a program capable of taking a black and white / grayscale image, and generating a realistic or plausib

Maitri Shah 1 Jan 06, 2022
Fast algorithms to compute an approximation of the minimal volume oriented bounding box of a point cloud in 3D.

ApproxMVBB Status Build UnitTests Homepage Fast algorithms to compute an approximation of the minimal volume oriented bounding box of a point cloud in

Gabriel Nützi 390 Dec 31, 2022
Python wrapper class for OpenVINO Model Server. User can submit inference request to OVMS with just a few lines of code

Python wrapper class for OpenVINO Model Server. User can submit inference request to OVMS with just a few lines of code.

Yasunori Shimura 7 Jul 27, 2022
Intrinsic Image Harmonization

Intrinsic Image Harmonization [Paper] Zonghui Guo, Haiyong Zheng, Yufeng Jiang, Zhaorui Gu, Bing Zheng Here we provide PyTorch implementation and the

VISION @ OUC 44 Dec 21, 2022
ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers

ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers Official implementation of ViewFormer. ViewFormer is a NeRF-free neural rend

Jonáš Kulhánek 169 Dec 30, 2022
A mini-course offered to Undergrad chemistry students

The best way to use this material is by forking it by click the Fork button at the top, right corner. Then you will get your own copy to play with! Th

Raghu 19 Dec 19, 2022
Code for our CVPR2021 paper coordinate attention

Coordinate Attention for Efficient Mobile Network Design (preprint) This repository is a PyTorch implementation of our coordinate attention (will appe

Qibin (Andrew) Hou 726 Jan 05, 2023
Pytorch implementations of popular off-policy multi-agent reinforcement learning algorithms, including QMix, VDN, MADDPG, and MATD3.

Off-Policy Multi-Agent Reinforcement Learning (MARL) Algorithms This repository contains implementations of various off-policy multi-agent reinforceme

183 Dec 28, 2022
SHRIMP: Sparser Random Feature Models via Iterative Magnitude Pruning

SHRIMP: Sparser Random Feature Models via Iterative Magnitude Pruning This repository is the official implementation of "SHRIMP: Sparser Random Featur

Bobby Shi 0 Dec 16, 2021
Classification Modeling: Probability of Default

Credit Risk Modeling in Python Introduction: If you've ever applied for a credit card or loan, you know that financial firms process your information

Aktham Momani 2 Nov 07, 2022
The PyTorch implementation of DiscoBox: Weakly Supervised Instance Segmentation and Semantic Correspondence from Box Supervision.

DiscoBox: Weakly Supervised Instance Segmentation and Semantic Correspondence from Box Supervision The PyTorch implementation of DiscoBox: Weakly Supe

Shiyi Lan 1 Oct 23, 2021
The Turing Change Point Detection Benchmark: An Extensive Benchmark Evaluation of Change Point Detection Algorithms on real-world data

Turing Change Point Detection Benchmark Welcome to the repository for the Turing Change Point Detection Benchmark, a benchmark evaluation of change po

The Alan Turing Institute 85 Dec 28, 2022
Python package provinding tools for artistic interactive applications using AI

Documentation redrawing Python package provinding tools for artistic interactive applications using AI Created by ReDrawing Campinas team for the Open

ReDrawing Campinas 1 Sep 30, 2021