Convert BART models to ONNX with quantization. 3X reduction in size, and upto 3X boost in inference speed

Overview

fast-Bart

Reduction of BART model size by 3X, and boost in inference speed up to 3X

BART implementation of the fastT5 library (https://github.com/Ki6an/fastT5)

Pytorch model -> ONNX model -> Quantized ONNX model


Install

Install using requirements.txt file

git clone https://github.com/siddharth-sharma7/fast-Bart
cd fast-Bart
pip install -r requirements.txt

Usage

The export_and_get_onnx_model() method exports the given pretrained Bart model to onnx, quantizes it and runs it on the onnxruntime with default settings. The returned model from this method supports the generate() method of huggingface.

If you don't wish to quantize the model then use quantized=False in the method.

from fastBart import export_and_get_onnx_model
from transformers import AutoTokenizer

model_name = 'facebook/bart-base'
model = export_and_get_onnx_model(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input = "This is a very long sentence and needs to be summarized."
token = tokenizer(input, return_tensors='pt')

tokens = model.generate(input_ids=token['input_ids'],
               attention_mask=token['attention_mask'],
               num_beams=3)

output = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
print(output)

to run the already exported model use get_onnx_model()

you can customize the whole pipeline as shown in the below code example:

from fastBart import (OnnxBart, get_onnx_runtime_sessions,
                    generate_onnx_representation, quantize)
from transformers import AutoTokenizer

model_or_model_path = 'facebook/bart-base'

# Step 1. convert huggingfaces bart model to onnx
onnx_model_paths = generate_onnx_representation(model_or_model_path)

# Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
# The process is slow for the decoder and init-decoder onnx files (can take up to 15 mins)
quant_model_paths = quantize(onnx_model_paths)

# step 3. setup onnx runtime
model_sessions = get_onnx_runtime_sessions(quant_model_paths)

# step 4. get the onnx model
model = OnnxBart(model_or_model_path, model_sessions)

                      ...
custom output paths

By default, fastBart creates a models-bart folder in the current directory and stores all the models. You can provide a custom path for a folder to store the exported models. And to run already exported models that are stored in a custom folder path: use get_onnx_model(onnx_models_path="/path/to/custom/folder/")

from fastBart import export_and_get_onnx_model, get_onnx_model

model_name = "facebook/bart-base"
custom_output_path = "/path/to/custom/folder/"

# 1. stores models to custom_output_path
model = export_and_get_onnx_model(model_name, custom_output_path)

# 2. run already exported models that are stored in custom path
# model = get_onnx_model(model_name, custom_output_path)

Functionalities

  • Export any pretrained Bart model to ONNX easily.
  • The exported model supports beam search and greedy search and more via generate() method.
  • Reduce the model size by 3X using quantization.
  • Up to 3X speedup compared to PyTorch execution for greedy search and 2-3X for beam search.
Owner
Siddharth Sharma
Machine learning | NLP | Computer Vision
Siddharth Sharma
Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21)

AdvRush Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21) Environmental Set-up Python == 3.6.12, PyTorch =

11 Dec 10, 2022
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

139 Dec 29, 2022
A machine learning package for streaming data in Python. The other ancestor of River.

scikit-multiflow is a machine learning package for streaming data in Python. creme and scikit-multiflow are merging into a new project called River. W

670 Dec 30, 2022
Unified learning approach for egocentric hand gesture recognition and fingertip detection

Unified Gesture Recognition and Fingertip Detection A unified convolutional neural network (CNN) algorithm for both hand gesture recognition and finge

Mohammad 227 Dec 25, 2022
(under submission) Bayesian Integration of a Generative Prior for Image Restoration

BIGPrior: Towards Decoupling Learned Prior Hallucination and Data Fidelity in Image Restoration Authors: Majed El Helou, and Sabine Süsstrunk {Note: p

Majed El Helou 22 Dec 17, 2022
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 58 Dec 06, 2022
The ARCA23K baseline system

ARCA23K Baseline System This is the source code for the baseline system associated with the ARCA23K dataset. Details about ARCA23K and the baseline sy

4 Jul 02, 2022
Code for the paper: "On the Bottleneck of Graph Neural Networks and Its Practical Implications"

On the Bottleneck of Graph Neural Networks and its Practical Implications This is the official implementation of the paper: On the Bottleneck of Graph

75 Dec 22, 2022
We propose a new method for effective shadow removal by regarding it as an exposure fusion problem.

Auto-exposure fusion for single-image shadow removal We propose a new method for effective shadow removal by regarding it as an exposure fusion proble

Qing Guo 146 Dec 31, 2022
A framework that allows people to write their own Rocket League bots.

YOU PROBABLY SHOULDN'T PULL THIS REPO Bot Makers Read This! If you just want to make a bot, you don't need to be here. Instead, start with one of thes

543 Dec 20, 2022
LSTM and QRNN Language Model Toolkit for PyTorch

LSTM and QRNN Language Model Toolkit This repository contains the code used for two Salesforce Research papers: Regularizing and Optimizing LSTM Langu

Salesforce 1.9k Jan 08, 2023
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 06, 2022
[CVPR'21] Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild

IVOS-W Paper Learning to Recommend Frame for Interactive Video Object Segmentation in the Wild Zhaoyun Yin, Jia Zheng, Weixin Luo, Shenhan Qian, Hanli

SVIP Lab 38 Dec 12, 2022
Numbering permanent and deciduous teeth via deep instance segmentation in panoramic X-rays

Numbering permanent and deciduous teeth via deep instance segmentation in panoramic X-rays In this repo, you will find the instructions on how to requ

Intelligent Vision Research Lab 4 Jul 21, 2022
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
OpenDelta - An Open-Source Framework for Paramter Efficient Tuning.

OpenDelta is a toolkit for parameter efficient methods (we dub it as delta tuning), by which users could flexibly assign (or add) a small amount parameters to update while keeping the most paramters

THUNLP 386 Dec 26, 2022
Implementing Vision Transformer (ViT) in PyTorch

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

2 Dec 24, 2021
This repository contains the code for the paper "PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization"

PIFu: Pixel-Aligned Implicit Function for High-Resolution Clothed Human Digitization News: [2020/05/04] Added EGL rendering option for training data g

Shunsuke Saito 1.5k Jan 03, 2023
Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving

GSAN Introduction Code for paper GSAN: Graph Self-Attention Network for Learning Spatial-Temporal Interaction Representation in Autonomous Driving, wh

YE Luyao 6 Oct 27, 2022
NaturalProofs: Mathematical Theorem Proving in Natural Language

NaturalProofs: Mathematical Theorem Proving in Natural Language NaturalProofs: Mathematical Theorem Proving in Natural Language Sean Welleck, Jiacheng

Sean Welleck 83 Jan 05, 2023