.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_sklearn/plot_sklearn_with_spox.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sklearn_plot_sklearn_with_spox.py: .. _l-plot-sklearn-with-spox: Converting a scikit-learn model to ONNX with spox ================================================== :epkg:`spox` is a Python library for constructing ONNX graphs that exposes strongly-typed, opset-versioned operator functions rather than the raw ``onnx.helper`` API. Each operator is a Python function (e.g. ``op.MatMul(A, B)``) so the construction code is type-safe, IDE-friendly, and always produces a graph that is valid for the chosen opset. :class:`~yobx.builder.spox.SpoxGraphBuilder` implements the same :class:`~yobx.typing.GraphBuilderExtendedProtocol` as the default :class:`~yobx.xbuilder.GraphBuilder`, but delegates every operator construction call to the corresponding :epkg:`spox` opset module. Existing :mod:`yobx.sklearn` converters work without modification: the only change is passing ``builder_cls=SpoxGraphBuilder`` to :func:`yobx.sklearn.to_onnx`. Covered in this example: 1. Binary classification pipeline (``StandardScaler`` + ``LogisticRegression``) 2. Multiclass classification 3. ``DecisionTreeClassifier`` — uses the ``ai.onnx.ml`` domain, so it exercises :class:`~yobx.builder.spox.SpoxGraphBuilder`'s secondary-domain dispatch. 4. Visualising the exported ONNX graph. .. GENERATED FROM PYTHON SOURCE LINES 32-44 .. code-block:: Python import spox # noqa: F401 import numpy as np import onnxruntime from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from yobx.doc import plot_dot from yobx.builder.spox import SpoxGraphBuilder from yobx.sklearn import to_onnx .. GENERATED FROM PYTHON SOURCE LINES 45-51 1. Train a binary classification pipeline ------------------------------------------ We train a two-step pipeline: ``StandardScaler`` followed by ``LogisticRegression`` on a small synthetic dataset (80 samples, 4 features, 2 classes). .. GENERATED FROM PYTHON SOURCE LINES 51-63 .. code-block:: Python rng = np.random.default_rng(0) X_train = rng.standard_normal((80, 4)).astype(np.float32) y_train = (X_train[:, 0] + X_train[:, 1] > 0).astype(int) pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())]) pipe.fit(X_train, y_train) print("Binary pipeline steps:") for name, step in pipe.steps: print(f" {name}: {step}") .. rst-class:: sphx-glr-script-out .. code-block:: none Binary pipeline steps: scaler: StandardScaler() clf: LogisticRegression() .. GENERATED FROM PYTHON SOURCE LINES 64-71 2. Convert to ONNX using SpoxGraphBuilder ------------------------------------------ The only difference from a plain :func:`~yobx.sklearn.to_onnx` call is passing ``builder_cls=SpoxGraphBuilder``. The converters for each step in the pipeline will use :epkg:`spox` opset modules to emit the ONNX nodes. .. GENERATED FROM PYTHON SOURCE LINES 71-79 .. code-block:: Python onx = to_onnx(pipe, (X_train[:1],), builder_cls=SpoxGraphBuilder) print(f"\nONNX opset : {onx.opset_import[0].version}") print("Node types :", [n.op_type for n in onx.graph.node]) print("Graph input : ", [(inp.name, inp.type.tensor_type.elem_type) for inp in onx.graph.input]) print("Graph outputs:", [out.name for out in onx.graph.output]) .. rst-class:: sphx-glr-script-out .. code-block:: none ONNX opset : 21 Node types : ['Constant', 'Constant', 'Constant', 'Sub', 'Constant', 'Div', 'Constant', 'Constant', 'Gemm', 'Sigmoid', 'Sub', 'Concat', 'ArgMax', 'Cast', 'Gather', 'Identity', 'Identity'] Graph input : [('X', 1)] Graph outputs: ['label', 'probabilities'] .. GENERATED FROM PYTHON SOURCE LINES 80-85 3. Run and verify — binary classification ------------------------------------------ We run the exported model with :epkg:`onnxruntime` and check that class labels and probabilities match scikit-learn's predictions. .. GENERATED FROM PYTHON SOURCE LINES 85-101 .. code-block:: Python X_test = rng.standard_normal((20, 4)).astype(np.float32) ref = onnxruntime.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) label_onnx, proba_onnx = ref.run(None, {"X": X_test}) label_sk = pipe.predict(X_test) proba_sk = pipe.predict_proba(X_test).astype(np.float32) print("\nFirst 5 labels (sklearn):", label_sk[:5]) print("First 5 labels (ONNX) :", label_onnx[:5]) assert np.array_equal(label_sk, label_onnx), "Label mismatch!" assert np.allclose(proba_sk, proba_onnx, atol=1e-5), "Probability mismatch!" print("\nBinary predictions match ✓") .. rst-class:: sphx-glr-script-out .. code-block:: none First 5 labels (sklearn): [0 1 1 1 0] First 5 labels (ONNX) : [0 1 1 1 0] Binary predictions match ✓ .. GENERATED FROM PYTHON SOURCE LINES 102-108 4. Multiclass classification ----------------------------- The same workflow applies to multiclass problems. The ``LogisticRegression`` converter switches from a sigmoid-based binary graph to a softmax-based multiclass graph automatically. .. GENERATED FROM PYTHON SOURCE LINES 108-130 .. code-block:: Python X_mc = rng.standard_normal((120, 4)).astype(np.float32) y_mc = rng.integers(0, 3, size=120) pipe_mc = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=500))]) pipe_mc.fit(X_mc, y_mc) X_test_mc = rng.standard_normal((30, 4)).astype(np.float32) onx_mc = to_onnx(pipe_mc, (X_test_mc[:1],), builder_cls=SpoxGraphBuilder) ref_mc = onnxruntime.InferenceSession( onx_mc.SerializeToString(), providers=["CPUExecutionProvider"] ) label_mc_onnx, proba_mc_onnx = ref_mc.run(None, {"X": X_test_mc}) label_mc_sk = pipe_mc.predict(X_test_mc) proba_mc_sk = pipe_mc.predict_proba(X_test_mc).astype(np.float32) assert np.array_equal(label_mc_sk, label_mc_onnx), "Multiclass label mismatch!" assert np.allclose(proba_mc_sk, proba_mc_onnx, atol=1e-5), "Multiclass proba mismatch!" print("Multiclass predictions match ✓") .. rst-class:: sphx-glr-script-out .. code-block:: none Multiclass predictions match ✓ .. GENERATED FROM PYTHON SOURCE LINES 131-133 5. Visualise the ONNX graph --------------------------- .. GENERATED FROM PYTHON SOURCE LINES 133-135 .. code-block:: Python plot_dot(onx) .. image-sg:: /auto_examples_sklearn/images/sphx_glr_plot_sklearn_with_spox_001.png :alt: plot sklearn with spox :srcset: /auto_examples_sklearn/images/sphx_glr_plot_sklearn_with_spox_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.647 seconds) .. _sphx_glr_download_auto_examples_sklearn_plot_sklearn_with_spox.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sklearn_with_spox.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sklearn_with_spox.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_sklearn_with_spox.zip ` .. include:: plot_sklearn_with_spox.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_