scikit-learn Export to ONNX#

See also

Numpy-Tracing and FunctionTransformer — the numpy-tracing mechanism used by FunctionTransformer is documented in the core design section.

A basic scikit-learn model may look like the following, a scaler following by an estimator. Every model can be converter with model to_onnx().

Pipeline(steps=[('scaler', StandardScaler()), ('clf', LogisticRegression())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Models based on scikit-learn are made of a custom collection of known transformers or estimators. The main function functions has to call every converter for every piece and assembles the result into a single ONNX model.

The custom collection is mainly implemented through the classes sklearn.pipeline.Pipeline, sklearn.pipeline.FeatureUnion and sklearn.compose.ColumnTransformer. Everything else is well defined and can be mapped to its converted ONNX code. It also implemented through meta-estimators combining others such as sklearn.multiclass.OneVsRestClassifier or sklearn.ensemble.VotingClassifier.

Common API

Putting ONNX node together in a model is not difficult but almost everybody already implemented its own way of doing, ir-py, onnxscript, spox. Every converting library has also its own: sklearn-onnx, tensorflow-onnx, onnxmltools… The choice was made not to create a new one but more to define what the converters expect to find in a class classed GraphBuilder. It then becomes possible to create a bridge such as yobx.builder.onnxscript.OnnxScriptGraphBuilder which implements this API for every known way. See Expected API for further details.

Opsets

yobx.sklearn.to_onnx() converts scikit-learn models into ONNX. The function exposes arguments target_opset. The conversion is done for opset 18 if target_opset==18. The conversion may includes optimized kernels for onnxruntime if target_opsets={'': 18, 'com.microsoft': 1} (see ONNX Runtime Contrib Ops (com.microsoft domain)).

Discrepancies

scikit-learn==1.8 is more strict with computation types and the number of discrepancies is reduced. Switch to float32 in a matrix multiplication when the order of magnitude of the coefficient is quite large usually introduces discrepancies. That is often the case when a matrix is the inverse of another one (see Float32 vs Float64: precision loss with PLSRegression). Prior to that, it was not rare the see huge difference when using a model just after normalizing the data. The normalizer was implicitly switching the type to float64 while ONNX was keeping float32. If followed by a tree, a small difference could make the model choose a different decision path and produce a very different output.

Finally, the example given at the top of the page would be converted into the mode which follows.

import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from yobx.sklearn import to_onnx
from yobx.helpers.dot_helper import to_dot

rng = np.random.default_rng(0)
X = rng.standard_normal((20, 4)).astype(np.float32)
y = (X[:, 0] > 0).astype(np.int64)

pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegression()),
]).fit(X, y)

model = to_onnx(pipe, (X,))
print("DOT-SECTION", to_dot(model))

digraph { graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8]; node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box]; edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0]; I_0 [label="X\nFLOAT(batch,4)", fillcolor="#aaeeaa"]; i_1 [label="init1_s1x4_\nFLOAT(1, 4)", fillcolor="#cccc00"]; Sub_2 [label="Sub\n(., [-0.11890265, 0.053688034, 0.25009614, 0.29648674])", fillcolor="#cccccc"]; Div_3 [label="Div\n(., [0.8643964, 0.91257733, 1.0399438, 0.9712117])", fillcolor="#cccccc"]; Gemm_4 [label="Gemm(., ., [-0.6392022])", fillcolor="#cccccc"]; Sigmoid_5 [label="Sigmoid(.)", fillcolor="#cccccc"]; Sub_6 [label="Sub([1.0], .)", fillcolor="#cccccc"]; Concat_7 [label="Concat(., ., axis=-1)", fillcolor="#cccccc"]; ArgMax_8 [label="ArgMax(., axis=1)", fillcolor="#cccccc"]; Gather_9 [label="Gather([0, 1], ., axis=0)", fillcolor="#cccccc"]; I_0 -> Sub_2 [label="FLOAT(batch,4)"]; Sub_2 -> Div_3; Div_3 -> Gemm_4; i_1 -> Gemm_4 [label="FLOAT(1, 4)"]; Gemm_4 -> Sigmoid_5; Sigmoid_5 -> Sub_6; Sub_6 -> Concat_7; Sigmoid_5 -> Concat_7; Concat_7 -> ArgMax_8; ArgMax_8 -> Gather_9; O_10 [label="label\n", fillcolor="#aaaaee"]; Gather_9 -> O_10; O_11 [label="probabilities\n", fillcolor="#aaaaee"]; Concat_7 -> O_11; }