SimplEx - Explaining Latent Representations with a Corpus of Examples

Overview

SimplEx - Explaining Latent Representations with a Corpus of Examples

image

Code Author: Jonathan Crabbé ([email protected])

This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.

Installation

  1. Clone the repository
  2. Create a new virtual environment with Python 3.8
  3. Run the following command from the repository folder:
    pip install -r requirements.txt #install requirements

When the packages are installed, SimplEx can directly be used.

Toy example

Bellow, you can find a toy demonstration where we make a corpus decomposition of test examples representations. All the relevant code can be found in the file simplex.

from explainers.simplex import Simplex
from models.base import BlackBox

# Get the model and the examples
model = BlackBox() # Model should have the BlackBox interface
corpus_inputs = get_corpus() # A tensor of corpus inputs
test_inputs = get_test() # A set of inputs to explain

# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs) 
test_latents = model.latent_representation(test_inputs)

# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs, 
                  corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs, 
            test_latent_reps=test_latents,
            reg_factor=0)

# Get the weights of each corpus decomposition
weights = simplex.weights

We get a tensor weights that can be interpreted as follows: weights[i,c] = weight of corpus example c in the decomposition of example i.

We can get the importance of each corpus feature for the decomposition of a given example i in the following way:

# Compute the Integrated Jacobian for a particular example
i = 42
input_baseline = get_baseline() # Baseline tensor of the same shape as corpus_inputs
simplex.jacobian_projections(test_id=i, model=model,
                             input_baseline=input_baseline)

result = simplex.decompose(i)

We get a list result where each element of the list corresponds to a corpus example. This list is sorted by decreasing order of importance in the corpus decomposition. Each element of the list is a tuple structured as follows:

w_c, x_c, proj_jacobian_c = result[c]

Where w_c corresponds to the weight weights[i,c], x_c corresponds to corpus_inputs[c] and proj_jacobian is a tensor such that proj_jacobian_c[k] is the Projected Jacobian of feature k from corpus example c.

Reproducing the paper results

Reproducing MNIST Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.mnist -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing Prostate Cancer Approximation Quality Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Copy the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.prostate_cancer -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.prostate.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing Prostate Cancer Outlier Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Make sure that the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.prostate_cancer -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.prostate.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing MNIST Jacobian Projection Significance Experiment

  1. Run the following script
python -m experiments.mnist -experiment "jacobian_corruption" 

2.The resulting plots and data are saved here.

Reproducing MNIST Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m experiments.mnist -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing MNIST Influence Function Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.mnist -experiment "influence" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.mnist.influence.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Note: some problems can appear with the package Pytorch Influence Functions. If this is the case, please change calc_influence_function.py in the following way:

343: influences.append(tmp_influence) ==> influences.append(tmp_influence.cpu())
438: influences_meta['test_sample_index_list'] = sample_list ==> #influences_meta['test_sample_index_list'] = sample_list

Reproducing AR Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.time_series -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.ar.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing AR Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m experiments.time_series -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m experiments.results.ar.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Citing

If you use this code, please cite the associated paper:

Put citation here when ready
Owner
Jonathan Crabbé
I am currently doing a PhD in Explainable AI at the Department of Applied Mathematics and Theoretical Physics (DAMTP) of the University of Cambridge.
Jonathan Crabbé
Local trajectory planner based on a multilayer graph framework for autonomous race vehicles.

Graph-Based Local Trajectory Planner The graph-based local trajectory planner is python-based and comes with open interfaces as well as debug, visuali

TUM - Institute of Automotive Technology 160 Jan 04, 2023
hySLAM is a hybrid SLAM/SfM system designed for mapping

HySLAM Overview hySLAM is a hybrid SLAM/SfM system designed for mapping. The system is based on ORB-SLAM2 with some modifications and refactoring. Raú

Brian Hopkinson 15 Oct 10, 2022
Self-supervised Point Cloud Prediction Using 3D Spatio-temporal Convolutional Networks

Self-supervised Point Cloud Prediction Using 3D Spatio-temporal Convolutional Networks This is a Pytorch-Lightning implementation of the paper "Self-s

Photogrammetry & Robotics Bonn 111 Dec 06, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Prompt-BERT: Prompt makes BERT Better at Sentence Embeddings

Prompt-BERT: Prompt makes BERT Better at Sentence Embeddings Results on STS Tasks Model STS12 STS13 STS14 STS15 STS16 STSb SICK-R Avg. unsup-prompt-be

196 Jan 08, 2023
Retinal Vessel Segmentation with Pixel-wise Adaptive Filters (ISBI 2022)

Official code of Retinal Vessel Segmentation with Pixel-wise Adaptive Filters and Consistency Training (ISBI 2022)

anonymous 14 Oct 27, 2022
Python based framework for Automatic AI for Regression and Classification over numerical data.

Python based framework for Automatic AI for Regression and Classification over numerical data. Performs model search, hyper-parameter tuning, and high-quality Jupyter Notebook code generation.

BlobCity, Inc 141 Dec 21, 2022
A High-Level Fusion Scheme for Circular Quantities published at the 20th International Conference on Advanced Robotics

Monte Carlo Simulation to the Paper A High-Level Fusion Scheme for Circular Quantities published at the 20th International Conference on Advanced Robotics

Sören Kohnert 0 Dec 06, 2021
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
Implementation of BI-RADS-BERT & The Advantages of Section Tokenization.

BI-RADS BERT Implementation of BI-RADS-BERT & The Advantages of Section Tokenization. This implementation could be used on other radiology in house co

1 May 17, 2022
Learn about quantum computing and algorithm on quantum computing

quantum_computing this repo contains everything i learn about quantum computing and algorithm on quantum computing what is aquantum computing quantum

arfy slowy 8 Dec 25, 2022
《Truly shift-invariant convolutional neural networks》(2021)

Truly shift-invariant convolutional neural networks [Paper] Authors: Anadi Chaman and Ivan Dokmanić Convolutional neural networks were always assumed

Anadi Chaman 46 Dec 19, 2022
Robust Consistent Video Depth Estimation

[CVPR 2021] Robust Consistent Video Depth Estimation This repository contains Python and C++ implementation of Robust Consistent Video Depth, as descr

Facebook Research 213 Dec 17, 2022
Hi Guys, here I am providing examples, which will help you in Lerarning Python

LearningPython Hi guys, here I am trying to include as many practice examples of Python Language, as i Myself learn, and hope these will help you in t

4 Feb 03, 2022
Recreate CenternetV2 based on MMDET.

Introduction This project is trying to Recreate CenternetV2 based on MMDET, which is proposed in paper Probabilistic two-stage detection. This project

25 Dec 09, 2022
Virtual hand gesture mouse using a webcam

NonMouse 日本語のREADMEはこちら This is an application that allows you to use your hand itself as a mouse. The program uses a web camera to recognize your han

Yuki Takeyama 55 Jan 01, 2023
CVPR2020 Counterfactual Samples Synthesizing for Robust VQA

CVPR2020 Counterfactual Samples Synthesizing for Robust VQA This repo contains code for our paper "Counterfactual Samples Synthesizing for Robust Visu

72 Dec 22, 2022
Wafer Fault Detection using MlOps Integration

Wafer Fault Detection using MlOps Integration This is an end to end machine learning project with MlOps integration for predicting the quality of wafe

Sethu Sai Medamallela 0 Mar 11, 2022
Cooperative multi-agent reinforcement learning for high-dimensional nonequilibrium control

Cooperative multi-agent reinforcement learning for high-dimensional nonequilibrium control Official implementation of: Cooperative multi-agent reinfor

0 Nov 16, 2021