Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Overview

Enformer - Pytorch (wip)

Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow sonnet code can be found here.

Citations

@article {Avsec2021.04.07.438649,
    author  = {Avsec, {\v Z}iga and Agarwal, Vikram and Visentin, Daniel and Ledsam, Joseph R. and Grabska-Barwinska, Agnieszka and Taylor, Kyle R. and Assael, Yannis and Jumper, John and Kohli, Pushmeet and Kelley, David R.},
    title   = {Effective gene expression prediction from sequence by integrating long-range interactions},
    elocation-id = {2021.04.07.438649},
    year    = {2021},
    doi     = {10.1101/2021.04.07.438649},
    publisher = {Cold Spring Harbor Laboratory},
    URL     = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649},
    eprint  = {https://www.biorxiv.org/content/early/2021/04/08/2021.04.07.438649.full.pdf},
    journal = {bioRxiv}
}
Comments
  • Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Using EleutherAI/enformer-official-rough PyTorch implementation to just get human output head

    Hi @lucidrains,

    Thank you so much for your efforts in releasing the PyTorch version of the Enformer model! I am really excited to use it for my particular implementation.

    I was wondering if it is possible to use the pre-trained huggingface model to just get the human output head. The reason is that inference takes a few minutes, and since I just need human data, this will help make my implementation a bit smoother. Is there a way to do this elegantly with the current codebase, or would I need to rewrite some functions to allow for this? From what I have seen so far it doesn't seem that this modularity is possible yet.

    The way I have set up my inference currently is as follows:

    class EnformerInference:
        def __init__(self, data_path: str, model_path="EleutherAI/enformer-official-rough"):
            if torch.cuda.is_available():
                device = torch.device("cuda")
            else:
                device = torch.device("cpu")
            self.device = device
            self.model = Enformer.from_pretrained(model_path)
            self.data = EnformerDataLoader(pd.read_csv(data_path, sep="\t")) # returns a one hot encoded torch.Tensor representation of the sequence of interest
                                                                                                                              
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.model(x.to(self.device))
    

    Any guidance on this would be greatly appreciated, thank you!

    opened by aaronwtr 4
  • Host weights on HuggingFace hub

    Host weights on HuggingFace hub

    Hi Phil Wang,

    Created a little demo on how you can easily load pre-trained weights from the HuggingFace hub into your Enformer model. I've basically followed this guide which Sylvain (@sgugger) wrote recently. It's a new feature that let's you push model weights to the hub and allows to load them into any custom PyTorch/TF/Flax model.

    From this PR, you can do (after pip install enformer-pytorch):

    from enformer_pytorch import Enformer
    
    model = Enformer.from_pretrained("nielsr/enformer-preview")
    

    If you consent, then I'll transfer all weights to the eleutherai organization on the hub, such that you can do from_pretrained("eleutherai/enformer-preview").

    The weights are hosted here: https://huggingface.co/nielsr/enformer-preview. As you can see in the "files and versions" tab, it contains a pytorch_model.bin file, which has a size of about 1GB. You can also load the other variant, as follows:

    model = Enformer.from_pretrained("nielsr/enformer-corr_coef_obj")
    

    To make it work, the only thing that is required is encapsulating all hyperparameters regarding the model architecture into a separate EnformerConfig object (which I've defined in config_enformer.py). It can be instantiated as follows:

    from enformer_pytorch import EnformerConfig
    
    config = EnformerConfig(
        dim = 1536,
        depth = 11,
        heads = 8,
        output_heads = dict(human = 5313, mouse = 1643),
        target_length = 896,
    )
    

    To initialize an Enformer model with randomly initialized weights, you can do:

    from enformer_pytorch import Enformer
    
    model = Enformer(config)
    

    There's no need for the config.yml and model_loader.py files anymore, as these are now handled by HuggingFace :)

    Let me know what you think about it :)

    Kind regards,

    Niels

    To do:

    • [x] upload remaining checkpoints to the hub
    • [x] transfer checkpoints to the eleutherai organization
    • [x] remove config.yml and model_loading.py scripts
    • [x] update README
    opened by NielsRogge 4
  • Minor potential typo in `FastaInterval` class

    Minor potential typo in `FastaInterval` class

    Hello, first off thanks so much for this incredible repository, it's greatly accelerating a project I am working on!

    I've been using the GenomeIntervalDataset class and notice a minor potential typo in the FastaInterval class when I was trying to fetch a sequence with a negative start position and got an empty tensor back. It looks like there is logic for clipping the start position at 0 and padding the sequence here https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#L137-L143 but that it wasn't being used in my case as it was inside the above if clause that I wasn't triggering https://github.com/lucidrains/enformer-pytorch/blob/ab29196d535802c8a04929534c5860fb55d06056/enformer_pytorch/data.py#LL128C9-L128C82. If I unindent that logic then everything worked fine for me.

    If it was unintentional to have the clipping inside that if clause I'd be happy to submit a trivial PR to fix the indentation.

    Thanks again for all your work

    opened by sofroniewn 2
  • example data files

    example data files

    Hi, in the README, you mentioned to use sequences.bed and hg38.ml.fa files to build the GenomeIntervalDataset, but I can't find these example data files, could you provide the links of these files ? Thanks!

    opened by yingyuan830 2
  • Why do we need Residual here while we have residual connection inside conv block

    Why do we need Residual here while we have residual connection inside conv block

    we wrap conv block inside Residual: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L318

    while we have residual connection already inside conv block here: https://github.com/lucidrains/enformer-pytorch/blob/1cbbe860bbd3ce8c26cee3de149d4fcdba508d95/enformer_pytorch/modeling_enformer.py#L226

    opened by inspirit 2
  • Add base_model_prefix

    Add base_model_prefix

    This PR fixes the from_pretrained method by adding base_model_prefix, as this makes sure weights are properly loaded from the hub.

    Kudos to @sgugger for finding the bug.

    opened by NielsRogge 2
  • How to load the pre-trained Enfromer model?

    How to load the pre-trained Enfromer model?

    Hi, I encountered a problem when trying to load the pre-trained enformer model.

    from enformer_pytorch import Enformer model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError Traceback (most recent call last) Input In [3], in 1 from enformer_pytorch import Enformer ----> 2 model = Enformer.from_pretrained("EleutherAI/enformer-preview")

    AttributeError: type object 'Enformer' has no attribute 'from_pretrained'

    opened by yzJiang9 2
  • enformer TF pretrained weights

    enformer TF pretrained weights

    Hello!

    Thanks for this wonderful resource. I was wondering whether you can point me to how to obtain the model weights for the original TF version of Enformer, or the actual weights if they are stored somewhere easily accessible. I see the model on TF hub but am not sure exactly how to extract the weights - I seem to be running into some issues potentially because the original code is sonnet based and the model is always loaded as a custom user object..

    Much appreciated!

    opened by naumanjaved 1
  • AttentionPool bug?

    AttentionPool bug?

    Looking at the attention pool class did you mean to have

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = self.pool_size)
    

    instead of

    self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
    

    Here's the full class

    class AttentionPool(nn.Module):
        def __init__(self, dim, pool_size = 2):
            super().__init__()
            self.pool_size = pool_size
            self.pool_fn = Rearrange('b d (n p) -> b d n p', p = 2)
            self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias = False)
    
        def forward(self, x):
            b, _, n = x.shape
            remainder = n % self.pool_size
            needs_padding = remainder > 0
    
            if needs_padding:
                x = F.pad(x, (0, remainder), value = 0)
                mask = torch.zeros((b, 1, n), dtype = torch.bool, device = x.device)
                mask = F.pad(mask, (0, remainder), value = True)
    
            x = self.pool_fn(x)
            logits = self.to_attn_logits(x)
    
            if needs_padding:
                mask_value = -torch.finfo(logits.dtype).max
                logits = logits.masked_fill(self.pool_fn(mask), mask_value)
    
            attn = logits.softmax(dim = -1)
    
            return (x * attn).sum(dim = -1)
    
    opened by cmlakhan 1
  • Colab notebook for computing the correlation across different basenji2 dataset splits.

    Colab notebook for computing the correlation across different basenji2 dataset splits.

    New features:

    1. Colab notebook for computing correlations across the different basenji2 dataset splits.
    2. Pytorch metric for computing the mean of per-channel correlations properly aggregated across a region set.
    opened by jstjohn 0
  • Computing Contribution Scores

    Computing Contribution Scores

    From the paper:

    To better understand what sequence elements Enformer is utilizing when making predictions, we computed two different gene expression contribution scores — input gradients (gradient × input and attention weights

    I was just wondering how to compute input gradients and fetch the attention matrix for the given input. I'm not well versed with PyTorch, so I'm sorry if this is a noob question.

    opened by Prakash2403 0
  • Models in training splits

    Models in training splits

    Hey,

    Is there a way of getting the models trained in each training set, as mentioned in the "Model training and evaluation" paragraph of the Enformer paper?

    Thanks!

    opened by luciabarb 0
  • metric for enformer

    metric for enformer

    Hello, can I ask how you find of the human pearson R is 0.625 for validation, and 0.65 for test? Couldn't find any information in the paper. Is there any other place that records this?

    opened by Rachel66666 0
  • error loading enformer package

    error loading enformer package

    I am trying to install the enformer package but seem to be getting the following error:

    >>> import torch
    >>> from enformer_pytorch import Enformer
    Traceback (most recent call last):
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 905, in _get_module
        return importlib.import_module("." + module_name, self.__name__)
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/importlib/__init__.py", line 127, in import_module
        return _bootstrap._gcd_import(name[level:], package, level)
      File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/modeling_utils.py", line 76, in <module>
        from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
    ImportError: cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/__init__.py", line 2, in <module>
        from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/enformer_pytorch/modeling_enformer.py", line 14, in <module>
        from transformers import PreTrainedModel
      File "<frozen importlib._bootstrap>", line 1055, in _handle_fromlist
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 895, in __getattr__
        module = self._get_module(self._class_to_module[name])
      File "/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/transformers/utils/import_utils.py", line 907, in _get_module
        raise RuntimeError(
    RuntimeError: Failed to import transformers.modeling_utils because of the following error (look up to see its traceback):
    cannot import name 'dispatch_model' from 'accelerate' (/sc/arion/projects/ad-omics/clakhani/conda/envs/enformer_lightning/lib/python3.9/site-packages/accelerate/__init__.py)
    

    I simply cloned an existing pytorch environment on Conda (using cuda 11.1 and torch 1.10) and then pip installed the hugging face packages and enformer packages

    pip install transformers
    pip install datasets
    pip install accelerate
    pip install tokenizers
    pip install enformer-pytorch
    

    Any idea why I'm getting this error?

    opened by cmlakhan 1
Releases(0.5.6)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Official repository of "Investigating Tradeoffs in Real-World Video Super-Resolution"

RealBasicVSR [Paper] This is the official repository of "Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv". This repository contain

Kelvin C.K. Chan 566 Dec 28, 2022
The Dual Memory is build from a simple CNN for the deep memory and Linear Regression fro the fast Memory

Simple-DMA a simple Dual Memory Architecture for classifications. based on the paper Dual-Memory Deep Learning Architectures for Lifelong Learning of

1 Jan 27, 2022
This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University.

bayesian_uncertainty This is my research project for the Irving Center for Cancer Dynamics/Azizi Lab, Columbia University. In this project I build a s

Max David Gupta 1 Feb 13, 2022
BalaGAN: Image Translation Between Imbalanced Domains via Cross-Modal Transfer

BalaGAN: Image Translation Between Imbalanced Domains via Cross-Modal Transfer Project Page | Paper | Video State-of-the-art image-to-image translatio

47 Dec 06, 2022
Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering (NAACL 2021)

Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering Abstract In open-domain question answering (QA), retrieve-and-read mec

Clova AI Research 34 Apr 13, 2022
A PyTorch implementation of "Pathfinder Discovery Networks for Neural Message Passing"

A PyTorch implementation of "Pathfinder Discovery Networks for Neural Message Passing" (WebConf 2021). Abstract In this work we propose Pathfind

Benedek Rozemberczki 49 Dec 01, 2022
Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning" (AAAI 2021)

Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic

NAVER/LINE Vision 30 Dec 06, 2022
Code for "OctField: Hierarchical Implicit Functions for 3D Modeling (NeurIPS 2021)"

OctField(Jittor): Hierarchical Implicit Functions for 3D Modeling Introduction This repository is code release for OctField: Hierarchical Implicit Fun

55 Dec 08, 2022
PolyphonicFormer: Unified Query Learning for Depth-aware Video Panoptic Segmentation

PolyphonicFormer: Unified Query Learning for Depth-aware Video Panoptic Segmentation Winner method of the ICCV-2021 SemKITTI-DVPS Challenge. [arxiv] [

Yuan Haobo 38 Jan 03, 2023
ThunderSVM: A Fast SVM Library on GPUs and CPUs

What's new We have recently released ThunderGBM, a fast GBDT and Random Forest library on GPUs. add scikit-learn interface, see here Overview The miss

Xtra Computing Group 1.4k Dec 22, 2022
[NeurIPS 2021] SSUL: Semantic Segmentation with Unknown Label for Exemplar-based Class-Incremental Learning

SSUL - Official Pytorch Implementation (NeurIPS 2021) SSUL: Semantic Segmentation with Unknown Label for Exemplar-based Class-Incremental Learning Sun

Clova AI Research 44 Dec 27, 2022
Effect of Deep Transfer and Multi task Learning on Sperm Abnormality Detection

Effect of Deep Transfer and Multi task Learning on Sperm Abnormality Detection Introduction This repository includes codes and models of "Effect of De

Amir Abbasi 5 Sep 05, 2022
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 62 Dec 20, 2022
Ranger deep learning optimizer rewrite to use newest components

Ranger21 - integrating the latest deep learning components into a single optimizer Ranger deep learning optimizer rewrite to use newest components Ran

Less Wright 266 Dec 28, 2022
Laser device for neutralizing - mosquitoes, weeds and pests

Laser device for neutralizing - mosquitoes, weeds and pests (in progress) Here I will post information for creating a laser device. A warning!! How It

Ildaron 1k Jan 02, 2023
Efficient semidefinite bounds for multi-label discrete graphical models.

Low rank solvers #################################### benchmark/ : folder with the random instances used in the paper. ############################

1 Dec 08, 2022
Tutoriais publicados nas nossas redes sociais para obtenção de dados, análises simples e outras tarefas relevantes no mercado financeiro.

Tutoriais Públicos Tutoriais publicados nas nossas redes sociais para obtenção de dados, análises simples e outras tarefas relevantes no mercado finan

Trading com Dados 68 Oct 15, 2022
[NeurIPS 2021]: Are Transformers More Robust Than CNNs? (Pytorch implementation & checkpoints)

Are Transformers More Robust Than CNNs? Pytorch implementation for NeurIPS 2021 Paper: Are Transformers More Robust Than CNNs? Our implementation is b

Yutong Bai 145 Dec 01, 2022
This repo holds codes of the ICCV21 paper: Visual Alignment Constraint for Continuous Sign Language Recognition.

VAC_CSLR This repo holds codes of the paper: Visual Alignment Constraint for Continuous Sign Language Recognition.(ICCV 2021) [paper] Prerequisites Th

Yuecong Min 64 Dec 19, 2022
PaddleViT: State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 2.0+

PaddlePaddle Vision Transformers State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 🤖 PaddlePaddle Visual Transformers (PaddleViT or

1k Dec 28, 2022