Convert scikit-learn models to PyTorch modules

Related tags

Deep Learningsk2torch
Overview

sk2torch

sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript.

Problems solved by this project:

  1. scikit-learn cannot perform inference on a GPU. Models like SVMs have a lot to gain from fast GPU primitives, and converting the models to PyTorch gives immediate access to these primitives.
  2. While scikit-learn supports serialization through pickle, saved models are not reproducible across versions of the library. On the other hand, TorchScript provides a convenient, safe way to save a model with its corresponding implementation. The resulting models can be loaded anywhere that PyTorch is installed, even without importing sk2torch.
  3. While certain models like SVMs and linear classifiers are theoretically end-to-end differentiable, scikit-learn provides no mechanism to compute gradients through trained models. PyTorch provides this functionality mostly for free.

See Usage for a high-level example of using the library. See How it works to see which modules are supported.

For fun, here's a vector field produced by differentiating the probability predictions of a two-class SVM (produced by this script):

A vector field quiver plot with two modes

Usage

First, train a model with scikit-learn as usual:

from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

x, y = create_some_dataset()
model = Pipeline([
    ("center", StandardScaler(with_std=False)),
    ("classify", SGDClassifier()),
])
model.fit(x, y)

Then call sk2torch.wrap on the model to create a PyTorch equivalent:

import sk2torch
import torch

torch_model = sk2torch.wrap(model)
print(torch_model.predict(torch.tensor([[1., 2., 3.]]).double()))

You can save a model with TorchScript:

import torch.jit

torch.jit.script(torch_model).save("path.pt")

# ... sk2torch need not be installed to load the model.
loaded_model = torch.jit.load("path.pt")

For a full example of training a model and using its PyTorch translation, see examples/svm_vector_field.py.

How it works

sk2torch contains PyTorch re-implementations of supported scikit-learn models. For a supported estimator X, a class TorchX in sk2torch will be able to read the attributes of X and convert them to torch.Tensor or simple Python types. TorchX subclasses torch.nn.Module and has a method for each inference API of X (e.g. predict, decision_function, etc.).

Which modules are supported? The easiest way to get an up-to-date list is via the supported_classes() function, which returns all wrap()able scikit-learn classes:

>>> import sk2torch
>>> sk2torch.supported_classes()
[<class 'sklearn.tree._classes.DecisionTreeClassifier'>, <class 'sklearn.tree._classes.DecisionTreeRegressor'>, <class 'sklearn.dummy.DummyClassifier'>, <class 'sklearn.ensemble._gb.GradientBoostingClassifier'>, <class 'sklearn.preprocessing._label.LabelBinarizer'>, <class 'sklearn.svm._classes.LinearSVC'>, <class 'sklearn.svm._classes.LinearSVR'>, <class 'sklearn.neural_network._multilayer_perceptron.MLPClassifier'>, <class 'sklearn.kernel_approximation.Nystroem'>, <class 'sklearn.pipeline.Pipeline'>, <class 'sklearn.linear_model._stochastic_gradient.SGDClassifier'>, <class 'sklearn.preprocessing._data.StandardScaler'>, <class 'sklearn.svm._classes.SVC'>, <class 'sklearn.svm._classes.NuSVC'>, <class 'sklearn.svm._classes.SVR'>, <class 'sklearn.svm._classes.NuSVR'>, <class 'sklearn.compose._target.TransformedTargetRegressor'>]

Comparison to sklearn-onnx

sklearn-onnx is an open source package for converting trained scikit-learn models into ONNX. Like sk2torch, sklearn-onnx re-implements inference functions for various models, meaning that it can also provide serialization and GPU acceleration for supported modules.

Naturally, neither library will support modules that aren't manually ported. As a result, the two libraries support different subsets of all available models/methods. For example, sk2torch supports the SVC probability prediction methods predict_proba and predict_log_prob, whereas sklearn-onnx does not.

While sklearn-onnx exports models to ONNX, sk2torch exports models to Python objects with familiar method names that can be fine-tuned, backpropagated through, and serialized in a user-friendly way. PyTorch is strictly more general than ONNX, since PyTorch models can be converted to ONNX if desired.

Owner
Alex Nichol
Web developer, math geek, and AI enthusiast.
Alex Nichol
Learning Temporal Consistency for Low Light Video Enhancement from Single Images (CVPR2021)

StableLLVE This is a Pytorch implementation of "Learning Temporal Consistency for Low Light Video Enhancement from Single Images" in CVPR 2021, by Fan

99 Dec 19, 2022
A PyTorch implementation for Unsupervised Domain Adaptation by Backpropagation(DANN), support Office-31 and Office-Home dataset

DANN A PyTorch implementation for Unsupervised Domain Adaptation by Backpropagation Prerequisites Linux or OSX NVIDIA GPU + CUDA (may CuDNN) and corre

8 Apr 16, 2022
Self-Supervised Monocular DepthEstimation with Internal Feature Fusion(arXiv), BMVC2021

DIFFNet This repo is for Self-Supervised Monocular DepthEstimation with Internal Feature Fusion(arXiv), BMVC2021 A new backbone for self-supervised de

Hang 94 Dec 25, 2022
PyJokes - Joking around with Python library pyjokes

Hi, it's Muhaimin again πŸ‘‹ This is something unorthodox but cool. Don't forget t

Muhaimin A. Salay Kanton 1 Feb 02, 2022
A Python training and inference implementation of Yolov5 helmet detection in Jetson Xavier nx and Jetson nano

yolov5-helmet-detection-python A Python implementation of Yolov5 to detect head or helmet in the wild in Jetson Xavier nx and Jetson nano. In Jetson X

12 Dec 05, 2022
Privacy-Preserving Portrait Matting [ACM MM-21]

Privacy-Preserving Portrait Matting [ACM MM-21] This is the official repository of the paper Privacy-Preserving Portrait Matting. Jizhizi Liβˆ—, Sihan M

Jizhizi_Li 212 Dec 27, 2022
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
Homepage of paper: Paint Transformer: Feed Forward Neural Painting with Stroke Prediction, ICCV 2021.

Paint Transformer: Feed Forward Neural Painting with Stroke Prediction [Paper] [PaddlePaddle Implementation] Homepage of paper: Paint Transformer: Fee

442 Dec 16, 2022
Unicorn can be used for performance analyses of highly configurable systems with causal reasoning

Unicorn can be used for performance analyses of highly configurable systems with causal reasoning. Users or developers can query Unicorn for a performance task.

AISys Lab 27 Jan 05, 2023
Clockwork Convnets for Video Semantic Segmentation

Clockwork Convnets for Video Semantic Segmentation This is the reference implementation of arxiv:1608.03609: Clockwork Convnets for Video Semantic Seg

Evan Shelhamer 141 Nov 21, 2022
The official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang Gong, Yi Ma. "Fully Convolutional Line Parsing." *.

F-Clip β€” Fully Convolutional Line Parsing This repository contains the official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang

Xili Dai 115 Dec 28, 2022
Repo for the Tutorials of Day1-Day3 of the Nordic Probabilistic AI School 2021 (https://probabilistic.ai/)

ProbAI 2021 - Probabilistic Programming and Variational Inference Tutorial with Pryo Day 1 (June 14) Slides Notebook: students_PPLs_Intro Notebook: so

PGM-Lab 46 Nov 01, 2022
A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning.

Open3DSOT A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning. The official code release of BAT an

Kangel Zenn 172 Dec 23, 2022
TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning

TransZero++ This repository contains the testing code for the paper "TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning" submitted

Shiming Chen 6 Aug 16, 2022
Differential fuzzing for the masses!

NEZHA NEZHA is an efficient and domain-independent differential fuzzer developed at Columbia University. NEZHA exploits the behavioral asymmetries bet

147 Dec 05, 2022
A coin flip game in which you can put the amount of money below or equal to 1000 and then choose heads or tail

COIN_FLIPPY ##This is a simple example package. You can use Github-flavored Markdown to write your content. Coinflippy A coin flip game in which you c

2 Dec 26, 2021
This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong Poisons

Adversarial poison generation and evaluation. This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong

31 Nov 01, 2022
[Pedestron] Generalizable Pedestrian Detection: The Elephant In The Room. @ CVPR2021

Pedestron Pedestron is a MMdetection based repository, that focuses on the advancement of research on pedestrian detection. We provide a list of detec

Irtiza Hasan 594 Jan 05, 2023
Laplace Redux -- Effortless Bayesian Deep Learning

Laplace Redux - Effortless Bayesian Deep Learning This repository contains the code to run the experiments for the paper Laplace Redux - Effortless Ba

Runa Eschenhagen 28 Dec 07, 2022
Hydra: an Extensible Fuzzing Framework for Finding Semantic Bugs in File Systems

Hydra: An Extensible Fuzzing Framework for Finding Semantic Bugs in File Systems Paper Finding Semantic Bugs in File Systems with an Extensible Fuzzin

gts3.org (<a href=[email protected])"> 129 Dec 15, 2022