PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

Related tags

Deep Learningpytorch
Overview

PyTorch-LIT

PyPI version

PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

With the rapid growth of deep learning research, models are becoming increasingly complex in terms of parameters and complexity, making it difficult to run the models on currently available end devices. For example, GPT-J with 6B parameters only needs 24 GB of RAM in full-precision mode to be ready for execution, which may be impossible in most systems; even a powerful GPU like the RTX 2060 with 6 GB of memory can't even contain GPT-J in half-precision mode, making direct inference impossible.

To address this issue when training large models, libraries such as DeepSpeed use offload techniques (e.g., ZeRO) to handle the parameters and make training possible by dividing the weights between devices. In contrast, there is no direct library/framework available for inference.

PyTorch-LIT allows the inference of large models by loading weights as needed from secondary specified memory, which could be disk, CPU, or GPU, allowing the inference of models that do not even fit in the system's main memory simply by trading off time.

Quick Start

  1. Install the library
pip install pytorch-lit
  1. You have to save the model's weight in a way that toolkit can use
from pytorch_lit.export import prepare_params

weights = {} # your model's parameters (state_dict)
# change the directory to save your model and specify data-type
prepare_params(weights, ".models/my-model", dtype="float32")
  1. After preparing the weights, you can infer your model
from pytorch_lit import LitModule

# pass your model construction as a closure, 
# specify weights path and inference device 
model = LitModule.from_params(".models/my-model",
                                  lambda: MyModel(),
                                  device="cuda")
result = model(*arg, **kwargs)
  1. Have fun enjoying the inference of the large model on a lower memory device:)

Examples

The repo's examples directory contains examples. There are currently two examples of GPT-J, one for text generation and the other for extracting hidden states as feature representations.

Development

This is a work in progress that will require further development before it can be considered a stable inference toolkit. Here is a list of potential future developments:

  • Caching and batch loading as many weights as memory allows, with weights being replaced in parallel with future ones (through the order of the execution graph)
  • C++ extension for PyTorch jit, so the solution applies to the majority of production end devices
  • Add functions to make it easier to export large models to onnx or trace with jit
  • Use better and faster format than numpy memmap

Contributions are welcome; to discuss your idea further, open an issue with the discussion tag. Finally, you can submit a pull request to merge your fork.

How does it work?

This implementation was made possible primarily by two ideas:

  • The first issue was that PyTorch initialized the model object's parameters when constructing it, causing the construction to fail when the model couldn't fit into memory. To address this, we proposed temporarily hijacking PyTorch's Parameter class's __new__ method during model construction, allowing us to replace the parameter's tensor with a view from a shared global tensor immediately after creation. By doing so, all parameters use the same shared big tensor as their primary storage, allowing the model to be built and tested with inputs to follow and trace the execution graph.
  • The second issue was the large size of model parameters; in the preparation step, we built a numpy memmap(np.memmap) and saved metadata that provided us with the location of each key in the memmap. This allowed us to read parameters from the memmap as needed. Following that, we use the PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution.

Citation

Please cite PyTorch-LIT if it helps your research. You can use the following BibTeX entry:

@misc{pytorch_lit,
	title = {PyTorch-LIT},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/PyTorch-LIT}},
	year = {2021}
}
You might also like...
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification

FPGA & FreeNet Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification by Zhuo Zheng, Yanfei Zhong, Ailong M

 WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

this is a lite easy to use virtual keyboard project for anyone to use
this is a lite easy to use virtual keyboard project for anyone to use

virtual_Keyboard this is a lite easy to use virtual keyboard project for anyone to use motivation I made this for this year's recruitment for RobEn AA

Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Code & Models for 3DETR - an End-to-end transformer model for 3D object detection
Code & Models for 3DETR - an End-to-end transformer model for 3D object detection

3DETR: An End-to-End Transformer Model for 3D Object Detection PyTorch implementation and models for 3DETR. 3DETR (3D DEtection TRansformer) is a simp

Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models
Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models

Python scripts to detect faces using Python with the BlazeFace Tensorflow Lite models. Tested on Windows 10, Tensorflow 2.4.0 (Python 3.8).

A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] An end-to-end PyTorch framework for image and video classification
An end-to-end PyTorch framework for image and video classification

What's New: March 2021: Added RegNetZ models November 2020: Vision Transformers now available, with training recipes! 2020-11-20: Classy Vision v0.5 R

Comments
  • RuntimeError : OrderdDict mutated during iteration.

    RuntimeError : OrderdDict mutated during iteration.

    Hi, there are new problems. When the model parameters forward, raise a RuntimeError : OrderdDict mutated during iteration. detail as below: Traceback (most recent call last): File "nlp/rct-FPM-rhino/big_model/predict.py", line 24, in result = model(**tokens) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 34, in call return self.forward(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 31, in forward return self.module(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1057, in _call_impl for hook in itertools.chain( RuntimeError: OrderedDict mutated during iteration

    enviroments:

    GPU:NVIDIA GeForce 3090 CUDA version 11.4 pip list: certifi 2021.10.8 charset-normalizer 2.0.8 click 8.0.3 filelock 3.4.0 huggingface-hub 0.2.0 idna 3.3 joblib 1.1.0 numpy 1.21.4 packaging 21.3 Pillow 8.4.0 pip 21.2.4 pyparsing 3.0.6 pytorch-lit 0.1.7 PyYAML 6.0 regex 2021.11.10 requests 2.26.0 sacremoses 0.0.46 setuptools 58.0.4 six 1.16.0 tokenizer 3.3.2 tokenizers 0.10.3 torch 1.9.1+cu111 torchaudio 0.8.1 torchvision 0.9.1+cu111 tqdm 4.62.3 transformers 4.12.5 typing_extensions 4.0.1 urllib3 1.26.7

    I think this problem caused by PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution, when load and unload the parameters,the OrderedDict was be mutated.

    opened by changleilei 9
  • TypeError: <lambda>() missing 1 required positional argument: 'k'

    TypeError: () missing 1 required positional argument: 'k'

    Hello, when i use pytorch-lit prepare a model, got a TypeError as title. The detail as blow:

    File "nlp/rct-FPM-rhino/big_model/prepare_model.py", line 16, in prepare_model prepare_params(model, args.save_path, dtype='float32') File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 19, in prepare_params _params_to_memmap(parameters, path.join(save_dir, "model.bin"), File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 52, in _params_to_memmap param = get_param(k) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 50, in get_param = lambda key: params"get" TypeError: () missing 1 required positional argument: 'k'

    package list:

    certifi 2021.10.8 numpy 1.21.4 pip 21.2.4 pytorch-lit 0.1.6 setuptools 58.0.4 torch 1.10.0 tqdm 4.62.3 typing_extensions 4.0.1 wheel 0.37.0

    model: gpt-j-6B

    Have any suggesstion? Thanks.

    opened by changleilei 1
  • gpt-j generation speed very low

    gpt-j generation speed very low

    The output of gpt-j is very slow, for a 200 output token generation it takes about 20 minutes, for 2048 it takes more than an hour, this significantly limits any experimentation with the model.

    I checked Gpu utilization during inference which is about 1 percent or 4 percent, and gpu memory usage is below 4GB usage, my system has 8GB Gpu memory, if full Gpu is utilized it may be significantly increase the inference speed

    Are their simple hacks to speedup inference time ?

    opened by usama-ahmedkhan 3
  • Weights file format is changed, function partial_loader fails

    Weights file format is changed, function partial_loader fails

    Hi, thanks for your effort for making it easy to load and do inference from large models. I tried your code on a gpt-j model with different model file format, the weight files of the model are in several .pt files not like a single .bin file which your code function partial_loader() expects, does the code work with multiple weight file ? , how can i change it.

    opened by usama-ahmedkhan 4
Releases(0.1.7)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
Official Repository for the paper "Improving Baselines in the Wild".

iWildCam and FMoW baselines (WILDS) This repository was originally forked from the official repository of WILDS datasets (commit 7e103ed) For general

Kazuki Irie 3 Nov 24, 2022
Unofficial implementation of the ImageNet, CIFAR 10 and SVHN Augmentation Policies learned by AutoAugment using pillow

AutoAugment - Learning Augmentation Policies from Data Unofficial implementation of the ImageNet, CIFAR10 and SVHN Augmentation Policies learned by Au

Philip Popien 1.3k Jan 02, 2023
FluxTraining.jl gives you an endlessly extensible training loop for deep learning

A flexible neural net training library inspired by fast.ai

86 Dec 31, 2022
A ssl analyzer which could analyzer target domain's certificate.

ssl_analyzer A ssl analyzer which could analyzer target domain's certificate. Analyze the domain name ssl certificate information according to the inp

vincent 17 Dec 12, 2022
Using machine learning to predict and analyze high and low reader engagement for New York Times articles posted to Facebook.

How The New York Times can increase Engagement on Facebook Using machine learning to understand characteristics of news content that garners "high" Fa

Jessica Miles 0 Sep 16, 2021
A deep-learning pipeline for segmentation of ambiguous microscopic images.

Welcome to Official repository of deepflash2 - a deep-learning pipeline for segmentation of ambiguous microscopic images. Quick Start in 30 seconds se

Matthias Griebel 39 Dec 19, 2022
A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualization

Website, Tutorials, and Docs    Uncertainty Toolbox A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualizatio

Uncertainty Toolbox 1.4k Dec 28, 2022
Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning The predictive learning of spatiotemporal sequences aims to generate future

THUML: Machine Learning Group @ THSS 243 Dec 26, 2022
Code to train models from "Paraphrastic Representations at Scale".

Paraphrastic Representations at Scale Code to train models from "Paraphrastic Representations at Scale". The code is written in Python 3.7 and require

John Wieting 71 Dec 19, 2022
A criticism of a recent paper on buggy image downsampling methods in popular image processing and deep learning libraries.

A criticism of a recent paper on buggy image downsampling methods in popular image processing and deep learning libraries.

70 Jul 12, 2022
(ICCV 2021) PyTorch implementation of Paper "Progressive Correspondence Pruning by Consensus Learning"

CLNet (ICCV 2021) PyTorch implementation of Paper "Progressive Correspondence Pruning by Consensus Learning" [project page] [paper] Citing CLNet If yo

Chen Zhao 22 Aug 26, 2022
People log into different sites every day to get information and browse through these sites one by one

HyperLink People log into different sites every day to get information and browse through these sites one by one. And they are exposed to advertisemen

0 Feb 17, 2022
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

Introduction This repository includes the source code for "Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks", which is pu

machen 11 Nov 27, 2022
Collect super-resolution related papers, data, repositories

Collect super-resolution related papers, data, repositories

WangChaofeng 1.7k Jan 03, 2023
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
[ECCV 2020] Gradient-Induced Co-Saliency Detection

Gradient-Induced Co-Saliency Detection Zhao Zhang*, Wenda Jin*, Jun Xu, Ming-Ming Cheng ⭐ Project Home » The official repo of the ECCV 2020 paper Grad

Zhao Zhang 35 Nov 25, 2022
The Deep Learning with Julia book, using Flux.jl.

Deep Learning with Julia DL with Julia is a book about how to do various deep learning tasks using the Julia programming language and specifically the

Logan Kilpatrick 67 Dec 25, 2022
Transfer SemanticKITTI labeles into other dataset/sensor formats.

LiDAR-Transfer Transfer SemanticKITTI labeles into other dataset/sensor formats. Content Convert datasets (NUSCENES, FORD, NCLT) to KITTI format Minim

Photogrammetry & Robotics Bonn 64 Nov 21, 2022
Foreground-Action Consistency Network for Weakly Supervised Temporal Action Localization

FAC-Net Foreground-Action Consistency Network for Weakly Supervised Temporal Action Localization Linjiang Huang (CUHK), Liang Wang (CASIA), Hongsheng

21 Nov 22, 2022
LogAvgExp - Pytorch Implementation of LogAvgExp

LogAvgExp - Pytorch Implementation of LogAvgExp for Pytorch Install $ pip instal

Phil Wang 31 Oct 14, 2022