Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

Overview

E(n)-Equivariant Transformer (wip)

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention.

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,
    dim_head = 64,
    heads = 8,
    edge_dim = 4,
    fourier_features = 2
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = model(feats, coors, edges)  # (1, 16, 512), (1, 16, 3)

Todo

  • masking
  • neighborhoods by radius

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Checkpoint sequential segments should equal number of layers instead of 1?

    Checkpoint sequential segments should equal number of layers instead of 1?

    https://github.com/lucidrains/En-transformer/blob/a37e635d93a322cafdaaf829397c601350b23e5b/en_transformer/en_transformer.py#L527

    Looking at the source code here: https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint_sequential

    opened by aced125 2
  • On rotary embeddings

    On rotary embeddings

    Hi @lucidrains, thank you for your amazing work; big fan! I had a quick question on the usage of this repository.

    Based on my understanding, rotary embeddings are a drop-in replacement for the original sinusoidal or learnt PEs in Transformers for sequential data, as in NLP or other temporal applications. If my application is not on sequential data, is there a reason why I should still use rotary embeddings?

    E.g. for molecular datasets such as QM9 (from the En-GNNs paper), would it make sense to have rotary embeddings?

    opened by chaitjo 1
  • Is this line required?

    Is this line required?

    https://github.com/lucidrains/En-transformer/blob/7247e258fab953b2a8b5a73b8dfdfb72910711f8/en_transformer/en_transformer.py#L159

    Is this line required? Does line 157, two lines above, make this line redundant?

    opened by aced125 1
  • Performance drop with checkpointing update

    Performance drop with checkpointing update

    I see a drop in performance (higher loss) when I update checkpointing from checkpoint_sequential(self.layers, 1, inp) to checkpoint_sequential(self.layers, len(self.layers), inp). Is this expected?

    opened by heiidii 0
  • varying number of nodes

    varying number of nodes

    @lucidrains Thank you for your efficient implementation. I was wondering how to use this implementation for the dataset when the number of nodes in each graph is not the same? For example, the datasets of small molecules.

    opened by mohaiminul2810 1
  • Edge model/rep

    Edge model/rep

    Hi,

    Thank you for providing this version of the EnGNN model. This is not really an issue just a query. The original model as implemented here (https://github.com/vgsatorras/egnn) has 3 main steps per layer: edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) h, agg = self.node_model(h, edge_index, edge_feat, node_attr) I am interested in the edge_feat and was wondering what would be an equivalent edge representation in your implementation. Line 335 in EnTransformer.py: qk = self.edge_mlp(qk) seems like the best candidate. Thanks, Pooja

    opened by heiidii 1
  • efficient implementation

    efficient implementation

    Hi, I wonder if relative distances and coordinates can be handled more efficiently using memory efficient attention as in " Self-attention Does Not Need O(n^2) Memory". It is straightforward for the scalar part.

    opened by amrhamedp 2
Releases(1.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
✨✨✨An awesome open source toolbox for stereo matching.

OpenStereo This is an awesome open source toolbox for stereo matching. Supported Methods: BM SGM(T-PAMI'07) GCNet(ICCV'17) PSMNet(CVPR'18) StereoNet(E

Wang Qingyu 6 Nov 04, 2022
PyTorch IPFS Dataset

PyTorch IPFS Dataset IPFSDataset(Dataset) See the jupyter notepad to see how it works and how it interacts with a standard pytorch DataLoader You need

Jake Kalstad 2 Apr 13, 2022
Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

48 Dec 20, 2022
Gym environment for FLIPIT: The Game of "Stealthy Takeover"

gym-flipit Gym environment for FLIPIT: The Game of "Stealthy Takeover" invented by Marten van Dijk, Ari Juels, Alina Oprea, and Ronald L. Rivest. Desi

Lisa Oakley 2 Dec 15, 2021
darija <-> english dictionary

darija-dictionary Having advanced IT solutions that are well adapted to the Moroccan context passes inevitably through understanding Moroccan dialect.

DODa 102 Jan 01, 2023
School of Artificial Intelligence at the Nanjing University (NJU)School of Artificial Intelligence at the Nanjing University (NJU)

F-Principle This is an exercise problem of the digital signal processing (DSP) course at School of Artificial Intelligence at the Nanjing University (

Thyrix 5 Nov 23, 2022
quantize aware training package for NCNN on pytorch

ncnnqat ncnnqat is a quantize aware training package for NCNN on pytorch. Table of Contents ncnnqat Table of Contents Installation Usage Code Examples

62 Nov 23, 2022
This repository contains a set of codes to run (i.e., train, perform inference with, evaluate) a diarization method called EEND-vector-clustering.

EEND-vector clustering The EEND-vector clustering (End-to-End-Neural-Diarization-vector clustering) is a speaker diarization framework that integrates

45 Dec 26, 2022
CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

CompilerGym is a library of easy to use and performant reinforcement learning environments for compiler tasks

Facebook Research 721 Jan 03, 2023
An AutoML Library made with Optuna and PyTorch Lightning

An AutoML Library made with Optuna and PyTorch Lightning Installation Recommended pip install -U gradsflow From source pip install git+https://github.

GradsFlow 294 Dec 17, 2022
Multiwavelets-based operator model

Multiwavelet model for Operator maps Gaurav Gupta, Xiongye Xiao, and Paul Bogdan Multiwavelet-based Operator Learning for Differential Equations In Ne

Gaurav 33 Dec 04, 2022
Jaxtorch (a jax nn library)

Jaxtorch (a jax nn library) This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular ja

nshepperd 17 Dec 08, 2022
Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks"

LUNAR Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks" Adam Goodge, Bryan Hooi, Ng See Kiong and

Adam Goodge 25 Dec 28, 2022
Simple converter for deploying Stable-Baselines3 model to TFLite and/or Coral

Running SB3 developed agents on TFLite or Coral Introduction I've been using Stable-Baselines3 to train agents against some custom Gyms, some of which

Gary Briggs 16 Oct 11, 2022
TAPEX: Table Pre-training via Learning a Neural SQL Executor

TAPEX: Table Pre-training via Learning a Neural SQL Executor The official repository which contains the code and pre-trained models for our paper TAPE

Microsoft 157 Dec 28, 2022
An implementation of Equivariant e2 convolutional kernals into a convolutional self attention network, applied to radio astronomy data.

EquivariantSelfAttention An implementation of Equivariant e2 convolutional kernals into a convolutional self attention network, applied to radio astro

2 Nov 09, 2021
Translate darknet to tensorflow. Load trained weights, retrain/fine-tune using tensorflow, export constant graph def to mobile devices

Intro Real-time object detection and classification. Paper: version 1, version 2. Read more about YOLO (in darknet) and download weight files here. In

Trieu 6.1k Jan 04, 2023
Code for Estimating Multi-cause Treatment Effects via Single-cause Perturbation (NeurIPS 2021)

Estimating Multi-cause Treatment Effects via Single-cause Perturbation (NeurIPS 2021) Single-cause Perturbation (SCP) is a framework to estimate the m

Zhaozhi Qian 9 Sep 28, 2022
The Malware Open-source Threat Intelligence Family dataset contains 3,095 disarmed PE malware samples from 454 families

MOTIF Dataset The Malware Open-source Threat Intelligence Family (MOTIF) dataset contains 3,095 disarmed PE malware samples from 454 families, labeled

Booz Allen Hamilton 112 Dec 13, 2022