Fast, general, and tested differentiable structured prediction in PyTorch

Overview

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
Comments
  • add tests for CKY

    add tests for CKY

    This PR fixes several bugs in k-best parsing with dist.topk() and includes a simple test to test the function.

    I made incremental changes so that existing modules relying on the CKY will not be affected.

    opened by zhaoyanpeng 8
  • 1st order cky implementation

    1st order cky implementation

    Hi,

    I'd like to contribute this implementation of a first-order cky-style crf with anchored rule potentials: $\phi[i,j,k,A,B,C] := \phi(A_{i,j} \rightarrow B_{i,k}, C{k+1,j})$.

    I also added code to the _Struct class that allows calculating marginals even if input tensors don't require a gradient (i.e., after model.eval())

    Please let me know if you'd like to see any changes.

    Thanks, Tom

    opened by teffland 6
  • Mini-batch setting with Semi Markov CRF

    Mini-batch setting with Semi Markov CRF

    I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking. The model train well when setting batch size 1.

    opened by urchade 5
  • Release on PyPI?

    Release on PyPI?

    Is there any interest on releasing pytorch-struct (and genbmm) on the official Python Package Index?

    I ran into this because I distribute my constituency parser on PyPI, and I just recently pushed a new version that depends on pytorch-struct: https://pypi.org/project/benepar/0.2.0a0/

    It turns out that packages on PyPI aren't allowed to depend on packages only hosted on github, so users of my parser can't just pip install benepar and have it work right away.

    opened by nikitakit 5
  • up sweep and down sweep

    up sweep and down sweep

    I'm interested in the parallel scan algorithm for the linear-chain CRF.

    I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.

    I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?

    opened by allanj 5
  • [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    Hey, I found that your implementation of Eisner's algorithm admits arbitrary root number, which is a very severe bug since dependency parsing usually has only one root token.

    In your DepTree.dp() method, you make a conversion to let the root token as the first token in the sentence. Imagine that the root x{0} attacks word x_{i}, I_{0,0} + C_{1, i} = I_{0, i} and I_{0, i} + C_{i,j} = C_{0, j} for some j < L where L is the length of sentence. Now complete span C_{0, j} still have opportunity to attach a new word x_{k} for j< k<=L, making multiple root attachment possible.

    Fortunately, I made some changes to your codes to restrict the root number to 1.

    ` def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)")

        labeled = arc_scores_in.dim() == 4
        semiring = self.semiring
        # arc_scores_in = _convert(arc_scores_in)
        arc_scores_in, batch, N, lengths = self._check_potentials(
            arc_scores_in, lengths
        )
        arc_scores_in.requires_grad_(True)
        arc_scores = semiring.sum(arc_scores_in) if labeled else arc_scores_in
        alpha = [
            [
                [
                    Chart((batch, N, N), arc_scores, semiring, cache=cache)
                    for _ in range(2)
                ]
                for _ in range(2)
            ]
            for _ in range(2)
        ]
    
        semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
        semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
        semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
        semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
    
    
        for k in range(1, N):
            f = torch.arange(N - k), torch.arange(k, N)
            ACL = alpha[A][C][L][: N - k, :k]
            ACR = alpha[A][C][R][: N - k, :k]
            BCL = alpha[B][C][L][k:, N - k :]
            BCR = alpha[B][C][R][k:, N - k :]
            x = semiring.dot(ACR, BCL)
            arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])
            alpha[A][I][L][: N - k, k] = arcs_l
            alpha[B][I][L][k:N, N - k - 1] = arcs_l
            arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
            alpha[A][I][R][:N - k, k] = arcs_r
            alpha[B][I][R][k:N, N - k - 1] = arcs_r
            AIR = alpha[A][I][R][: N - k, 1 : k + 1]
            BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]
            new = semiring.dot(ACL, BIL)
            alpha[A][C][L][: N - k, k] = new
            alpha[B][C][L][k:N, N - k - 1] = new
            new = semiring.dot(AIR, BCR)
            alpha[A][C][R][: N - k, k] = new
            alpha[B][C][R][k:N, N - k - 1] = new
    
        root_incomplete_span = semiring.times(alpha[A][C][L][0, :], arc_scores[:, :, torch.arange(N), torch.arange(N)])
        root =  [ Chart((batch,), arc_scores, semiring, cache=cache) for _ in range(N)]
        for k in range(N):
            AIR = root_incomplete_span[:, :, :k+1]
            BCR = alpha[B][C][R][k, N - (k+1):]
            root[k] = semiring.dot(AIR, BCR)
        v = torch.stack([root[l-1][:,i] for i, l in enumerate(lengths)], dim=1)
        return v, [arc_scores_in], alpha
    

    `

    Basically, I don't treat the first token as root anymore. I handle the root token just after the for-loop, so you may need handle the length variable. (length = length-1, root no longer be treated as part of sentence) . I tested the modified code and found it bug-free

    opened by sustcsonglin 4
  • Inference for the HMM model

    Inference for the HMM model

    Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters

    t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
    e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
    i = torch.tensor(np.array([0.99, 0.01])).log()
    x = torch.randint(0, 2, size=(1, 8))
    

    and I was expecting the model to stay in the hidden state 0 regardless of the observed data x – it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax, it appears that the model jumps from one state to the other:

    def show_chain(chain):
        plt.imshow(chain.detach().sum(-1).transpose(0, 1))
    
    dist = torch_struct.HMM(t, e, i, x)
    show_chain(dist.argmax[0])
    

    image

    I must be missing something obvious; but shouldn't dist.argmax correspond to argmax_z p(z | x, Θ)? Thank you!

    opened by danoneata 4
  • DependencyCRF partition function broken

    DependencyCRF partition function broken

    Getting the following in-place operation error when using the DependencyCRF:

    B,N = 3,50
    phi = torch.randn(B,N,N)
    DependencyCRF(phi).partition
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/deptree.py in _check_potentials(self, arc_scores, lengths)
        121         arc_scores = semiring.convert(arc_scores)
        122         for b in range(batch):
    --> 123             semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
        124             semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])
        125 
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/semirings/semirings.py in zero_(xs)
        124     @staticmethod
        125     def zero_(xs):
    --> 126         return xs.fill_(-1e5)
        127 
        128     @staticmethod
    
    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
    
    opened by teffland 3
  • [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    Hi.

    Thank you for the great library. I have one question that I hope you could help with.

    How can I compute a marginal probability over a (contiguous) set of nodes? Right now, I am using your LinearChain-CRF to do NER. In addition to the best sequence itself, I also need to compute the model’s confidence in its predicted labeling over a segment of input. For example, what is the probability that a span of tokens constitute a person name?

    I read your example and see how you get the marginal prob for each individual node. But I was not quite sure how to compute the marginal prob over a subset of nodes. If you could give any hint, it would be great.

    Thank you.

    opened by kimdev95 3
  • Get the score of dist.topk()

    Get the score of dist.topk()

    The topk() function returns top k predictions from the distribution, how to easily get the corresponding score of each prediction?

    By the way, when sentence lengths are short and the k value of topk is large, how to know the number of predictions that are valid? For the example in DependencyCRF, when sentence length is 2 and k is 5, only the top 3 predictions are valid I think.

    opened by wangxinyu0922 3
  • Labeled projective dependency CRF

    Labeled projective dependency CRF

    This is work in progress and isn't ready to merge yet.

    This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?

    opened by kmkurn 3
  • [Question] How to apply pytorch-struct for 2 dimensional data?

    [Question] How to apply pytorch-struct for 2 dimensional data?

    I could find examples of pytorch struct usage for 1d sequence data like text or video frame. But I'm trying to parse tables structure in pdf documents.

    Could you provide some hints where to start?

    opened by YuriyPryyma 4
  • end_class support for Autoregressive

    end_class support for Autoregressive

    end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

    opened by urchade 1
  • Update examples to use newer torchtext APIs

    Update examples to use newer torchtext APIs

    opened by erip 2
  • Instable learning with SemiMarkov CRF

    Instable learning with SemiMarkov CRF

    HI,

    First, thank you for fixing #110 (@da03), the SemiCRF works better now, I was able to get good results on span extraction tasks. However, I still encounter a learning instability where the loss (neg logprob) gets negative after several steps (and the accuracy starts to drop). The same problem occurs with batch_size = 1. Below I put the learning curve (f1_score and log loss).

    (Maybe the bug comes from the masking of spans where (length, length + span_with) and length + span_with > length, but I am not sure.)

    Edit: I created a test and it seems that the masking is good. Maybe the log_prob computation or the to_parts function ?

    train_loss score

    opened by urchade 0
  • fix bug- missing assignment of spans from sentCFG in documentation

    fix bug- missing assignment of spans from sentCFG in documentation

    Noticed a small bug in the documentation and example of SentCFG. The return of dist.argmax is (terms, rules, init, spans), but example in documentation only assigns (term, rules, init) and gives dim mismatch. As such when running the example it breaks. This fix resolves this issue.

    opened by jdegange 0
Releases(v0.5)
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 07, 2023
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
High-level batteries-included neural network training library for Pytorch

Pywick High-Level Training framework for Pytorch Pywick is a high-level Pytorch training framework that aims to get you up and running quickly with st

382 Dec 06, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
High-fidelity performance metrics for generative models in PyTorch

High-fidelity performance metrics for generative models in PyTorch

Vikram Voleti 5 Oct 24, 2021
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 05, 2023
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Jan 06, 2023
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022