Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

Overview

SE3 Transformer - Pytorch

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 results and other drug discovery applications.

Install

$ pip install se3-transformer-pytorch

Usage

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 512,
    heads = 8,
    depth = 6,
    dim_head = 64,
    num_degrees = 4,
    valid_radius = 10
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
mask  = torch.ones(1, 1024).bool()

out = model(feats, coors, mask) # (1, 1024, 512)

Potential example usage in Alphafold2, as outlined here

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atom_feats = torch.randn(2, 32, 64)
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atom_feats, coors, mask, return_type = 1) # (2, 32, 3)

You can also let the base transformer class take care of embedding the type 0 features being passed in. Assuming they are atoms

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,       # 28 unique atoms
    dim = 64,
    depth = 2,
    input_degrees = 1,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(atoms, coors, mask, return_type = 1) # (2, 32, 3)

If you think the net could further benefit from positional encoding, you can featurize your positions in space and pass it in as follows.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    dim = 64,
    depth = 2,
    input_degrees = 2,
    num_degrees = 2,
    output_degrees = 2,
    reduce_dim_out = True  # reduce out the final dimension
)

atom_feats  = torch.randn(2, 32, 64, 1) # b x n x d x type0
coors_feats = torch.randn(2, 32, 64, 3) # b x n x d x type1

# atom features are type 0, predicted coordinates are type 1
features = {'0': atom_feats, '1': coors_feats}
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

refinement = model(features, coors, mask, return_type = 1) # (2, 32, 3) - equivariant to input type 1 features and coordinates

Edges

To offer edge information to SE3 Transformers (say bond types between atoms), you just have to pass in two more keyword arguments on initialization.

import torch
from se3_transformer_pytorch import SE3Transformer

model = SE3Transformer(
    num_tokens = 28,
    dim = 64,
    num_edge_tokens = 4,       # number of edge type, say 4 bond types
    edge_dim = 16,             # dimension of edge embedding
    depth = 2,
    input_degrees = 1,
    num_degrees = 3,
    output_degrees = 1,
    reduce_dim_out = True
)

atoms = torch.randint(0, 28, (2, 32))
bonds = torch.randint(0, 4, (2, 32, 32))
coors = torch.randn(2, 32, 3)
mask  = torch.ones(2, 32).bool()

pred = model(atoms, coors, mask, edges = bonds, return_type = 0) # (2, 32, 1)

Caching

By default, the basis vectors are cached. However, if there is ever the need to clear the cache, you simply have to set the environmental flag CLEAR_CACHE to some value on initiating the script

$ CLEAR_CACHE=1 python train.py

Or you can try deleting the cache directory, which should exist at

$ rm -rf ~/.cache.equivariant_attention

Testing

$ python setup.py pytest

Credit

This library is largely a port of Fabian's official repository, but without the DGL library.

Citations

@misc{fuchs2020se3transformers,
    title   = {SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks}, 
    author  = {Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling},
    year    = {2020},
    eprint  = {2006.10503},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Breaking equivariance

    Breaking equivariance

    Hi, Thanks a lot for your work!!

    I was running some equivariance tests on your implementation of the SE3-Transformer and found that it is not always conserved. It does not break every time and unfortunately I do not know where the bug is.

    I have appended an image with an example of equivariance not being conserved.

    image

    To generate this image I used the following code:

        import numpy as np
        import matplotlib.pyplot as plt
        import torch
        
        from se3_transformer_pytorch import SE3Transformer
    

    Make some data

        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:].float()
        feat = torch.randint(0, 20, (1, geom.shape[1],1)).float()
        
        def rot_matrix(x):
            # Rotation matrix
            a ,b ,c = 2*np.pi*x
            return np.array([
                [np.cos(a)*np.cos(b), np.cos(a)*np.sin(b)*np.sin(c)- np.sin(a)*np.cos(c), np.cos(a)*np.sin(b)*np.cos(c)+ np.sin(a)*np.sin(c)],
                [np.sin(a)*np.cos(b), np.sin(a)*np.sin(b)*np.sin(c)+ np.cos(a)*np.cos(c), np.sin(a)*np.sin(b)*np.cos(c)- np.cos(a)*np.sin(c)],
                [-np.sin(b)         , np.cos(b)*np.sin(c)                               , np.cos(b)*np.cos(c)                               ]
            ])
    

    Initialize model

        mdl = SE3Transformer(
            dim = 1,
            depth = 3,
            input_degrees = 1,
            num_degrees = 2,
            output_degrees = 2,
            reduce_dim_out = True,
        )
        
        def model(geom,feat):
            return geom + mdl(feat,geom, return_type = 1)
    

    Check Rotation Invariance:

        with torch.no_grad():
            
            Q = torch.tensor(rot_matrix(np.random.random(3))).float()
            prerotated = model(geom @ Q, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) @ Q).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Rotated", "Post-Rotated"])
            plt.show()
    

    Check Translation Invariance:

        with torch.no_grad():
            
            x0 = 1*torch.rand(3)
            prerotated = model(geom + x0, feat).squeeze().detach().numpy().transpose()
            posrotated = (model(geom, feat) + x0).squeeze().detach().numpy().transpose()
        
            fig = plt.figure(dpi = 200)
            ax = plt.axes(projection="3d")
        
            ax.plot3D(prerotated[0], prerotated[1], prerotated[2], "r", linewidth=1.1)
            ax.plot3D(posrotated[0], posrotated[1], posrotated[2], "b", linewidth=0.5)
        
            plt.legend(["Pre-Translated", "Post-Translated"])
            plt.show()
    

    Hope this helps!!

    opened by brennanaba 18
  • How to populate input variable length data

    How to populate input variable length data

    Hello, thank you for your work. I am using the implementation of your SE3-Transformer as part of my project, but I have encountered some problems when processing variable length data input into your model. I do not know how to fill in features, coordinates and masks to meet the needs of the model.

    The processing of input data in my DataSet is as follows: image I fill the coordinates and features with zeros to ensure that their input dimensions are (246,3) and (246,20) respectively, and fill the mask with true and false bool types. In this case,there are 49 valid values,.So the shapes with valid features and coordinates are (49,20) and (49,3) and the rest are padded to zero. The first 49 of the mask is true and the rest is false

    Then I check through the torch.utils.data.DataLoader get input dimension is no problem: image My network looks like this, A typical binary classifier,I output the result of se3:

    image

    However, the output of se3 is NaN as shown in the figure below. image The first 48 have results, the rest are NaN, not clear what the problem is

    The next batch may all be NaN due to model weight changes image The output of SE3 causes all my losses to be NaN

    I think these NaN are probably caused by my incorrect filling method. I have also tried to fill in "1", but it was also ineffective. Maybe the mask I typed in can't have false, and I'm confused about that. Could you please tell me how to correctly fill coordinates, features and masks or give me some advice on using your SE3 model to handle variable length data. Hope this helps! ! thank you

    opened by zyk19981118 8
  • CPU/CUDA masking error

    CPU/CUDA masking error

    Hi - nice work - I was just testing out your code and ran into the following error, but only in the backward pass:

    RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mask' in call to th_masked_scatter_bool

    When using the nightly Pytorch, the error message is:

    RuntimeError: Tensor for argument #2 'mask' is on CPU, but expected it to be on GPU (while checking arguments for masked_scatter_)

    I'm pretty sure I don't have any tensors in CPU memory, but not sure if this is a bug in your SE3 code or a Pytorch issue. My gut feeling is this is a Pytorch/autograd issue, but I just don't know these particular Pytorch ops well enough to be sure. Tried both 1.7.1 release Pytorch and the latest nightly. Seems like there has been recent work on masked_scatter according to Pytorch issues.

    opened by denjots 8
  • small bug

    small bug

    https://github.com/lucidrains/se3-transformer-pytorch/blob/7c79998e4d84ec6bd6b6d4b916c6bf30b870b75b/se3_transformer_pytorch/se3_transformer_pytorch.py#L301

    should be if isinstance(m,nn.Linear)

    nbd, but thought you might wanna know.

    opened by MattMcPartlon 5
  • INTERNAL ASSERT FAILED

    INTERNAL ASSERT FAILED

    Hi - I'm not sure if this is actually a bug or if I'm just expecting too much, but the following code snippet bombs with a PyTorch internal assert failure:

    RuntimeError: sub_iter.strides(0)[0] == 0INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1618643394934/work/aten/src/ATen/native/cuda/Reduce.cuh":929, please report a bug to PyTorch.

    That's using the latest PyTorch nightly. Tried with both CUDA 10.2 and 11.1.

    Looking at the PyTorch bug reports, this seems to suggest that there is massive internal memory allocation being triggered when summing over a large tensor. It seems to have been sitting open for a year and I'm not sure if they even see it as an actual bug there. Certainly it doesn't seem to be a priority.

    Problem with this is that this isn't a particularly large data input, or a large model (to say the least!), and my GPU has 40 Gb of RAM, so, scalability of this particular transformer model seems very minimal as it currently stands. Changing the dim from 128 to 64 does at least allow it to run.

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(dim = 128, heads = 1, depth = 1, dim_head = 1, num_degrees = 1, input_degrees=1, output_degrees=1).cuda()
    
    feats = torch.randn(1, 200, 128).cuda()
    coords = torch.randn(1, 200, 3).cuda()
    
    out = model(feats, coords)
    
    
    opened by denjots 4
  • question about non scalar output

    question about non scalar output

    Hello,

    I am interested in using your implementation on my dataset. There are two things I want to check with you

    1. The property I want to predict is a 3x3 symmetric PSD matrix.
    2. There are some edge features (one categorical feature, one continuous feature) besides the coordinate difference

    I was wondering does the current se3-transformer can work with such a scenario? Thanks!

    opened by Chen-Cai-OSU 3
  • Reversible flag odd results

    Reversible flag odd results

    Sorry if it's expected behaviour again, but with a slight tweak of your new example code I get unexpected results when your new reversible option is used to return type 1 data. Here's the code I'm running...

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    model = SE3Transformer(
        num_tokens = 20,
        dim = 32,
        dim_head = 32,
        heads = 4,
        depth = 12,             # 12 layers
        input_degrees = 1,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True,
        reversible = True       # set reversible to True
    ).cuda()
    
    atoms = torch.randint(0, 4, (1, 50)).cuda()
    coors = torch.randn(1, 50, 3).cuda()
    mask  = torch.ones(1, 50).bool().cuda()
    
    pred = model(atoms, coors, mask = mask, return_type = 1)
    
    loss = pred.mean()
    print(loss)
    loss.backward()
    
    

    Without the reversible flag, the loss is close to zero as might be expected as the input coords are centered on the origin. However, when reversible is set, a very high value is produced which is going to blow up training if this was for real. Is this an expected side effect of reversible nets? Is some kind of normalization essential in that case? Or am I just doing something wrong here?

    opened by denjots 3
  • faster loop

    faster loop

    My small grain on sand for this project ;) : at least don't deal with python appends which are 2x slower than list comprehension.

    If this is of any help, here are some considerations (there might be misunderstandings on my side due to the decorators and so on):

    • i have my reservations on the utility of the line 62 in utils.py
    • would't make more sense to start the for i modified (line 122 of utils.py) in reverse order, then use the cached calculations for the lpmv() ?
    • same case (loop in reverse order) for the for in line 148 of basis.py?
    • if using the scipy.special.poch (which can deal with np arrays) instead of the custom pochhammer implementation, all operations inside get_spherical_harmonics_element are vectorizeable but the lpmv function call.
      • My sense is that the lpmv, get_spherical_harmonics_element and get_spherical_harmonics could be all wrapped in a single function (lower reusability / extension... so maybe doing the inverse loop order and caching is enough).
    opened by hypnopump 1
  • Question about continuous edge features

    Question about continuous edge features

    Thanks for all of the work!

    I am working on a simple proof of concept with your model. Ideally, I would like to perform multidimensional scaling. i.e. given a distance matrix recover corresponding coordinates.

    I was wondering if distance information could be passed as an edge feature (continuous rather than categorical information). Is it possible to do this with the current implementation?

    Thanks again and I appreciate the help!

    opened by MattMcPartlon 1
  • denoise.py bugfix

    denoise.py bugfix

    Fixes issue related to the constructor

    Traceback (most recent call last):
      File "/home/jcastellanos/projects/se3-transformer-pytorch/denoise.py", line 22, in <module>
        transformer = SE3Transformer(
      File "/home/jcastellanos/projects/se3-transformer-pytorch/se3_transformer_pytorch/se3_transformer_pytorch.py", line 1072, in __init__
        self.num_degrees = num_degrees if exists(num_degrees) else (max(hidden_fiber_dict.keys()) + 1)
    AttributeError: 'NoneType' object has no attribute 'keys'
    
    opened by javierbq 0
  • CUDA out of memory

    CUDA out of memory

    Thanks for your great job!

    The se3-transformer is powerful, but seems to be memory exhaustive.

    I built a model with the following parameters, and got "CUDA out of memory error" when I run it on the GPU(Nvidia V100 / 32G).

    model = SE3Transformer( dim = 20, heads = 4, depth = 2, dim_head = 5, num_degrees = 2, valid_radius = 5 )

    num_points = 512
    feats = torch.randn(1, num_points, 20)
    coors = torch.randn(1, num_points, 3)
    mask = torch.ones(1, num_points).bool()
    

    Does this error relate to the version of pytorch? and how can I fix it?

    opened by PengCheng-NUDT 0
  • SE3Transformer constructor hangs

    SE3Transformer constructor hangs

    I am trying to run an example from the README. The code is:

    import torch
    from se3_transformer_pytorch import SE3Transformer
    
    print('Initialising model...')
    model = SE3Transformer(
        dim = 512,
        heads = 8,
        depth = 6,
        dim_head = 64,
        num_degrees = 4,
        valid_radius = 10
    )
    
    print('Running model...')
    feats = torch.randn(1, 1024, 512)
    coors = torch.randn(1, 1024, 3)
    mask  = torch.ones(1, 1024).bool()
    
    out = model(feats, coors, mask) # (1, 1024, 512)
    

    The output hangs on 'Initialising model...' and eventually the kernel dies.

    Any ideas why this would be happening?

    Here is my pip freeze:

    anyio==3.2.1
    argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613036642480/work
    astunparse==1.6.3
    async-generator==1.10
    attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
    axial-positional-embedding==0.2.1
    Babel==2.9.1
    backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
    biopython==1.79
    bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
    cached-property @ file:///tmp/build/80754af9/cached-property_1600785575025/work
    certifi==2021.5.30
    cffi @ file:///tmp/build/80754af9/cffi_1613246939562/work
    chardet==4.0.0
    click==8.0.1
    configparser==5.0.2
    decorator==4.4.2
    defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
    dgl-cu101==0.4.3.post2
    dgl-cu110==0.6.1
    docker-pycreds==0.4.0
    egnn-pytorch==0.2.6
    einops==0.3.0
    En-transformer==0.3.8
    entrypoints==0.3
    equivariant-attention @ file:///workspace/projects/se3-transformer-public
    filelock==3.0.12
    gitdb==4.0.7
    GitPython==3.1.18
    graph-transformer-pytorch==0.0.1
    h5py @ file:///tmp/build/80754af9/h5py_1622088444809/work
    huggingface-hub==0.0.12
    idna==2.10
    importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617877314848/work
    ipykernel @ file:///tmp/build/80754af9/ipykernel_1596206598566/work/dist/ipykernel-5.3.4-py3-none-any.whl
    ipython @ file:///tmp/build/80754af9/ipython_1617118429768/work
    ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
    ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
    jedi==0.17.0
    Jinja2 @ file:///tmp/build/80754af9/jinja2_1621238361758/work
    joblib==1.0.1
    json5==0.9.6
    jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
    jupyter==1.0.0
    jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
    jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
    jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213308260/work
    jupyter-server==1.8.0
    jupyter-tensorboard==0.2.0
    jupyterlab==3.0.16
    jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
    jupyterlab-server==2.6.0
    jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
    jupytext==1.11.3
    lie-learn @ git+https://github.com/AMLab-Amsterdam/[email protected]
    llvmlite==0.36.0
    local-attention==1.4.1
    markdown-it-py==1.1.0
    MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
    mdit-py-plugins==0.2.8
    mdtraj==1.9.6
    mistune @ file:///tmp/build/80754af9/mistune_1594373098390/work
    mkl-fft==1.3.0
    mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
    mkl-service==2.3.0
    mp-nerf==0.1.11
    nbclassic==0.3.1
    nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
    nbconvert @ file:///tmp/build/80754af9/nbconvert_1601914821128/work
    nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
    nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
    networkx==2.5.1
    notebook @ file:///tmp/build/80754af9/notebook_1621523661196/work
    numba==0.53.1
    numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
    packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
    pandas==1.2.4
    pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120451932/work
    parso @ file:///tmp/build/80754af9/parso_1617223946239/work
    pathtools==0.1.2
    performer-pytorch==1.0.11
    pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
    pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
    ProDy==2.0
    prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
    promise==2.3
    prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
    protobuf==3.17.3
    psutil==5.8.0
    ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
    py3Dmol==0.9.1
    pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
    Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
    pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
    pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141707582/work
    python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
    pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
    PyYAML==5.4.1
    pyzmq==20.0.0
    qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
    QtPy==1.9.0
    regex==2021.4.4
    requests==2.25.1
    sacremoses==0.0.45
    scipy @ file:///tmp/build/80754af9/scipy_1618852618548/work
    se3-transformer-pytorch==0.8.10
    Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
    sentry-sdk==1.1.0
    shortuuid==1.0.1
    sidechainnet==0.6.0
    six @ file:///tmp/build/80754af9/six_1623709665295/work
    smmap==4.0.0
    sniffio==1.2.0
    subprocess32==3.5.4
    terminado==0.9.4
    testpath @ file:///home/ktietz/src/ci/testpath_1611930608132/work
    tokenizers==0.10.3
    toml==0.10.2
    torch==1.9.0
    tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
    tqdm==4.61.1
    traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
    transformers==4.8.0
    typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
    urllib3==1.26.5
    wandb==0.10.32
    wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
    webencodings==0.5.1
    websocket-client==1.1.0
    widgetsnbextension==3.5.1
    zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
    

    Here is a summary of my system info (lshw -short):

    H/W path    Device  Class      Description
    ==========================================
                        system     Computer
    /0                  bus        Motherboard
    /0/0                memory     59GiB System memory
    /0/1                processor  Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
    /0/100              bridge     440FX - 82441FX PMC [Natoma]
    /0/100/1            bridge     82371SB PIIX3 ISA [Natoma/Triton II]
    /0/100/1.1          storage    82371SB PIIX3 IDE [Natoma/Triton II]
    /0/100/1.3          bridge     82371AB/EB/MB PIIX4 ACPI
    /0/100/2            display    GD 5446
    /0/100/3            network    Elastic Network Adapter (ENA)
    /0/100/1e           display    GK210GL [Tesla K80]
    /0/100/1f           generic    Xen Platform Device
    /1          eth0    network    Ethernet interface
    
    opened by mpdprot 1
  • Whether SE3 needs pre-training

    Whether SE3 needs pre-training

    Thank you for your work. I used your reproduced SE3 as a part of my model, but the current test effect is not very good. I guess it may be because I do not have a good understanding of your model. Here are my questions:

    1. Does your model need pre-training?
    2. Can I train SE3 Transformer with the full connection layer that comes after it? Good advice is also welcome
    opened by zyk19981118 2
  • multiple molecules cases

    multiple molecules cases

    Hi,

    I use normally dataloader from PyG to handle my molecules dataset.

    Can you provide an example to make a real multistep epoch model please ?

    I run your sample code it's working but I need better understand input format as well as output format which for me would be x,y,z ?

    image

    thanks

    opened by thegodone 2
Releases(0.9.0)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Deep Reinforcement Learning based autonomous navigation for quadcopters using PPO algorithm.

PPO-based Autonomous Navigation for Quadcopters This repository contains an implementation of Proximal Policy Optimization (PPO) for autonomous naviga

Bilal Kabas 16 Nov 11, 2022
Session-aware Item-combination Recommendation with Transformer Network

Session-aware Item-combination Recommendation with Transformer Network 2nd place (0.39224) code and report for IEEE BigData Cup 2021 Track1 Report EDA

Tzu-Heng Lin 6 Mar 10, 2022
OBG-FCN - implementation of 'Object Boundary Guided Semantic Segmentation'

OBG-FCN This repository is to reproduce the implementation of 'Object Boundary Guided Semantic Segmentation' in http://arxiv.org/abs/1603.09742 Object

Jiu XU 3 Mar 11, 2019
Attention-driven Robot Manipulation (ARM) which includes Q-attention

Attention-driven Robotic Manipulation (ARM) This codebase is home to: Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation I

Stephen James 84 Dec 29, 2022
PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

Berivan Isik 8 Dec 08, 2022
This repo provides the official code for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/pdf/2103.04430.pdf).

TransBTS: Multimodal Brain Tumor Segmentation Using Transformer This repo is the official implementation for TransBTS: Multimodal Brain Tumor Segmenta

Raymond 247 Dec 28, 2022
Learnable Multi-level Frequency Decomposition and Hierarchical Attention Mechanism for Generalized Face Presentation Attack Detection

LMFD-PAD Note This is the official repository of the paper: LMFD-PAD: Learnable Multi-level Frequency Decomposition and Hierarchical Attention Mechani

28 Dec 02, 2022
Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision Training Efficiency We show the training efficiency of our DSLP model b

Chenyang Huang 36 Oct 31, 2022
Omniscient Video Super-Resolution

Omniscient Video Super-Resolution This is the official code of OVSR (Omniscient Video Super-Resolution, ICCV 2021). This work is based on PFNL. Datase

36 Oct 27, 2022
BMVC 2021: This is the github repository for "Few Shot Temporal Action Localization using Query Adaptive Transformers" accepted in British Machine Vision Conference (BMVC) 2021, Virtual

FS-QAT: Few Shot Temporal Action Localization using Query Adaptive Transformer Accepted as Poster in BMVC 2021 This is an official implementation in P

Sauradip Nag 14 Dec 09, 2022
Interactive Terraform visualization. State and configuration explorer.

Rover - Terraform Visualizer Rover is a Terraform visualizer. In order to do this, Rover: generates a plan file and parses the configuration in the ro

Tu Nguyen 2.3k Jan 07, 2023
Codes for "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation"

CSDI This is the github repository for the NeurIPS 2021 paper "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation

106 Jan 04, 2023
Official repository of ICCV21 paper "Viewpoint Invariant Dense Matching for Visual Geolocalization"

Viewpoint Invariant Dense Matching for Visual Geolocalization: PyTorch implementation This is the implementation of the ICCV21 paper: G Berton, C. Mas

Gabriele Berton 44 Jan 03, 2023
High dimensional black-box optimizer using Latent Action Monte Carlo Tree Search algorithm

LA-MCTS The code is based of paper Learning Search Space Partition for Black-box Optimization using Monte Carlo Tree Search. Component LA-MCTS has thr

Meta Research 18 Oct 24, 2022
Differentiable simulation for system identification and visuomotor control

gradsim gradSim: Differentiable simulation for system identification and visuomotor control gradSim is a unified differentiable rendering and multiphy

105 Dec 18, 2022
Paper: De-rendering Stylized Texts

Paper: De-rendering Stylized Texts Wataru Shimoda1, Daichi Haraguchi2, Seiichi Uchida2, Kota Yamaguchi1 1CyberAgent.Inc, 2 Kyushu University Accepted

CyberAgent AI Lab 55 Dec 18, 2022
Pytorch implementation for "Open Compound Domain Adaptation" (CVPR 2020 ORAL)

Open Compound Domain Adaptation [Project] [Paper] [Demo] [Blog] Overview Open Compound Domain Adaptation (OCDA) is the author's re-implementation of t

Zhongqi Miao 137 Dec 15, 2022
Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Julia package for contraction of tensor networks, based on the sweep line algorithm outlined in the paper General tensor network decoding of 2D Pauli codes

Christopher T. Chubb 35 Dec 21, 2022
Repo for flood prediction using LSTMs and HAND

Abstract Every year, floods cause billions of dollars’ worth of damages to life, crops, and property. With a proper early flood warning system in plac

1 Oct 27, 2021
Improving Factual Consistency of Abstractive Text Summarization

Improving Factual Consistency of Abstractive Text Summarization We provide the code for the papers: "Entity-level Factual Consistency of Abstractive T

61 Nov 27, 2022