Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods

Overview

Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods

Introduction

Graph Neural Networks (GNNs) have demonstrated superior performance in node classification or regression tasks, and have emerged as the state of the art in several applications. However, (inductive) GNNs require the edge connectivity structure of nodes to be known beforehand to work well. This is often not the case in several practical applications where the node degrees have power-law distributions, and nodes with a few connections might have noisy edges. An extreme case is the strict cold start (SCS) problem, where there is no neighborhood information available, forcing prediction models to rely completely on node features only. To study the viability of using inductive GNNs to solve the SCS problem, we introduce feature-contribution ratio (FCR), a metric to quantify the contribution of a node's features and that of its neighborhood in predicting node labels, and use this new metric as a model selection reward. We then propose Cold Brew, a new method that generalizes GNNs better in the SCS setting compared to pointwise and graph-based models, via a distillation approach. We show experimentally how FCR allows us to disentangle the contributions of various components of graph datasets, and demonstrate the superior performance of Cold Brew on several public benchmarks

Motivation

Long tail distribution is ubiquitously existed in large scale graph mining tasks. In some applications, some cold start nodes have too few or no neighborhood in the graph, which make graph based methods sub-optimal due to insufficient high quality edges to perform message passing.

gnns

gnns

Method

We improve teacher GNN with Structural Embedding, and propose student MLP model with latent neighborhood discovery step. We also propose a metric called FCR to judge the difficulty in cold start generalization.

gnns

coldbrew

Installation Guide

The following commands are used for installing key dependencies; other can be directly installed via pip or conda. A full redundant dependency list is in requirements.txt

pip install dgl
pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
pip install torch-geometric

Training Guide

In options/base_options.py, a full list of useable args is present, with default arguments and candidates initialized.

Comparing between traditional GCN (optimized with Initial/Jumping/Dense/PairNorm/NodeNorm/GroupNorm/Dropouts) and Cold Brew's GNN (optimized with Structural Embedding)

Train optimized traditional GNN:

python main.py --dataset='Cora' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 84.15

python main.py --dataset='Citeseer' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 71.00

python main.py --dataset='Pubmed' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 78.2

Training Cold Brew's Teacher GNN:

python main.py --dataset='Cora' --train_which='TeacherGNN' --whetherHasSE='100' --se_reg=32 --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 85.10

python main.py --dataset='Citeseer' --train_which='TeacherGNN' --whetherHasSE='100' --se_reg=0.5 --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 71.40

python main.py --dataset='Pubmed' --train_which='TeacherGNN' --whetherHasSE='111' --se_reg=0.5 --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 78.2

Comparing between MLP models:

Training naive MLP:

python main.py --dataset='Cora' --train_which='StudentBaseMLP' Result on isolation split: 63.92

Training GraphMLP:

python main.py --dataset='Cora' --train_which='GraphMLP' Result on isolation split: 68.63

Training Cold Brew's MLP:

python main.py --dataset='Cora' --train_which="SEMLP" --SEMLP_topK_2_replace=3 --SEMLP_part1_arch="2layer" --dropout_MLP=0.5 --studentMLP__opt_lr='torch.optim.Adam&0.005' Result on isolation split: 69.57

Hyperparameter meanings

--whetherHasSE: whether cold brew's TeacherGNN has structural embedding. The first ‘1’ means structural embedding exist in first layer; second ‘1’ means structural embedding exist in every middle layers; third ‘1’ means last layer.

--se_reg: regularization coefficient for cold brew teacher model's structural embedding.

--SEMLP_topK_2_replace: the number of top K best virtual neighbor nodes.

--manual_assign_GPU: set the GPU ID to train on. default=-9999, which means to dynamically choose GPU with most remaining memory.

Adaptation Guide

How to leverage this repo to train on other datasets:

In trainer.py, put any new graph dataset (node classification) under load_data() and return it.

what to load: return a dataset, which is a namespace, called 'data', data.x: 2D tensor, on cpu; shape = [N_nodes, dim_feature]. data.y: 1D tensor, on cpu; shape = [N_nodes]; values are integers, indicating the class of nodes. data.edge_index: tensor: [2, N_edge], cpu; edges contain self loop. data.train_mask: bool tensor, shape = [N_nodes], indicating the training node set. Template class for the 'data':

class MyDataset(torch_geometric.data.data.Data):
    def __init__(self):
        super().__init__()

Citation

comming soon.
Dense Prediction Transformers

Vision Transformers for Dense Prediction This repository contains code and models for our paper: Vision Transformers for Dense Prediction René Ranftl,

Intelligent Systems Lab Org 1.3k Jan 02, 2023
Doubly Robust Off-Policy Evaluation for Ranking Policies under the Cascade Behavior Model

Doubly Robust Off-Policy Evaluation for Ranking Policies under the Cascade Behavior Model About This repository contains the code to replicate the syn

Haruka Kiyohara 12 Dec 07, 2022
State of the Art Neural Networks for Deep Learning

pyradox This python library helps you with implementing various state of the art neural networks in a totally customizable fashion using Tensorflow 2

Ritvik Rastogi 60 May 29, 2022
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 09, 2023
Pre-trained NFNets with 99% of the accuracy of the official paper

NFNet Pytorch Implementation This repo contains pretrained NFNet models F0-F6 with high ImageNet accuracy from the paper High-Performance Large-Scale

Benjamin Schmidt 133 Dec 09, 2022
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
B-cos Networks: Attention is All we Need for Interpretability

Convolutional Dynamic Alignment Networks for Interpretable Classifications M. Böhle, M. Fritz, B. Schiele. B-cos Networks: Alignment is All we Need fo

58 Dec 23, 2022
JASS: Japanese-specific Sequence to Sequence Pre-training for Neural Machine Translation

JASS: Japanese-specific Sequence to Sequence Pre-training for Neural Machine Translation This the repository for this paper. Find extensions of this w

Zhuoyuan Mao 14 Oct 26, 2022
PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020).

NHDRRNet-PyTorch This is the PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020). 0. Differences between Original Paper and

Yutong Zhang 1 Mar 01, 2022
🌊 Online machine learning in Python

In a nutshell River is a Python library for online machine learning. It is the result of a merger between creme and scikit-multiflow. River's ambition

OnlineML 4k Jan 02, 2023
Using Random Effects to Account for High-Cardinality Categorical Features and Repeated Measures in Deep Neural Networks

LMMNN Using Random Effects to Account for High-Cardinality Categorical Features and Repeated Measures in Deep Neural Networks This is the working dire

Giora Simchoni 10 Nov 02, 2022
🔀 Visual Room Rearrangement

AI2-THOR Rearrangement Challenge Welcome to the 2021 AI2-THOR Rearrangement Challenge hosted at the CVPR'21 Embodied-AI Workshop. The goal of this cha

AI2 55 Dec 22, 2022
Measure WWjj polarization fraction

WlWl Polarization Measure WWjj polarization fraction Paper: arXiv:2109.09924 Notice: This code can only be used for the inference process, if you want

4 Apr 10, 2022
DeepLab-ResNet rebuilt in TensorFlow

DeepLab-ResNet-TensorFlow This is an (re-)implementation of DeepLab-ResNet in TensorFlow for semantic image segmentation on the PASCAL VOC dataset. Fr

Vladimir 1.2k Nov 04, 2022
Implementation of the state-of-the-art vision transformers with tensorflow

ViT Tensorflow This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision model

Mohammadmahdi NouriBorji 2 Mar 16, 2022
This repository introduces a short project about Transfer Learning for Classification of MRI Images.

Transfer Learning for MRI Images Classification This repository introduces a short project made during my stay at Neuromatch Summer School 2021. This

Oscar Guarnizo 3 Nov 15, 2022
[CVPR'21] Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild

IVOS-W Paper Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild Zhaoyun Yin, Jia Zheng, Weixin Luo, Shenhan Qian, Hanli

SVIP Lab 38 Dec 12, 2022
Where2Act: From Pixels to Actions for Articulated 3D Objects

Where2Act: From Pixels to Actions for Articulated 3D Objects The Proposed Where2Act Task. Given as input an articulated 3D object, we learn to propose

Kaichun Mo 69 Nov 28, 2022
TACTO: A Fast, Flexible and Open-source Simulator for High-Resolution Vision-based Tactile Sensors

TACTO: A Fast, Flexible and Open-source Simulator for High-Resolution Vision-based Tactile Sensors This package provides a simulator for vision-based

Facebook Research 255 Dec 27, 2022