[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Overview

Graph Stochastic Attention (GSAT)

The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism, to appear in ICML 2022.

Introduction

Commonly used attention mechanisms do not impose any constraints during training (besides normalization), and thus may lack interpretability. GSAT is a novel attention mechanism for building interpretable graph learning models. It injects stochasticity to learn attention, where a higher attention weight means a higher probability of the corresponding edge being kept during training. Such a mechanism will push the model to learn higher attention weights for edges that are important for prediction accuracy, which provides interpretability. To further improve the interpretability for graph learning tasks and avoid trivial solutions, we derive regularization terms for GSAT based on the information bottleneck (IB) principle. As a by-product, IB also helps model generalization. Fig. 1 shows the architecture of GSAT.

Figure 1. The architecture of GSAT.

Installation

We have tested our code on Python 3.9 with PyTorch 1.10.0, PyG 2.0.3 and CUDA 11.3. Please follow the following steps to create a virtual environment and install the required packages.

Create a virtual environment:

conda create --name gsat python=3.9
conda activate gsat

Install dependencies:

conda install -y pytorch==1.10.0 torchvision cudatoolkit=11.3 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
pip install -r requirements.txt

In case a lower CUDA version is required, please use the following command to install dependencies:

conda install -y pytorch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 cudatoolkit=10.2 -c pytorch
pip install torch-scatter==2.0.9 torch-sparse==0.6.12 torch-cluster==1.5.9 torch-spline-conv==1.2.1 torch-geometric==2.0.3 -f https://data.pyg.org/whl/torch-1.9.0+cu102.html
pip install -r requirements.txt

Run Examples

We provide examples with minimal code to run GSAT in ./example/example.ipynb. We have tested the provided examples on Ba-2Motifs (GIN), Mutag (GIN) and OGBG-Molhiv (PNA). Yet, to implement GSAT* one needs to load a pre-trained model first in the provided example.

It should be able to run on other datasets as well, but some hard-coded hyperparameters might need to be changed accordingly. To reproduce results for other datasets, please follow the instructions in the following section.

Reproduce Results

We provide the source code to reproduce the results in our paper. The results of GSAT can be reproduced by running run_gsat.py. To reproduce GSAT*, one needs to run pretrain_clf.py first and change the configuration file accordingly (from_scratch: false).

To pre-train a classifier:

cd ./src
python pretrain_clf.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

To train GSAT:

cd ./src
python run_gsat.py --dataset [dataset_name] --backbone [model_name] --cuda [GPU_id]

dataset_name can be choosen from ba_2motifs, mutag, mnist, Graph-SST2, spmotif_0.5, spmotif_0.7, spmotif_0.9, ogbg_molhiv, ogbg_moltox21, ogbg_molbace, ogbg_molbbbp, ogbg_molclintox, ogbg_molsider.

model_name can be choosen from GIN, PNA.

GPU_id is the id of the GPU to use. To use CPU, please set it to -1.

Training Logs

Standard output provides basic training logs, while more detailed logs and interpretation visualizations can be found on tensorboard:

tensorboard --logdir=./data/[dataset_name]/logs

Hyperparameter Settings

All settings can be found in ./src/configs.

Instructions on Acquiring Datasets

  • Ba_2Motifs

    • Raw data files can be downloaded automatically, provided by PGExplainer and DIG.
  • Spurious-Motif

    • Raw data files can be generated automatically, provide by DIR.
  • OGBG-Mol

    • Raw data files can be downloaded automatically, provided by OGBG.
  • Mutag

    • Raw data files need to be downloaded here, provided by PGExplainer.
    • Unzip Mutagenicity.zip and Mutagenicity.pkl.zip.
    • Put the raw data files in ./data/mutag/raw.
  • Graph-SST2

    • Raw data files need to be downloaded here, provided by DIG.
    • Unzip the downloaded Graph-SST2.zip.
    • Put the raw data files in ./data/Graph-SST2/raw.
  • MNIST-75sp

    • Raw data files need to be generated following the instruction here.
    • Put the generated files in ./data/mnist/raw.

FAQ

Does GSAT encourage sparsity?

No, GSAT doesn't encourage generating sparse subgraphs. We find r = 0.7 (Eq.(9) in our paper) can generally work well for all datasets in our experiments, which means during training roughly 70% of edges will be kept (kind of still large). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph, which is what previous works normally do and will hurt performance significantly for inhrently interpretable models (as shown in Fig. 7 in the paper). By contrast, GSAT provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.

How to choose the value of r?

A grid search in [0.5, 0.6, 0.7, 0.8, 0.9] is recommended, but r = 0.7 is a good starting point. Note that in practice we would decay the value of r gradually during training from 0.9 to the chosen value.

p or α to implement Eq.(9)?

Recall in Fig. 1, p is the probability of dropping an edge, while α is the sampled result from Bern(p). In our provided implementation, as an empirical choice, α is used to implement Eq.(9) (the Gumbel-softmax trick makes α essentially continuous in practice). We find that when α is used it may provide more regularization and makes the model more robust to hyperparameters. Nonetheless, using p can achieve the same performance, but it needs some more tuning.

Can you show an example of how GSAT works?

Below we show an example from the ba_2motifs dataset, which is to distinguish five-node cycle motifs (left) and house motifs (right). To make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to 1). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to r, which is set to be 0.7 in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped). Note that ba_2motifs satisfies our Thm. 4.1 with no noise, and GSAT achieves perfect interpretation performance on it.

Figure 2. An example of the learned attention weights.

Reference

If you find our paper and repo useful, please cite our paper:

@article{miao2022interpretable,
  title={Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism},
  author={Miao, Siqi and Liu, Miaoyuan and Li, Pan},
  journal={arXiv preprint arXiv:2201.12987},
  year={2022}
}
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Jan 02, 2023
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022
Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly Code for this paper Ultra-Data-Efficient GAN Tra

VITA 77 Oct 05, 2022
[ACM MM2021] MGH: Metadata Guided Hypergraph Modeling for Unsupervised Person Re-identification

Introduction This project is developed based on FastReID, which is an ongoing ReID project. Projects BUC In projects/BUC, we implement AAAI 2019 paper

WuYiming 7 Apr 13, 2022
This is the replication package for paper submission: Towards Training Reproducible Deep Learning Models.

This is the replication package for paper submission: Towards Training Reproducible Deep Learning Models.

0 Feb 02, 2022
This repository contains the code used for the implementation of the paper "Probabilistic Regression with HuberDistributions"

Public_prob_regression_with_huber_distributions This repository contains the code used for the implementation of the paper "Probabilistic Regression w

David Mohlin 1 Dec 04, 2021
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data (NeurIPS 2021) This repository will provide the official PyTorch implementa

Liming Jiang 238 Nov 25, 2022
A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN

A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN Please follow Faster R-CNN and DAF to complete the environment confi

2 Jan 12, 2022
Single Red Blood Cell Hydrodynamic Traps Via the Generative Design

Rbc-traps-generative-design - The generative design for single red clood cell hydrodynamic traps using GEFEST framework

Natural Systems Simulation Lab 4 Jun 16, 2022
Fully-automated scripts for collecting AI-related papers

AI-Paper-collector Fully-automated scripts for collecting AI-related papers List of Conferences to crawel ACL: 21-19 (including findings) EMNLP: 21-19

Gordon Lee 776 Jan 08, 2023
PyTorch implementation of the TTC algorithm

Trust-the-Critics This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critic

0 Nov 29, 2021
The ICS Chat System project for NYU Shanghai Fall 2021

ICS_Chat_System [Catenger] This is the ICS Chat System project for NYU Shanghai Fall 2021 Creators: Shavarsh Melikyan, Skyler Chen and Arghya Sarkar,

1 Dec 20, 2021
An air quality monitoring service with a Raspberry Pi and a SDS011 sensor.

Raspberry Pi Air Quality Monitor A simple air quality monitoring service for the Raspberry Pi. Installation Clone the repository and run the following

rydercalmdown 24 Dec 09, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
MM1 and MMC Queue Simulation using python - Results and parameters in excel and csv files

implementation of MM1 and MMC Queue on randomly generated data and evaluate simulation results then compare with analytical results and draw a plot curve for them, simulate some integrals and compare

Mohamadreza Rezaei 1 Jan 19, 2022
SimpleDepthEstimation - An unified codebase for NN-based monocular depth estimation methods

SimpleDepthEstimation Introduction This is an unified codebase for NN-based monocular depth estimation methods, the framework is based on detectron2 (

8 Dec 13, 2022
Code repository for "Reducing Underflow in Mixed Precision Training by Gradient Scaling" presented at IJCAI '20

Reducing Underflow in Mixed Precision Training by Gradient Scaling This project implements the gradient scaling method to improve the performance of m

Ruizhe Zhao 5 Apr 14, 2022
This repository contains the code for TABS, a 3D CNN-Transformer hybrid automated brain tissue segmentation algorithm using T1w structural MRI scans

This repository contains the code for TABS, a 3D CNN-Transformer hybrid automated brain tissue segmentation algorithm using T1w structural MRI scans. TABS relies on a Res-Unet backbone, with a Vision

6 Nov 07, 2022
Convert Table data to approximate values with GUI

Table_Editor Convert Table data to approximate values with GUIs... usage - Import methods for extension Tables. Imported method supposed to have only

CLJ 1 Jan 10, 2022
A Home Assistant custom component for Lobe. Lobe is an AI tool that can classify images.

Lobe This is a Home Assistant custom component for Lobe. Lobe is an AI tool that can classify images. This component lets you easily use an exported m

Kendell R 4 Feb 28, 2022