simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.

Overview

simplet5

Quickly train T5 models in just 3 lines of code + ONNX support

PyPI version License

simpleT5 is built on top of PyTorch-lightning ⚡️ and Transformers 🤗 that lets you quickly train your T5 models.

T5 models can be used for several NLP tasks such as summarization, QA, QG, translation, text generation, and more.

Here's a link to Medium article along with an example colab notebook

Install

pip install --upgrade simplet5

Usage

simpleT5 for summarization task Open In Collab

# import
from simplet5 import SimpleT5

# instantiate
model = SimpleT5()

# load
model.from_pretrained("t5","t5-base")

# train
model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
            eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 5,
            use_gpu = True,
            outputdir = "outputs",
            early_stopping_patience_epochs = 0
            )

# load trained T5 model
model.load_model("t5","path/to/trained/model/directory", use_gpu=False)

# predict
model.predict("input text for prediction")

# need faster inference on CPU, get ONNX support
model.convert_and_load_onnx_model("path/to/T5 model/directory")
model.onnx_predict("input text for prediction")
Comments
  • Suppress the Output Models

    Suppress the Output Models

    Hello there!

    I'd like to ask if there is any possible way to eliminate all models, except for the last trained one. When I fine tune a model, it gives me X different models if I fine tune the model X epochs. I just need the last model and couldn't find a way to prevent writing those models to disk.

    Thanks!

    opened by bayismet 6
  • TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask In onnx_predict function

    TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask In onnx_predict function

    Hello, when I run the fine-tuned mt5 model under onnx, I get the following error:

    `TypeError Traceback (most recent call last) in ----> 1 model.onnx_predict(text)

    ~\AppData\Roaming\Python\Python38\site-packages\simplet5\simplet5.py in onnx_predict(self, source_text) 469 """ generates prediction from ONNX model """ 470 token = self.onnx_tokenizer(source_text, return_tensors="pt") --> 471 tokens = self.onnx_model.generate( 472 input_ids=token["input_ids"], 473 attention_mask=token["attention_mask"],

    C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\grad_mode.py in decorate_context(*args, **kwargs) 26 def decorate_context(*args, **kwargs): 27 with self.class(): ---> 28 return func(*args, **kwargs) 29 return cast(F, decorate_context) 30

    C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, max_new_tokens, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, synced_gpus, **model_kwargs) 1051 input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs 1052 ) -> 1053 return self.beam_search( 1054 input_ids, 1055 beam_scorer,

    C:\ProgramData\Anaconda3\lib\site-packages\transformers\generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs) 1788 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1789 -> 1790 outputs = self( 1791 **model_inputs, 1792 return_dict=True,

    C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1050 or _global_forward_hooks or _global_forward_pre_hooks): -> 1051 return forward_call(*input, **kwargs) 1052 # Do not call functions when jit is used 1053 full_backward_hooks, non_full_backward_hooks = [], []

    TypeError: forward() got an unexpected keyword argument 'cross_attn_head_mask'`

    I tried to downgrade transformers and onnxruntime but the error still remains.

    opened by farshadfiruzi 6
  • colab error

    colab error

    When running google colab, the line below produced error:

    model.load_model("t5","outputs/SimpleT5-epoch-2-train-loss-0.9526", use_gpu=True)

    Error: 404 Client Error: Not Found for url: https://huggingface.co/outputs/SimpleT5-epoch-2-train-loss-0.9526/resolve/main/config.json

    Please help. Thanks a lot.

    opened by yzhang-github-pub 6
  • shows object has no attribute 'convert_and_load_onnx_model'

    shows object has no attribute 'convert_and_load_onnx_model'

    Shows the error AttributeError: 'SimpleT5' object has no attribute 'convert_and_load_onnx_model' while running the example notebook provided in the repository

    https://github.com/Shivanandroy/simpleT5/blob/main/examples/simpleT5-summarization.ipynb

    opened by pradeepdev-1995 4
  • Adding logger

    Adding logger

    With this change, users can provide a PyTorch Lightning logger object to the .train() method:

    from pytorch_lightning.loggers import WandbLogger
    
    wandb_logger = WandbLogger(project="my-project", name="run-name")
    
    model.train(
        train_df=train_df,
        eval_df=eval_df,
        logger=wandb_logger
    )
    
    opened by versae 3
  • codet5 support added

    codet5 support added

    Needed to use this package for my experiemnts,

    Salesforce/codet5-base uses a roberta tokenizer hence this pull request.

    users can now specify : model.from_pretrained("codet5","Salesforce/codet5-base")

    If you want to read through codet5

    Here are the links: https://huggingface.co/Salesforce/codet5-base

    Kind regards, Mosh

    opened by mosh98 3
  • byT5 with version 0.1.2

    byT5 with version 0.1.2

    hi there, it seems that the newest version of simpleT5 does no longer work with byT5. The line elif model_type == "byt5": is commented out. The newest version of transformers seems to use a new type of tokenizer T5TokenizerFast and ByT5TokenizerFast does not exist. Any ideas about how to fix that?

    opened by kimgerdes 2
  • Is there any option for fine-tuning mt5 models instead of training from scratch?

    Is there any option for fine-tuning mt5 models instead of training from scratch?

    Hi, Thanks for the amazing simpleT5 package. I use the following script to train a mt5 model for summarization task.

    from simplet5 import SimpleT5

    model = SimpleT5()

    model.from_pretrained(model_type="mt5", model_name="google/mt5-small")

    model.train(train_df=train_df, eval_df=test_df, source_max_token_len=256, target_max_token_len=64, batch_size=8, max_epochs=3, use_gpu=True, outputdir = "outputs", early_stopping_patience_epochs = 0, )

    When I run this code, training start from scratch. My question is that is there any flag to fine-tune the mt5 model instead of training from scratch?

    opened by farshadfiruzi 2
  • Kernel dies every time when I start training the model

    Kernel dies every time when I start training the model

    Hi Shiva, Thank you very much for a such clean and neat wrapper for training ML models. I am using t5(precisely t5-small) as the base to train my model for summarization. I use the dataset using datasets from huggingface. However, everytime when I initiate the training code, the kernel dies and restarts. Any help here is much appreciated!

    Following is my code.

    Import dependencies

    %%capture
    !pip install --user simplet5==0.1.4
    !pip install transformers
    !pip install wandb
    !pip install pandas
    !pip install datasets
    !pip install --user simpletransformers
    

    Load data using datasets from huggingface

    import pandas as pd
    import warnings
    warnings.filterwarnings("ignore")
    from datasets import load_dataset
    dataset = load_dataset("scitldr")
    

    Preparing the train and eval data

    train_df = dataset["train"].to_pandas().copy()
    train_df.drop(columns=["source_labels","rouge_scores","paper_id"],inplace=True)
    train_df.rename(columns={"source":"source_text","target":"target_text"}, inplace=True)
    train_df.count() ## No NaN found - zero 1992 dataset
    
    train_df['source_text'] = train_df['source_text'].astype('str').str.rstrip(']\'')
    train_df['source_text'] = train_df['source_text'].astype('str').str.lstrip('[\'')
    train_df['target_text'] = train_df['target_text'].astype('str').str.rstrip(']\'')
    train_df['target_text'] = train_df['target_text'].astype('str').str.lstrip('[\'')
    
    train_df["source_text"]=train_df["source_text"].str.replace('\'','')
    train_df["target_text"]=train_df["target_text"].str.replace('\'','')
    train_df["source_text"]="summarize: "+train_df["source_text"]
    train_df.to_csv("train.csv")
    
    eval_df = dataset["validation"].to_pandas().copy()
    eval_df.drop(columns=["source_labels","rouge_scores","paper_id"],inplace=True)
    eval_df.rename(columns={"source":"source_text","target":"target_text"}, inplace=True)
    eval_df.count() ## No NaN found - zero 1992 dataset
    
    eval_df['source_text'] = eval_df['source_text'].astype('str').str.rstrip(']\'')
    eval_df['source_text'] = eval_df['source_text'].astype('str').str.lstrip('[\'')
    eval_df['target_text'] = eval_df['target_text'].astype('str').str.rstrip(']\'')
    eval_df['target_text'] = eval_df['target_text'].astype('str').str.lstrip('[\'')
    
    eval_df["source_text"]=train_df["source_text"].str.replace('\'','')
    eval_df["target_text"]=train_df["target_text"].str.replace('\'','')
    eval_df["source_text"]="summarize: "+train_df["source_text"]
    eval_df.to_csv("eval.csv")
    

    Loading simpleT5 and wandb_logger and finally loading the model and training code

    from simplet5 import SimpleT5
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(project="ask-poc-logger")
    model = SimpleT5()
    model.from_pretrained("t5","t5-small")
    model.train(train_df=train_df[0:100], 
                eval_df=eval_df[0:100],
                source_max_token_len = 512, 
                target_max_token_len = 100,
                batch_size = 2,
                max_epochs = 3,
                use_gpu = True,
                outputdir = "outputs",
                logger = wandb_logger
                )
    

    I am running this code on the following machine. A vertex AI workbench from Google Cloud. N1-Standard-16 machine type with 16 core and 60 GB Memory. And added GPU P100. Any help is much appreciated ! Thanks in advance!

    opened by kkrishnan90 1
  • Is the task string necessary?

    Is the task string necessary?

    Hi,

    I have fine-tuned the model to write a compliment for a person, given the person's profile and it works pretty well. In the training examples, I haven't prepended the string 'summarize :' to the source_string column entries. Is it necessary (does it lead to better results) to prepend the string indicating the task?

    opened by nikogamulin 1
  • Push finished model

    Push finished model

    Is there a way of automatically pushing the checkpoints to the HuggingFace hub? I am running this mainly in Colab. Works great but often the Colab has timed out, and the checkpoints are lost.

    opened by peregilk 2
  • Unicode Charecter training issue

    Unicode Charecter training issue

    I tried to train My model for translating English to Bengali. After Training when I run the code, The output is not Unicode Bengali character.

    I Eat Rice (eng)=> আমি ভাত খাই (Bn)

    this type of data is input to the model while training. After complete, when I tested the model by inputting "I Eat Rice" I was expecting "আমি ভাত খাই" as output. But instead of this, the model gave me "Ich esse Reis." I dont know what kind of language is this. Its not related to bengali.

    opened by rahat10120141 5
  • ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

    ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

    import soundfile as sf from scipy.io import wavfile from IPython.display import Audio from transformers import Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer

    import speech_recognition as sr import io from pydub import AudioSegment

    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

    r = sr.Recognizer() with sr.Microphone(sample_rate=16000) as source: print("speak") while True: audio = r.listen(source) data = io.BytesIO(audio.get_wav_data()) clip = AudioSegment.from_file(data) x = torch.FloatTensor(clip.get_array_of_samples()) print(x)

        inputs = tokenizer(x, sampling_rate=16000, return_tensors='pt', padding='longest').input_values
        logits = model(inputs).logits
        tokens = torch.argmax(logits, axis=-1)
        text = tokenizer.batch_decode(tokens)
    
        print('you said: ', str(text).lower())
    
    opened by Ushanjay 1
  • Saved model name not customizable

    Saved model name not customizable

    def training_epoch_end(self, training_step_outputs): """ save tokenizer and model on epoch end """ self.average_training_loss = np.round( torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(), 4, ) path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(self.average_training_loss)}-val-loss-{str(self.average_validation_loss)}"

    Will be very helpful if you can allow the name customizable (note the 'path' assignment).

    Btw, SimpleT5 is simply cool!

    opened by ke-lara 0
Releases(v0.1.4)
  • v0.1.4(Feb 15, 2022)

    SimpleT5 v0.1.4

    • Added support for all the pytorch-lightning supported loggers #11
    • Added support to save model only at last epoch #21
    • Added dataloader_num_workers parameter to.train( ) method to specify number of worker in train/test/val dataloader #19
    • fixed warnings and made compatible with latest transformers and pytorch-lightning
    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Sep 4, 2021)

  • version-0.1.0(Jul 13, 2021)

    SimpleT5 - version 0.1.0

    • Supports ByT5 model - Thanks to @mapmeld for his contribution
    from simplet5 import SimpleT5
    model = SimpleT5()
    model.from_pretrained("byt5", "google/byt5-small")
    
    • Added precision flag to support mixed precision training
    # train
    model.train(train_df=train_df, # pandas dataframe with 2 columns: source_text & target_text
                eval_df=eval_df, # pandas dataframe with 2 columns: source_text & target_text
                source_max_token_len = 512, 
                target_max_token_len = 128,
                batch_size = 8,
                max_epochs = 5,
                use_gpu = True,
                outputdir = "outputs",
                early_stopping_patience_epochs = 0,
                precision = 32
                )
    
    Source code(tar.gz)
    Source code(zip)
Owner
Shivanand Roy
Data Scientist.
Shivanand Roy
code for modular summarization work published in ACL2021 by Krishna et al

This repository contains the code for running modular summarization pipelines as described in the publication Krishna K, Khosla K, Bigham J, Lipton ZC

Kundan Krishna 6 Jun 04, 2021
PyWorld3 is a Python implementation of the World3 model

The World3 model revisited in Python Install & Hello World3 How to tune your own simulation Licence How to cite PyWorld3 with Bibtex References & ackn

Charles Vanwynsberghe 248 Dec 14, 2022
Code for the paper "Are Sixteen Heads Really Better than One?"

Are Sixteen Heads Really Better than One? This repository contains code to reproduce the experiments in our paper Are Sixteen Heads Really Better than

Paul Michel 143 Dec 14, 2022
justCTF [*] 2020 challenges sources

justCTF [*] 2020 This repo contains sources for justCTF [*] 2020 challenges hosted by justCatTheFish. TLDR: Run a challenge with ./run.sh (requires Do

justCatTheFish 25 Dec 27, 2022
nlpcommon is a python Open Source Toolkit for text classification.

nlpcommon nlpcommon, Python Text Tool. Guide Feature Install Usage Dataset Contact Cite Reference Feature nlpcommon is a python Open Source

xuming 3 May 29, 2022
nlp-tutorial is a tutorial for who is studying NLP(Natural Language Processing) using Pytorch

nlp-tutorial is a tutorial for who is studying NLP(Natural Language Processing) using Pytorch. Most of the models in NLP were implemented with less than 100 lines of code.(except comments or blank li

Tae-Hwan Jung 11.9k Jan 08, 2023
SASE : Self-Adaptive noise distribution network for Speech Enhancement with heterogeneous data of Cross-Silo Federated learning

SASE : Self-Adaptive noise distribution network for Speech Enhancement with heterogeneous data of Cross-Silo Federated learning We propose a SASE mode

Tower 1 Nov 20, 2021
An official implementation for "CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval"

The implementation of paper CLIP4Clip: An Empirical Study of CLIP for End to End Video Clip Retrieval. CLIP4Clip is a video-text retrieval model based

ArrowLuo 456 Jan 06, 2023
This repository contains the code for running the character-level Sandwich Transformers from our ACL 2020 paper on Improving Transformer Models by Reordering their Sublayers.

Improving Transformer Models by Reordering their Sublayers This repository contains the code for running the character-level Sandwich Transformers fro

Ofir Press 53 Sep 26, 2022
Python interface for converting Penn Treebank trees to Stanford Dependencies and Universal Depenencies

PyStanfordDependencies Python interface for converting Penn Treebank trees to Universal Dependencies and Stanford Dependencies. Example usage Start by

David McClosky 64 May 08, 2022
NeoDays-based tileset for the roguelike CDDA (Cataclysm Dark Days Ahead)

NeoDaysPlus Reduced contrast, expanded, and continuously developed version of the CDDA tileset NeoDays that's being completed with new sprites for mis

0 Nov 12, 2022
This repository describes our reproducible framework for assessing self-supervised representation learning from speech

LeBenchmark: a reproducible framework for assessing SSL from speech Self-Supervised Learning (SSL) using huge unlabeled data has been successfully exp

49 Aug 24, 2022
Code Implementation of "Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction".

Span-ASTE: Learning Span-Level Interactions for Aspect Sentiment Triplet Extraction ***** New March 31th, 2022: Scikit-Style API for Easy Usage *****

Chia Yew Ken 111 Dec 23, 2022
Intent parsing and slot filling in PyTorch with seq2seq + attention

PyTorch Seq2Seq Intent Parsing Reframing intent parsing as a human - machine translation task. Work in progress successor to torch-seq2seq-intent-pars

Sean Robertson 159 Apr 04, 2022
📝An easy-to-use package to restore punctuation of the text.

✏️ rpunct - Restore Punctuation This repo contains code for Punctuation restoration. This package is intended for direct use as a punctuation restorat

Daulet Nurmanbetov 72 Dec 30, 2022
Model parallel transformers in JAX and Haiku

Table of contents Mesh Transformer JAX Updates Pretrained Models GPT-J-6B Links Acknowledgments License Model Details Zero-Shot Evaluations Architectu

Ben Wang 4.9k Jan 04, 2023
open-information-extraction-system, build open-knowledge-graph(SPO, subject-predicate-object) by pyltp(version==3.4.0)

中文开放信息抽取系统, open-information-extraction-system, build open-knowledge-graph(SPO, subject-predicate-object) by pyltp(version==3.4.0)

7 Nov 02, 2022
Code release for NeX: Real-time View Synthesis with Neural Basis Expansion

NeX: Real-time View Synthesis with Neural Basis Expansion Project Page | Video | Paper | COLAB | Shiny Dataset We present NeX, a new approach to novel

537 Jan 05, 2023
Pattern Matching in Python

Pattern Matching finalmente chega no Python 3.10. E daí? "Pattern matching", ou "correspondência de padrões" como é conhecido no Brasil. Algumas pesso

Fabricio Werneck 6 Feb 16, 2022
A python gui program to generate reddit text to speech videos from the id of any post.

Reddit text to speech generator A python gui program to generate reddit text to speech videos from the id of any post. Current functionality Generate

Aadvik 17 Dec 19, 2022