Converting a scikit-learn model to ONNX with spox#

spox is a Python library for constructing ONNX graphs that exposes strongly-typed, opset-versioned operator functions rather than the raw onnx.helper API. Each operator is a Python function (e.g. op.MatMul(A, B)) so the construction code is type-safe, IDE-friendly, and always produces a graph that is valid for the chosen opset.

SpoxGraphBuilder implements the same GraphBuilderExtendedProtocol as the default GraphBuilder, but delegates every operator construction call to the corresponding spox opset module. Existing yobx.sklearn converters work without modification: the only change is passing builder_cls=SpoxGraphBuilder to yobx.sklearn.to_onnx().

Covered in this example:

  1. Binary classification pipeline (StandardScaler + LogisticRegression)

  2. Multiclass classification

  3. DecisionTreeClassifier — uses the ai.onnx.ml domain, so it exercises SpoxGraphBuilder’s secondary-domain dispatch.

  4. Visualising the exported ONNX graph.

import spox  # noqa: F401
import numpy as np
import onnxruntime
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from yobx.doc import plot_dot
from yobx.builder.spox import SpoxGraphBuilder
from yobx.sklearn import to_onnx

1. Train a binary classification pipeline#

We train a two-step pipeline: StandardScaler followed by LogisticRegression on a small synthetic dataset (80 samples, 4 features, 2 classes).

rng = np.random.default_rng(0)
X_train = rng.standard_normal((80, 4)).astype(np.float32)
y_train = (X_train[:, 0] + X_train[:, 1] > 0).astype(int)

pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())])
pipe.fit(X_train, y_train)

print("Binary pipeline steps:")
for name, step in pipe.steps:
    print(f"  {name}: {step}")
Binary pipeline steps:
  scaler: StandardScaler()
  clf: LogisticRegression()

2. Convert to ONNX using SpoxGraphBuilder#

The only difference from a plain to_onnx() call is passing builder_cls=SpoxGraphBuilder. The converters for each step in the pipeline will use spox opset modules to emit the ONNX nodes.

onx = to_onnx(pipe, (X_train[:1],), builder_cls=SpoxGraphBuilder)

print(f"\nONNX opset  : {onx.opset_import[0].version}")
print("Node types  :", [n.op_type for n in onx.graph.node])
print("Graph input : ", [(inp.name, inp.type.tensor_type.elem_type) for inp in onx.graph.input])
print("Graph outputs:", [out.name for out in onx.graph.output])
ONNX opset  : 21
Node types  : ['Constant', 'Constant', 'Constant', 'Sub', 'Constant', 'Div', 'Constant', 'Constant', 'Gemm', 'Sigmoid', 'Sub', 'Concat', 'ArgMax', 'Cast', 'Gather', 'Identity', 'Identity']
Graph input :  [('X', 1)]
Graph outputs: ['label', 'probabilities']

3. Run and verify — binary classification#

We run the exported model with onnxruntime and check that class labels and probabilities match scikit-learn’s predictions.

X_test = rng.standard_normal((20, 4)).astype(np.float32)

ref = onnxruntime.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
label_onnx, proba_onnx = ref.run(None, {"X": X_test})

label_sk = pipe.predict(X_test)
proba_sk = pipe.predict_proba(X_test).astype(np.float32)

print("\nFirst 5 labels (sklearn):", label_sk[:5])
print("First 5 labels (ONNX)   :", label_onnx[:5])

assert np.array_equal(label_sk, label_onnx), "Label mismatch!"
assert np.allclose(proba_sk, proba_onnx, atol=1e-5), "Probability mismatch!"
print("\nBinary predictions match ✓")
First 5 labels (sklearn): [0 1 1 1 0]
First 5 labels (ONNX)   : [0 1 1 1 0]

Binary predictions match ✓

4. Multiclass classification#

The same workflow applies to multiclass problems. The LogisticRegression converter switches from a sigmoid-based binary graph to a softmax-based multiclass graph automatically.

X_mc = rng.standard_normal((120, 4)).astype(np.float32)
y_mc = rng.integers(0, 3, size=120)

pipe_mc = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=500))])
pipe_mc.fit(X_mc, y_mc)

X_test_mc = rng.standard_normal((30, 4)).astype(np.float32)
onx_mc = to_onnx(pipe_mc, (X_test_mc[:1],), builder_cls=SpoxGraphBuilder)

ref_mc = onnxruntime.InferenceSession(
    onx_mc.SerializeToString(), providers=["CPUExecutionProvider"]
)
label_mc_onnx, proba_mc_onnx = ref_mc.run(None, {"X": X_test_mc})

label_mc_sk = pipe_mc.predict(X_test_mc)
proba_mc_sk = pipe_mc.predict_proba(X_test_mc).astype(np.float32)

assert np.array_equal(label_mc_sk, label_mc_onnx), "Multiclass label mismatch!"
assert np.allclose(proba_mc_sk, proba_mc_onnx, atol=1e-5), "Multiclass proba mismatch!"
print("Multiclass predictions match ✓")
Multiclass predictions match ✓

5. Visualise the ONNX graph#

plot_dot(onx)
plot sklearn with spox

Total running time of the script: (0 minutes 0.647 seconds)

Related examples

Converting a scikit-learn Pipeline to ONNX

Converting a scikit-learn Pipeline to ONNX

Using sklearn-onnx to convert any scikit-learn estimator

Using sklearn-onnx to convert any scikit-learn estimator

Exporting sklearn estimators as ONNX local functions

Exporting sklearn estimators as ONNX local functions

Gallery generated by Sphinx-Gallery