Converting a scikit-learn Pipeline to ONNX#

yobx.sklearn.to_onnx() converts fitted scikit-learn estimators and pipelines into onnx.ModelProto objects that can be executed with any ONNX-compatible runtime.

The converter covers the following estimators (see yobx.sklearn for the full registry):

The workflow is:

  1. Train a scikit-learn estimator (or pipeline) as usual.

  2. Call yobx.sklearn.to_onnx() with a representative dummy input to convert the fitted model into an ONNX graph.

  3. Run the ONNX model with any ONNX runtime — this example uses onnxruntime.

  4. Verify that the ONNX outputs match scikit-learn’s predictions.

import numpy as np
import onnxruntime
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from yobx.doc import plot_dot
from yobx.sklearn import to_onnx

1. Train a scikit-learn pipeline#

We train a simple two-step pipeline: StandardScaler followed by LogisticRegression. The dataset has eighty samples and four features with two classes.

rng = np.random.default_rng(0)
X_train = rng.standard_normal((80, 4)).astype(np.float32)
y_train = (X_train[:, 0] + X_train[:, 1] > 0).astype(int)

pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())])
pipe.fit(X_train, y_train)

print("Pipeline steps:")
for name, step in pipe.steps:
    print(f"  {name}: {step}")
Pipeline steps:
  scaler: StandardScaler()
  clf: LogisticRegression()

2. Convert to ONNX#

yobx.sklearn.to_onnx() requires a representative dummy input (a NumPy array) so it can infer the input dtype and shape. The first axis is automatically treated as the batch dimension.

onx = to_onnx(pipe, (X_train[:1],))

print(f"\nONNX model opset : {onx.opset_import[0].version}")
print(f"Number of nodes  : {len(onx.graph.node)}")
print("Node op-types    :", [n.op_type for n in onx.graph.node])
print(
    "Graph inputs     :", [(inp.name, inp.type.tensor_type.elem_type) for inp in onx.graph.input]
)
print("Graph outputs    :", [out.name for out in onx.graph.output])
ONNX model opset : 21
Number of nodes  : 8
Node op-types    : ['Sub', 'Div', 'Gemm', 'Sigmoid', 'Sub', 'Concat', 'ArgMax', 'Gather']
Graph inputs     : [('X', 1)]
Graph outputs    : ['label', 'probabilities']

3. Run the ONNX model and compare outputs#

We run the converted model on a held-out test set and verify that the ONNX predictions match those produced by scikit-learn.

X_test = rng.standard_normal((20, 4)).astype(np.float32)
y_test = (X_test[:, 0] + X_test[:, 1] > 0).astype(int)

ref = onnxruntime.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
label_onnx, proba_onnx = ref.run(None, {"X": X_test})

label_sk = pipe.predict(X_test)
proba_sk = pipe.predict_proba(X_test).astype(np.float32)

print("\nFirst 5 labels  (sklearn):", label_sk[:5])
print("First 5 labels  (ONNX)   :", label_onnx[:5])
print("First 5 probas  (sklearn):", proba_sk[:5])
print("First 5 probas  (ONNX)   :", proba_onnx[:5])

assert np.array_equal(label_sk, label_onnx), "Label mismatch!"
assert np.allclose(proba_sk, proba_onnx, atol=1e-5), "Probability mismatch!"
print("\nAll predictions match ✓")
First 5 labels  (sklearn): [0 1 1 1 0]
First 5 labels  (ONNX)   : [0 1 1 1 0]
First 5 probas  (sklearn): [[0.6968902  0.30310982]
 [0.06376062 0.93623936]
 [0.38313922 0.6168608 ]
 [0.16703625 0.83296376]
 [0.63305384 0.36694616]]
First 5 probas  (ONNX)   : [[0.6968902  0.30310982]
 [0.06376064 0.93623936]
 [0.3831392  0.6168608 ]
 [0.16703618 0.8329638 ]
 [0.63305384 0.36694616]]

All predictions match ✓

4. Multiclass pipeline#

The same API works transparently for multiclass problems. The LogisticRegression converter automatically switches from a sigmoid-based binary graph to a softmax-based multiclass graph.

X_mc = rng.standard_normal((120, 4)).astype(np.float32)
y_mc = rng.integers(0, 3, size=120)

pipe_mc = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=500))])
pipe_mc.fit(X_mc, y_mc)

X_test_mc = rng.standard_normal((30, 4)).astype(np.float32)
onx_mc = to_onnx(pipe_mc, (X_test_mc[:1],))

ref_mc = onnxruntime.InferenceSession(
    onx_mc.SerializeToString(), providers=["CPUExecutionProvider"]
)
label_mc_onnx, proba_mc_onnx = ref_mc.run(None, {"X": X_test_mc})

label_mc_sk = pipe_mc.predict(X_test_mc)
proba_mc_sk = pipe_mc.predict_proba(X_test_mc).astype(np.float32)

assert np.array_equal(label_mc_sk, label_mc_onnx), "Multiclass label mismatch!"
assert np.allclose(proba_mc_sk, proba_mc_onnx, atol=1e-5), "Multiclass proba mismatch!"
print("Multiclass predictions match ✓")
Multiclass predictions match ✓

5. Visualize the ONNX graph#

plot_dot(onx)
plot sklearn pipeline

Total running time of the script: (0 minutes 0.292 seconds)

Related examples

Converting a scikit-learn model to ONNX with spox

Converting a scikit-learn model to ONNX with spox

Converting a scikit-learn KMeans to ONNX

Converting a scikit-learn KMeans to ONNX

Converting sksurv IPCRidge to ONNX

Converting sksurv IPCRidge to ONNX

Gallery generated by Sphinx-Gallery