Convert scikit-learn models to PyTorch modules

Related tags

Deep Learningsk2torch
Overview

sk2torch

sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript.

Problems solved by this project:

  1. scikit-learn cannot perform inference on a GPU. Models like SVMs have a lot to gain from fast GPU primitives, and converting the models to PyTorch gives immediate access to these primitives.
  2. While scikit-learn supports serialization through pickle, saved models are not reproducible across versions of the library. On the other hand, TorchScript provides a convenient, safe way to save a model with its corresponding implementation. The resulting models can be loaded anywhere that PyTorch is installed, even without importing sk2torch.
  3. While certain models like SVMs and linear classifiers are theoretically end-to-end differentiable, scikit-learn provides no mechanism to compute gradients through trained models. PyTorch provides this functionality mostly for free.

See Usage for a high-level example of using the library. See How it works to see which modules are supported.

For fun, here's a vector field produced by differentiating the probability predictions of a two-class SVM (produced by this script):

A vector field quiver plot with two modes

Usage

First, train a model with scikit-learn as usual:

from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

x, y = create_some_dataset()
model = Pipeline([
    ("center", StandardScaler(with_std=False)),
    ("classify", SGDClassifier()),
])
model.fit(x, y)

Then call sk2torch.wrap on the model to create a PyTorch equivalent:

import sk2torch
import torch

torch_model = sk2torch.wrap(model)
print(torch_model.predict(torch.tensor([[1., 2., 3.]]).double()))

You can save a model with TorchScript:

import torch.jit

torch.jit.script(torch_model).save("path.pt")

# ... sk2torch need not be installed to load the model.
loaded_model = torch.jit.load("path.pt")

For a full example of training a model and using its PyTorch translation, see examples/svm_vector_field.py.

How it works

sk2torch contains PyTorch re-implementations of supported scikit-learn models. For a supported estimator X, a class TorchX in sk2torch will be able to read the attributes of X and convert them to torch.Tensor or simple Python types. TorchX subclasses torch.nn.Module and has a method for each inference API of X (e.g. predict, decision_function, etc.).

Which modules are supported? The easiest way to get an up-to-date list is via the supported_classes() function, which returns all wrap()able scikit-learn classes:

>>> import sk2torch
>>> sk2torch.supported_classes()
[<class 'sklearn.tree._classes.DecisionTreeClassifier'>, <class 'sklearn.tree._classes.DecisionTreeRegressor'>, <class 'sklearn.dummy.DummyClassifier'>, <class 'sklearn.ensemble._gb.GradientBoostingClassifier'>, <class 'sklearn.preprocessing._label.LabelBinarizer'>, <class 'sklearn.svm._classes.LinearSVC'>, <class 'sklearn.svm._classes.LinearSVR'>, <class 'sklearn.neural_network._multilayer_perceptron.MLPClassifier'>, <class 'sklearn.kernel_approximation.Nystroem'>, <class 'sklearn.pipeline.Pipeline'>, <class 'sklearn.linear_model._stochastic_gradient.SGDClassifier'>, <class 'sklearn.preprocessing._data.StandardScaler'>, <class 'sklearn.svm._classes.SVC'>, <class 'sklearn.svm._classes.NuSVC'>, <class 'sklearn.svm._classes.SVR'>, <class 'sklearn.svm._classes.NuSVR'>, <class 'sklearn.compose._target.TransformedTargetRegressor'>]

Comparison to sklearn-onnx

sklearn-onnx is an open source package for converting trained scikit-learn models into ONNX. Like sk2torch, sklearn-onnx re-implements inference functions for various models, meaning that it can also provide serialization and GPU acceleration for supported modules.

Naturally, neither library will support modules that aren't manually ported. As a result, the two libraries support different subsets of all available models/methods. For example, sk2torch supports the SVC probability prediction methods predict_proba and predict_log_prob, whereas sklearn-onnx does not.

While sklearn-onnx exports models to ONNX, sk2torch exports models to Python objects with familiar method names that can be fine-tuned, backpropagated through, and serialized in a user-friendly way. PyTorch is strictly more general than ONNX, since PyTorch models can be converted to ONNX if desired.

Owner
Alex Nichol
Web developer, math geek, and AI enthusiast.
Alex Nichol
Code and data to accompany the camera-ready version of "Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation" in EMNLP 2021

Code and data to accompany the camera-ready version of "Cross-Attention is All You Need: Adapting Pretrained Transformers for Machine Translation" in EMNLP 2021

Mozhdeh Gheini 16 Jul 16, 2022
This is Official implementation for "Pose-guided Feature Disentangling for Occluded Person Re-Identification Based on Transformer" in AAAI2022

PFD:Pose-guided Feature Disentangling for Occluded Person Re-identification based on Transformer This repo is the official implementation of "Pose-gui

Tao Wang 93 Dec 18, 2022
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time

T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time The first Lidar-only odometry framework with high performance based on tr

Pengwei Zhou 183 Dec 01, 2022
Code for the paper "Next Generation Reservoir Computing"

Next Generation Reservoir Computing This is the code for the results and figures in our paper "Next Generation Reservoir Computing". They are written

OSU QuantInfo Lab 105 Dec 20, 2022
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Seq2seq - Sequence to Sequence Learning with Keras

Seq2seq Sequence to Sequence Learning with Keras Hi! You have just found Seq2Seq. Seq2Seq is a sequence to sequence learning add-on for the python dee

Fariz Rahman 3.1k Dec 18, 2022
The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding.

SuperGen The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding. Requirements Before running, you

Yu Meng 38 Dec 12, 2022
The source code for the Cutoff data augmentation approach proposed in this paper: "A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation".

Cutoff: A Simple Data Augmentation Approach for Natural Language This repository contains source code necessary to reproduce the results presented in

Dinghan Shen 49 Dec 22, 2022
This repository contains part of the code used to make the images visible in the article "How does an AI Imagine the Universe?" published on Towards Data Science.

Generative Adversarial Network - Generating Universe This repository contains part of the code used to make the images visible in the article "How doe

Davide Coccomini 9 Dec 18, 2022
MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc.

MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc. ⭐⭐⭐⭐⭐

568 Jan 04, 2023
A Context-aware Visual Attention-based training pipeline for Object Detection from a Webpage screenshot!

CoVA: Context-aware Visual Attention for Webpage Information Extraction Abstract Webpage information extraction (WIE) is an important step to create k

Keval Morabia 41 Jan 01, 2023
Self-Supervised Speech Pre-training and Representation Learning Toolkit.

What's New Sep 2021: We host a challenge in AAAI workshop: The 2nd Self-supervised Learning for Audio and Speech Processing! See SUPERB official site

s3prl 1.6k Jan 08, 2023
Open-Ended Commonsense Reasoning (NAACL 2021)

Open-Ended Commonsense Reasoning Quick links: [Paper] | [Video] | [Slides] | [Documentation] This is the repository of the paper, Differentiable Open-

(Bill) Yuchen Lin 31 Oct 19, 2022
Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks

flownet2-pytorch Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks. Multiple GPU training is supported, a

NVIDIA Corporation 2.8k Dec 27, 2022
Official repository for "Restormer: Efficient Transformer for High-Resolution Image Restoration". SOTA for motion deblurring, image deraining, denoising (Gaussian/real data), and defocus deblurring.

Restormer: Efficient Transformer for High-Resolution Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan,

Syed Waqas Zamir 906 Dec 30, 2022
Open source code for the paper of Neural Sparse Voxel Fields.

Neural Sparse Voxel Fields (NSVF) Project Page | Video | Paper | Data Photo-realistic free-viewpoint rendering of real-world scenes using classical co

Meta Research 647 Dec 27, 2022
StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks

StackGAN Pytorch implementation Inception score evaluation StackGAN-v2-pytorch Tensorflow implementation for reproducing main results in the paper Sta

Han Zhang 1.8k Dec 21, 2022
[ECCVW2020] Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DiMP)

Feel free to visit my homepage Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DIMP) [ECCVW2020 paper] Presentation

Seokeon Choi 35 Oct 26, 2022
Escaping the Gradient Vanishing: Periodic Alternatives of Softmax in Attention Mechanism

Period-alternatives-of-Softmax Experimental Demo for our paper 'Escaping the Gradient Vanishing: Periodic Alternatives of Softmax in Attention Mechani

slwang9353 0 Sep 06, 2021