Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch

Overview

PyGAS: Auto-Scaling GNNs in PyG


PyGAS is the practical realization of our GNNAutoScale (GAS) framework, which scales arbitrary message-passing GNNs to large graphs, as described in our paper:

Matthias Fey, Jan E. Lenssen, Frank Weichert, Jure Leskovec: GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings (ICML 2021)

GAS prunes entire sub-trees of the computation graph by utilizing historical embeddings from prior training iterations, leading to constant GPU memory consumption in respect to input mini-batch size, and maximally expressivity.

PyGAS is implemented in PyTorch and utilizes the PyTorch Geometric (PyG) library. It provides an easy-to-use interface to convert a common or custom GNN from PyG into its scalable variant:

from torch_geometric.nn import SAGEConv
from torch_geometric_autoscale import ScalableGNN
from torch_geometric_autoscale import metis, permute, SubgraphLoader

class GNN(ScalableGNN):
    def __init__(self, num_nodes, in_channels, hidden_channels,
                 out_channels, num_layers):
        # * pool_size determines the number of pinned CPU buffers
        # * buffer_size determines the size of pinned CPU buffers,
        #   i.e. the maximum number of out-of-mini-batch nodes

        super().__init__(num_nodes, hidden_channels, num_layers,
                         pool_size=2, buffer_size=5000)

        self.convs = ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, adj_t, *args):
        for conv, history in zip(self.convs[:-1], self.histories):
            x = conv(x, adj_t).relu_()
            x = self.push_and_pull(history, x, *args)
        return self.convs[-1](x, adj_t)

perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)

model = GNN(...)
for batch, *args in loader:
    out = model(batch.x, batch.adj_t, *args)

A detailed description of ScalableGNN can be found in its implementation.

Requirements

pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric

where ${TORCH} should be replaced by either 1.7.0 or 1.8.0, and ${CUDA} should be replaced by either cpu, cu92, cu101, cu102, cu110 or cu111, depending on your PyTorch installation.

Installation

pip install git+https://github.com/rusty1s/pyg_autoscale.git

or

python setup.py install

Project Structure

  • torch_geometric_autoscale/ contains the source code of PyGAS
  • examples/ contains examples to demonstrate how to apply GAS in practice
  • small_benchmark/ includes experiments to evaluate GAS performance on small-scale graphs
  • large_benchmark/ includes experiments to evaluate GAS performance on large-scale graphs

We use Hydra to manage hyperparameter configurations.

Cite

Please cite our paper if you use this code in your own work:

@inproceedings{Fey/etal/2021,
  title={{GNNAutoScale}: Scalable and Expressive Graph Neural Networks via Historical Embeddings},
  author={Fey, M. and Lenssen, J. E. and Weichert, F. and Leskovec, J.},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2021},
}
Owner
Matthias Fey
PhD student @ TU Dortmund University - Interested in Representation Learning on Graphs and Manifolds; PyTorch, CUDA, Vim and macOS Enthusiast
Matthias Fey
Fre-GAN: Adversarial Frequency-consistent Audio Synthesis

Fre-GAN Vocoder Fre-GAN: Adversarial Frequency-consistent Audio Synthesis Training: python train.py --config config.json Citation: @misc{kim2021frega

Rishikesh (ऋषिकेश) 93 Dec 17, 2022
Direct Multi-view Multi-person 3D Human Pose Estimation

Implementation of NeurIPS-2021 paper: Direct Multi-view Multi-person 3D Human Pose Estimation [paper] [video-YouTube, video-Bilibili] [slides] This is

Sea AI Lab 251 Dec 30, 2022
Prior-Guided Multi-View 3D Head Reconstruction

Prior-Guided Head MVS This repository includes some reconstruction results of our IEEE TMM 2021 paper, Prior-Guided Multi-View 3D Head Reconstruction.

11 Aug 17, 2022
yolox_backbone is a deep-learning library and is a collection of YOLOX Backbone models.

YOLOX-Backbone yolox-backbone is a deep-learning library and is a collection of YOLOX backbone models. Install pip install yolox-backbone Load a Pret

Yonghye Kwon 21 Dec 28, 2022
U-Time: A Fully Convolutional Network for Time Series Segmentation

U-Time & U-Sleep Official implementation of The U-Time [1] model for general-purpose time-series segmentation. The U-Sleep [2] model for resilient hig

Mathias Perslev 176 Dec 19, 2022
Lazy, a tool for running things in idle time

Lazy, a tool for running things in idle time Mostly used to stop CUDA ML model training from making my desktop unusable. Simply monitors keyboard/mous

N Shepperd 46 Nov 06, 2022
Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning Implemented & tested on Sort-of-CLEVR task. So

Kim Heecheol 800 Dec 05, 2022
Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network

Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network Paddle-PANet 目录 结果对比 论文介绍 快速安装 结果对比 CTW1500 Method Backbone Fine

7 Aug 08, 2022
[ICCV2021] IICNet: A Generic Framework for Reversible Image Conversion

IICNet - Invertible Image Conversion Net Official PyTorch Implementation for IICNet: A Generic Framework for Reversible Image Conversion (ICCV2021). D

felixcheng97 55 Dec 06, 2022
OpenMMLab 3D Human Parametric Model Toolbox and Benchmark

Introduction English | 简体中文 MMHuman3D is an open source PyTorch-based codebase for the use of 3D human parametric models in computer vision and comput

OpenMMLab 782 Jan 04, 2023
Run Keras models in the browser, with GPU support using WebGL

**This project is no longer active. Please check out TensorFlow.js.** The Keras.js demos still work but is no longer updated. Run Keras models in the

Leon Chen 4.9k Dec 29, 2022
This is a collection of our NAS and Vision Transformer work.

AutoML - Neural Architecture Search This is a collection of our AutoML-NAS work iRPE (NEW): Rethinking and Improving Relative Position Encoding for Vi

Microsoft 828 Dec 28, 2022
Official repository for Jia, Raghunathan, Göksel, and Liang, "Certified Robustness to Adversarial Word Substitutions" (EMNLP 2019)

Certified Robustness to Adversarial Word Substitutions This is the official GitHub repository for the following paper: Certified Robustness to Adversa

Robin Jia 38 Oct 16, 2022
A simple baseline for the 2022 IEEE GRSS Data Fusion Contest (DFC2022)

DFC2022 Baseline A simple baseline for the 2022 IEEE GRSS Data Fusion Contest (DFC2022) This repository uses TorchGeo, PyTorch Lightning, and Segmenta

isaac 24 Nov 28, 2022
Implementation of Segnet, FCN, UNet , PSPNet and other models in Keras.

Image Segmentation Keras : Implementation of Segnet, FCN, UNet, PSPNet and other models in Keras. Implementation of various Deep Image Segmentation mo

Divam Gupta 2.6k Jan 05, 2023
Yet another video caption

Yet another video caption

Fan Zhimin 5 May 26, 2022
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Code for SentiBERT: A Transferable Transformer-Based Architecture for Compositional Sentiment Semantics (ACL'2020).

SentiBERT Code for SentiBERT: A Transferable Transformer-Based Architecture for Compositional Sentiment Semantics (ACL'2020). https://arxiv.org/abs/20

Da Yin 66 Aug 13, 2022
Dilated Convolution for Semantic Image Segmentation

Multi-Scale Context Aggregation by Dilated Convolutions Introduction Properties of dilated convolution are discussed in our ICLR 2016 conference paper

Fisher Yu 764 Dec 26, 2022
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022