Fast, general, and tested differentiable structured prediction in PyTorch

Overview

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
Comments
  • add tests for CKY

    add tests for CKY

    This PR fixes several bugs in k-best parsing with dist.topk() and includes a simple test to test the function.

    I made incremental changes so that existing modules relying on the CKY will not be affected.

    opened by zhaoyanpeng 8
  • 1st order cky implementation

    1st order cky implementation

    Hi,

    I'd like to contribute this implementation of a first-order cky-style crf with anchored rule potentials: $\phi[i,j,k,A,B,C] := \phi(A_{i,j} \rightarrow B_{i,k}, C{k+1,j})$.

    I also added code to the _Struct class that allows calculating marginals even if input tensors don't require a gradient (i.e., after model.eval())

    Please let me know if you'd like to see any changes.

    Thanks, Tom

    opened by teffland 6
  • Mini-batch setting with Semi Markov CRF

    Mini-batch setting with Semi Markov CRF

    I encounter learning instability when using a batch size > 1 with the semi-markovian CRF (loss goes to very large negative number), even when explicitly providing "lengths". I think the bug comes from the masking. The model train well when setting batch size 1.

    opened by urchade 5
  • Release on PyPI?

    Release on PyPI?

    Is there any interest on releasing pytorch-struct (and genbmm) on the official Python Package Index?

    I ran into this because I distribute my constituency parser on PyPI, and I just recently pushed a new version that depends on pytorch-struct: https://pypi.org/project/benepar/0.2.0a0/

    It turns out that packages on PyPI aren't allowed to depend on packages only hosted on github, so users of my parser can't just pip install benepar and have it work right away.

    opened by nikitakit 5
  • up sweep and down sweep

    up sweep and down sweep

    I'm interested in the parallel scan algorithm for the linear-chain CRF.

    I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.

    I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?

    opened by allanj 5
  • [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    [Bug] Implementation of Eisner's algorithm does not restrict the root number to 1

    Hey, I found that your implementation of Eisner's algorithm admits arbitrary root number, which is a very severe bug since dependency parsing usually has only one root token.

    In your DepTree.dp() method, you make a conversion to let the root token as the first token in the sentence. Imagine that the root x{0} attacks word x_{i}, I_{0,0} + C_{1, i} = I_{0, i} and I_{0, i} + C_{i,j} = C_{0, j} for some j < L where L is the length of sentence. Now complete span C_{0, j} still have opportunity to attach a new word x_{k} for j< k<=L, making multiple root attachment possible.

    Fortunately, I made some changes to your codes to restrict the root number to 1.

    ` def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)")

        labeled = arc_scores_in.dim() == 4
        semiring = self.semiring
        # arc_scores_in = _convert(arc_scores_in)
        arc_scores_in, batch, N, lengths = self._check_potentials(
            arc_scores_in, lengths
        )
        arc_scores_in.requires_grad_(True)
        arc_scores = semiring.sum(arc_scores_in) if labeled else arc_scores_in
        alpha = [
            [
                [
                    Chart((batch, N, N), arc_scores, semiring, cache=cache)
                    for _ in range(2)
                ]
                for _ in range(2)
            ]
            for _ in range(2)
        ]
    
        semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
        semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
        semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
        semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
    
    
        for k in range(1, N):
            f = torch.arange(N - k), torch.arange(k, N)
            ACL = alpha[A][C][L][: N - k, :k]
            ACR = alpha[A][C][R][: N - k, :k]
            BCL = alpha[B][C][L][k:, N - k :]
            BCR = alpha[B][C][R][k:, N - k :]
            x = semiring.dot(ACR, BCL)
            arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])
            alpha[A][I][L][: N - k, k] = arcs_l
            alpha[B][I][L][k:N, N - k - 1] = arcs_l
            arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
            alpha[A][I][R][:N - k, k] = arcs_r
            alpha[B][I][R][k:N, N - k - 1] = arcs_r
            AIR = alpha[A][I][R][: N - k, 1 : k + 1]
            BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]
            new = semiring.dot(ACL, BIL)
            alpha[A][C][L][: N - k, k] = new
            alpha[B][C][L][k:N, N - k - 1] = new
            new = semiring.dot(AIR, BCR)
            alpha[A][C][R][: N - k, k] = new
            alpha[B][C][R][k:N, N - k - 1] = new
    
        root_incomplete_span = semiring.times(alpha[A][C][L][0, :], arc_scores[:, :, torch.arange(N), torch.arange(N)])
        root =  [ Chart((batch,), arc_scores, semiring, cache=cache) for _ in range(N)]
        for k in range(N):
            AIR = root_incomplete_span[:, :, :k+1]
            BCR = alpha[B][C][R][k, N - (k+1):]
            root[k] = semiring.dot(AIR, BCR)
        v = torch.stack([root[l-1][:,i] for i, l in enumerate(lengths)], dim=1)
        return v, [arc_scores_in], alpha
    

    `

    Basically, I don't treat the first token as root anymore. I handle the root token just after the for-loop, so you may need handle the length variable. (length = length-1, root no longer be treated as part of sentence) . I tested the modified code and found it bug-free

    opened by sustcsonglin 4
  • Inference for the HMM model

    Inference for the HMM model

    Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters

    t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
    e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
    i = torch.tensor(np.array([0.99, 0.01])).log()
    x = torch.randint(0, 2, size=(1, 8))
    

    and I was expecting the model to stay in the hidden state 0 regardless of the observed data x – it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax, it appears that the model jumps from one state to the other:

    def show_chain(chain):
        plt.imshow(chain.detach().sum(-1).transpose(0, 1))
    
    dist = torch_struct.HMM(t, e, i, x)
    show_chain(dist.argmax[0])
    

    image

    I must be missing something obvious; but shouldn't dist.argmax correspond to argmax_z p(z | x, Θ)? Thank you!

    opened by danoneata 4
  • DependencyCRF partition function broken

    DependencyCRF partition function broken

    Getting the following in-place operation error when using the DependencyCRF:

    B,N = 3,50
    phi = torch.randn(B,N,N)
    DependencyCRF(phi).partition
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/deptree.py in _check_potentials(self, arc_scores, lengths)
        121         arc_scores = semiring.convert(arc_scores)
        122         for b in range(batch):
    --> 123             semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
        124             semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])
        125 
    
    /usr/local/lib/python3.7/dist-packages/torch_struct/semirings/semirings.py in zero_(xs)
        124     @staticmethod
        125     def zero_(xs):
    --> 126         return xs.fill_(-1e5)
        127 
        128     @staticmethod
    
    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
    
    opened by teffland 3
  • [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    [Question] How to compute a marginal probability over a (contiguous) set of nodes?

    Hi.

    Thank you for the great library. I have one question that I hope you could help with.

    How can I compute a marginal probability over a (contiguous) set of nodes? Right now, I am using your LinearChain-CRF to do NER. In addition to the best sequence itself, I also need to compute the model’s confidence in its predicted labeling over a segment of input. For example, what is the probability that a span of tokens constitute a person name?

    I read your example and see how you get the marginal prob for each individual node. But I was not quite sure how to compute the marginal prob over a subset of nodes. If you could give any hint, it would be great.

    Thank you.

    opened by kimdev95 3
  • Get the score of dist.topk()

    Get the score of dist.topk()

    The topk() function returns top k predictions from the distribution, how to easily get the corresponding score of each prediction?

    By the way, when sentence lengths are short and the k value of topk is large, how to know the number of predictions that are valid? For the example in DependencyCRF, when sentence length is 2 and k is 5, only the top 3 predictions are valid I think.

    opened by wangxinyu0922 3
  • Labeled projective dependency CRF

    Labeled projective dependency CRF

    This is work in progress and isn't ready to merge yet.

    This seems to work for partition, but argmax and marginals don't return as I expect. Both return tensor of shape (B, N, N); I'd expect them to return (B, N, N, L) tensors instead. Any advice?

    opened by kmkurn 3
  • [Question] How to apply pytorch-struct for 2 dimensional data?

    [Question] How to apply pytorch-struct for 2 dimensional data?

    I could find examples of pytorch struct usage for 1d sequence data like text or video frame. But I'm trying to parse tables structure in pdf documents.

    Could you provide some hints where to start?

    opened by YuriyPryyma 4
  • end_class support for Autoregressive

    end_class support for Autoregressive

    end_class is not used for the Autoregressive module: https://github.com/harvardnlp/pytorch-struct/blob/7146de5659ff17ad7be53023c025ffd099866412/torch_struct/autoregressive.py#L49

    opened by urchade 1
  • Update examples to use newer torchtext APIs

    Update examples to use newer torchtext APIs

    opened by erip 2
  • Instable learning with SemiMarkov CRF

    Instable learning with SemiMarkov CRF

    HI,

    First, thank you for fixing #110 (@da03), the SemiCRF works better now, I was able to get good results on span extraction tasks. However, I still encounter a learning instability where the loss (neg logprob) gets negative after several steps (and the accuracy starts to drop). The same problem occurs with batch_size = 1. Below I put the learning curve (f1_score and log loss).

    (Maybe the bug comes from the masking of spans where (length, length + span_with) and length + span_with > length, but I am not sure.)

    Edit: I created a test and it seems that the masking is good. Maybe the log_prob computation or the to_parts function ?

    train_loss score

    opened by urchade 0
  • fix bug- missing assignment of spans from sentCFG in documentation

    fix bug- missing assignment of spans from sentCFG in documentation

    Noticed a small bug in the documentation and example of SentCFG. The return of dist.argmax is (terms, rules, init, spans), but example in documentation only assigns (term, rules, init) and gives dim mismatch. As such when running the example it breaks. This fix resolves this issue.

    opened by jdegange 0
Releases(v0.5)
Pytorch implementation of Tacotron

Tacotron-pytorch A pytorch implementation of Tacotron: A Fully End-to-End Text-To-Speech Synthesis Model. Requirements Install python 3 Install pytorc

soobin seo 203 Dec 02, 2022
💛 Code and Dataset for our EMNLP 2021 paper: "Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes"

Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes Official PyTorch implementation and EmoCause evaluatio

Hyunwoo Kim 50 Dec 21, 2022
100+ Chinese Word Vectors 上百种预训练中文词向量

Chinese Word Vectors 中文词向量 中文 This project provides 100+ Chinese Word Vectors (embeddings) trained with different representations (dense and sparse),

embedding 10.4k Jan 09, 2023
Collection of useful (to me) python scripts for interacting with napari

Napari scripts A collection of napari related tools in various state of disrepair/functionality. Browse_LIF_widget.py This module can be imported, for

5 Aug 15, 2022
"Investigating the Limitations of Transformers with Simple Arithmetic Tasks", 2021

transformers-arithmetic This repository contains the code to reproduce the experiments from the paper: Nogueira, Jiang, Lin "Investigating the Limitat

Castorini 33 Nov 16, 2022
Skipgram Negative Sampling in PyTorch

PyTorch SGNS Word2Vec's SkipGramNegativeSampling in Python. Yet another but quite general negative sampling loss implemented in PyTorch. It can be use

Jamie J. Seol 287 Dec 14, 2022
DANeS is an open-source E-newspaper dataset by collaboration between DATASET JSC (dataset.vn) and AIV Group (aivgroup.vn)

DANeS - Open-source E-newspaper dataset Source: Technology vector created by macrovector - www.freepik.com. DANeS is an open-source E-newspaper datase

DATASET .JSC 64 Aug 17, 2022
NewsMTSC: (Multi-)Target-dependent Sentiment Classification in News Articles

NewsMTSC: (Multi-)Target-dependent Sentiment Classification in News Articles NewsMTSC is a dataset for target-dependent sentiment classification (TSC)

Felix Hamborg 79 Dec 30, 2022
Persian-lexicon - A lexicon of 70K unique Persian (Farsi) words

Persian Lexicon This repo uses Uppsala Persian Corpus (UPC) to construct a lexic

Saman Vaisipour 7 Apr 01, 2022
تولید اسم های رندوم فینگیلیش

karafs کرفس تولید اسم های رندوم فینگیلیش installation ➜ pip install karafs usage دو زبانه ➜ karafs -n 10 توت فرنگی بی ناموس toot farangi-ye bi_namoos

Vaheed NÆINI (9E) 36 Nov 24, 2022
Search-Engine - 📖 AI based search engine

Search Engine AI based search engine that was trained on 25000 samples, feel free to train on up to 1.2M sample from kaggle dataset, link below StackS

Vladislav Kruglikov 2 Nov 29, 2022
Unofficial PyTorch implementation of Google AI's VoiceFilter system

VoiceFilter Note from Seung-won (2020.10.25) Hi everyone! It's Seung-won from MINDs Lab, Inc. It's been a long time since I've released this open-sour

MINDs Lab 881 Jan 03, 2023
Extract city and country mentions from Text like GeoText without regex, but FlashText, a Aho-Corasick implementation.

flashgeotext ⚡ 🌍 Extract and count countries and cities (+their synonyms) from text, like GeoText on steroids using FlashText, a Aho-Corasick impleme

Ben 57 Dec 16, 2022
LightSpeech: Lightweight and Fast Text to Speech with Neural Architecture Search

LightSpeech UnOfficial PyTorch implementation of LightSpeech: Lightweight and Fast Text to Speech with Neural Architecture Search.

Rishikesh (ऋषिकेश) 54 Dec 03, 2022
Tools for curating biomedical training data for large-scale language modeling

Tools for curating biomedical training data for large-scale language modeling

BigScience Workshop 242 Dec 25, 2022
Deep Learning for Natural Language Processing - Lectures 2021

This repository contains slides for the course "20-00-0947: Deep Learning for Natural Language Processing" (Technical University of Darmstadt, Summer term 2021).

0 Feb 21, 2022
Curso práctico: NLP de cero a cien 🤗

Curso Práctico: NLP de cero a cien Comprende todos los conceptos y arquitecturas clave del estado del arte del NLP y aplícalos a casos prácticos utili

Somos NLP 147 Jan 06, 2023
Exploration of BERT-based models on twitter sentiment classifications

twitter-sentiment-analysis Explore the relationship between twitter sentiment of Tesla and its stock price/return. Explore the effect of different BER

Sammy Cui 2 Oct 02, 2022
Accurately generate all possible forms of an English word e.g "election" --> "elect", "electoral", "electorate" etc.

Accurately generate all possible forms of an English word Word forms can accurately generate all possible forms of an English word. It can conjugate v

Dibya Chakravorty 570 Dec 31, 2022
News-Articles-and-Essays - NLP (Topic Modeling and Clustering)

NLP T5 Project proposal Topic Modeling and Clustering of News-Articles-and-Essays Students: Nasser Alshehri Abdullah Bushnag Abdulrhman Alqurashi OVER

2 Jan 18, 2022