EGNN - Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

Overview

EGNN - Pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Install

$ pip install egnn-pytorch

Usage

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)

With edges

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • training batch size

    training batch size

    Dear authors,

    thanks for your great work! I saw your example, which is easy to understand. But I notice that during training, in each iteration, it seems it supports the case where batch-size > 1, but all the graphs have the same adj_mat. do you have better solution for that? thanks

    opened by futianfan 6
  • Import Error when torch_geometric is not available

    Import Error when torch_geometric is not available

    https://github.com/lucidrains/egnn-pytorch/blob/e35510e1be94ee9f540bf2ffea49cd63578fe473/egnn_pytorch/egnn_pytorch.py#L413

    A small problem, this Tensor is not defined.

    Thanks for your work.

    opened by zrt 4
  • About aggregations in EGNN_sparse

    About aggregations in EGNN_sparse

    Hi, thanks for your great work!

    I have a question on how aggregations are computed for node embedding and coordinate embedding. In the paper, the aggregation for node embedding is computed over its neighbors, while the aggregation for coordinate embedding is computed over is computed over all others. However, in EGNN_sparse, I didn't notice such difference in aggregations.

    I guess it is because computing all-pair messages for coordinate embedding makes 'sparse' meaningless, but I would like to double-check to see if I get this correctly. So anyway, did you do this intentionally? Or did I miss something?

    My appreciation.

    opened by simon1727 4
  • Few queries on the implementation

    Few queries on the implementation

    Hi - fast work coding these things up, as usual! Looking at the paper and your code, you're not using squared distance for the edge weighting. Is that intentional? Also, it looks like you are adding the old feature vectors to the new ones rather than taking the new vectors directly from the fully connected net - is that also an intentional change from the paper?

    opened by denjots 3
  • Fix PyG problems. add exmaple for point cloud denoising

    Fix PyG problems. add exmaple for point cloud denoising

    • Fixed some tiny errors in data flows for the PyG layers (dimensions and slices mainly)
    • fixed the EGNN_Sparse_Network so now it works
    • provides example for point cloud denoising (from gaussian masked coordinates), and showcases potential issues:
      • unstable (could be due to nature of data, not sure, but gvp does well on it)
      • not able to beat baseline (in contrast, gvp gets to 0.8 RMSD while this gets to the baseline 1 RMSD but not below it)
    opened by hypnopump 2
  • EGNN_sparse incorrect positional encoding output

    EGNN_sparse incorrect positional encoding output

    Hi, many thanks for the implementation!

    I was quickly checking the code for the pytorch geometric implementation of the EGNN_sparse layer, and I noticed that it expects the first 3 columns in the features to be the coordinates. However, in the update method, features and coordinates are passed in the wrong order.

    https://github.com/lucidrains/egnn-pytorch/blob/375d686c749a685886874baba8c9e0752db5f5be/egnn_pytorch/egnn_pytorch.py#L192

    This may cause problems during learning (think of concatenating several of these layers), as they expect coordinate and feature order to be consistent.

    One can reproduce this behaviour in the following snippet:

    layer = EGNN_sparse(feats_dim=1, pos_dim=3, m_dim=16, fourier_features=0)
    
    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)
    
    feats = torch.randn(16, 1)
    coors = torch.randn(16, 3)
    x1 = torch.cat([coors, feats], dim=-1)
    x2 = torch.cat([(coors @ R + T).squeeze() , feats], dim=-1)
    edge_idxs = (torch.rand(2, 20) * 16).long()
    
    out1 = layer(x=x1, edge_index=edge_idxs)
    out2 = layer(x=x2, edge_index=edge_idxs)
    

    After fixing the order of these arguments in the update method then the layer behaves as expected (output features are equivariant, and coordinate features are equivariant upon se(3) transformation)

    opened by josejimenezluna 2
  • Nan Values after stacking multiple layers

    Nan Values after stacking multiple layers

    Hi Lucid!!

    I find that when stacking multiple layers the output from the model rapidly goes to Nan. I suspect it may be related to the weights used for initialization.

    Here is a minimal working example:

    Make some data:

        import numpy as np
        import torch
        from egnn_pytorch import EGNN
        
        torch.set_default_dtype(torch.double)
    
        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:]
        feat = torch.randint(0, 20, (1, geom.shape[1],1))
    

    Make a model:

        class ResEGNN(torch.nn.Module):
            def __init__(self, depth = 2, dims_in = 1):
                super().__init__()
                self.layers = torch.nn.ModuleList([EGNN(dim = dims_in) for i in range(depth)])
            
            def forward(self, geom, feat):
                for layer in self.layers:
                    feat, geom = layer(feat, geom)
                return geom
    

    Run model for varying depths:

        for i in range(10):
            model = ResEGNN(depth = i)
            pred = model(geom, feat)
            mean_absolute_value  = torch.abs(pred).mean()
            print("Order of predictions {:.2f}".format(np.log(mean_absolute_value.detach().numpy())))
    

    Output : Order of predictions -0.29 Order of predictions 0.05 Order of predictions 6.65 Order of predictions 21.38 Order of predictions 78.25 Order of predictions 302.71 Order of predictions 277.38 Order of predictions nan Order of predictions nan Order of predictions nan

    opened by brennanaba 2
  • Edge features thrown out

    Edge features thrown out

    Hi, thanks for this implementation!

    I was wondering if the pytorch-geometric implementation of this architecture is throwing the edge features out by mistake, as seen here

    https://github.com/lucidrains/egnn-pytorch/blob/1b8320ade1a89748e4042ae448626652f1c659a1/egnn_pytorch/egnn_pytorch.py#L148-L151

    Or maybe my understanding is wrong? Cheers,

    opened by josejimenezluna 2
  • solve ij -> i bottleneck in sparse version

    solve ij -> i bottleneck in sparse version

    I don't recommend normalizing the weights nor the coords.

    • The weights are the coefficient that multiplies the delta in the i->j direction
    • the coords are the deltas in the i->j direction Can't see the advantage of normalizing them beyond a naive stabilization that might affect the convergence properties by needing more layers due to the limited transformation that a layer will be able to do.

    It works fine for denoising without normalization (the unstability might come from huge outliers, but then tuning the learning rate or clipping the gradients might be of help.)

    opened by hypnopump 0
  • Questions about the EGNN code

    Questions about the EGNN code

    Recently, I've tried to read EGNN paper and study your EGNN code. Actually, I had hard time to understand both paper and code because my major is not computer science. When studying your code, I realize that the shape of hidden_out and the shape of kwargs["x"] must be same to perform add operation (becaus of residual connection) in the class EGNN_sparse forward method. How can I increase or decrease the hidden dimension size of x?

    I would like to get some advice.

    Thanks for your consideration in this regard.

    opened by Byun-jinyoung 0
  • Wrong edge_index size hint in  class EGNN_Sparse of pyg version

    Wrong edge_index size hint in class EGNN_Sparse of pyg version

    Hi, I found there may be a little mistake. In the input hint of class EGNN_Sparse of pyg version, the size of edge_index is (n_edges, 2). However, it should be (2, n_edges). Otherwise, the distance calculation will be not correct. """ Inputs: * x: (n_points, d) where d is pos_dims + feat_dims * edge_index: (n_edges, 2) * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats. * batch: (n_points,) long tensor. specifies xloud belonging for each point * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor. * size: None """

    opened by Layne-Huang 2
  • Exploding Gradients With 4 Layers

    Exploding Gradients With 4 Layers

    I'm using EGNN with 4 layers (where I also do global attention after each layer), and I'm seeing exploding gradients after 90 epochs or so. I'm using techniques discussed earlier (sparse attention matrix, coor_weights_clamp_value, norm_coors), but I'm not sure if there's anything else I should be doing. I'm also not updating the coordinates, so the fix in the pull request doesn't apply.

    opened by cutecows 0
  • Added optional tanh to coors_mlp

    Added optional tanh to coors_mlp

    This removes the NaN bug completely (must also use norm_coors otherwise performance dies)

    The NaN bug comes from the coors_mlp exploding, so forcing values between -1 and 1 prevents this. If coordinates are normalised then performance should not be adversely affected.

    opened by jscant 1
Releases(0.2.6)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
pixelNeRF: Neural Radiance Fields from One or Few Images

pixelNeRF: Neural Radiance Fields from One or Few Images Alex Yu, Vickie Ye, Matthew Tancik, Angjoo Kanazawa UC Berkeley arXiv: http://arxiv.org/abs/2

Alex Yu 1k Jan 04, 2023
A self-supervised 3D representation learning framework named viewpoint bottleneck.

Pointly-supervised 3D Scene Parsing with Viewpoint Bottleneck Paper Created by Liyi Luo, Beiwen Tian, Hao Zhao and Guyue Zhou from Institute for AI In

63 Aug 11, 2022
DGCNN - Dynamic Graph CNN for Learning on Point Clouds

DGCNN is the author's re-implementation of Dynamic Graph CNN, which achieves state-of-the-art performance on point-cloud-related high-level tasks including category classification, semantic segmentat

Wang, Yue 1.3k Dec 26, 2022
code for the ICLR'22 paper: On Robust Prefix-Tuning for Text Classification

On Robust Prefix-Tuning for Text Classification Prefix-tuning has drawed much attention as it is a parameter-efficient and modular alternative to adap

Zonghan Yang 12 Nov 30, 2022
Research code of ICCV 2021 paper "Mesh Graphormer"

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 08, 2023
Code for Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021)

Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021) Hang Zhou, Yasheng Sun, Wayne Wu, Chen Cha

Hang_Zhou 628 Dec 28, 2022
Locally Constrained Self-Attentive Sequential Recommendation

LOCKER This is the pytorch implementation of this paper: Locally Constrained Self-Attentive Sequential Recommendation. Zhankui He, Handong Zhao, Zhe L

Zhankui (Aaron) He 8 Jul 30, 2022
In Search of Probeable Generalization Measures

In Search of Probeable Generalization Measures Exciting News! In Search of Probeable Generalization Measures has been accepted to the International Co

Mahdi S. Hosseini 6 Sep 11, 2022
Files for a tutorial to train SegNet for road scenes using the CamVid dataset

SegNet and Bayesian SegNet Tutorial This repository contains all the files for you to complete the 'Getting Started with SegNet' and the 'Bayesian Seg

Alex Kendall 800 Dec 31, 2022
CryptoFrog - My First Strategy for freqtrade

cryptofrog-strategies CryptoFrog - My First Strategy for freqtrade NB: (2021-04-20) You'll need the latest freqtrade develop branch otherwise you migh

Robert Davey 137 Jan 01, 2023
Human Activity Recognition example using TensorFlow on smartphone sensors dataset and an LSTM RNN. Classifying the type of movement amongst six activity categories - Guillaume Chevalier

LSTMs for Human Activity Recognition Human Activity Recognition (HAR) using smartphones dataset and an LSTM RNN. Classifying the type of movement amon

Guillaume Chevalier 3.1k Dec 30, 2022
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 214 Jan 03, 2023
SAS output to EXCEL converter for Cornell/MIT Language and acquisition lab

CORNELLSASLAB SAS output to EXCEL converter for Cornell/MIT Language and acquisition lab Instructions: This python code can be used to convert SAS out

2 Jan 26, 2022
realsense d400 -> jpg + csv

Realsense-capture realsense d400 - jpg + csv Requirements RealSense sdk : Installation Python3 pyrealsense2 (RealSense SDK) Numpy OpenCV Tkinter Run

Ar-Ray 2 Mar 22, 2022
PyTorch implementation of Federated Learning with Non-IID Data, and federated learning algorithms, including FedAvg, FedProx.

Federated Learning with Non-IID Data This is an implementation of the following paper: Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, Vik

Youngjoon Lee 48 Dec 29, 2022
Python scripts for performing stereo depth estimation using the HITNET Tensorflow model.

HITNET-Stereo-Depth-estimation Python scripts for performing stereo depth estimation using the HITNET Tensorflow model from Google Research. Stereo de

Ibai Gorordo 76 Jan 02, 2023
Official implementation of CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21

CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21 For more information, check out the paper on [arXiv]. Training with different

Sunghwan Hong 120 Jan 04, 2023
An All-MLP solution for Vision, from Google AI

MLP Mixer - Pytorch An All-MLP solution for Vision, from Google AI, in Pytorch. No convolutions nor attention needed! Yannic Kilcher video Install $ p

Phil Wang 784 Jan 06, 2023
Code for reproducible experiments presented in KSD Aggregated Goodness-of-fit Test.

Code for KSDAgg: a KSD aggregated goodness-of-fit test This GitHub repository contains the code for the reproducible experiments presented in our pape

Antonin Schrab 5 Dec 15, 2022