.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_sklearn/plot_sklearn_pipeline.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sklearn_plot_sklearn_pipeline.py: .. _l-plot-sklearn-pipeline: Converting a scikit-learn Pipeline to ONNX ========================================== :func:`yobx.sklearn.to_onnx` converts fitted :epkg:`scikit-learn` estimators and pipelines into :class:`onnx.ModelProto` objects that can be executed with any ONNX-compatible runtime. The converter covers the following estimators (see :mod:`yobx.sklearn` for the full registry): * :class:`sklearn.preprocessing.StandardScaler` * :class:`sklearn.linear_model.LogisticRegression` / :class:`~sklearn.linear_model.LogisticRegressionCV` * :class:`sklearn.tree.DecisionTreeClassifier` / :class:`~sklearn.tree.DecisionTreeRegressor` * :class:`sklearn.pipeline.Pipeline` — chains the above step-by-step The workflow is: 1. **Train** a scikit-learn estimator (or pipeline) as usual. 2. Call :func:`yobx.sklearn.to_onnx` with a representative dummy input to convert the fitted model into an ONNX graph. 3. **Run** the ONNX model with any ONNX runtime — this example uses :epkg:`onnxruntime`. 4. **Verify** that the ONNX outputs match scikit-learn's predictions. .. GENERATED FROM PYTHON SOURCE LINES 32-41 .. code-block:: Python import numpy as np import onnxruntime from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from yobx.doc import plot_dot from yobx.sklearn import to_onnx .. GENERATED FROM PYTHON SOURCE LINES 42-48 1. Train a scikit-learn pipeline --------------------------------- We train a simple two-step pipeline: ``StandardScaler`` followed by ``LogisticRegression``. The dataset has eighty samples and four features with two classes. .. GENERATED FROM PYTHON SOURCE LINES 48-60 .. code-block:: Python rng = np.random.default_rng(0) X_train = rng.standard_normal((80, 4)).astype(np.float32) y_train = (X_train[:, 0] + X_train[:, 1] > 0).astype(int) pipe = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())]) pipe.fit(X_train, y_train) print("Pipeline steps:") for name, step in pipe.steps: print(f" {name}: {step}") .. rst-class:: sphx-glr-script-out .. code-block:: none Pipeline steps: scaler: StandardScaler() clf: LogisticRegression() .. GENERATED FROM PYTHON SOURCE LINES 61-67 2. Convert to ONNX ------------------- :func:`yobx.sklearn.to_onnx` requires a representative *dummy input* (a NumPy array) so it can infer the input dtype and shape. The first axis is automatically treated as the batch dimension. .. GENERATED FROM PYTHON SOURCE LINES 67-78 .. code-block:: Python onx = to_onnx(pipe, (X_train[:1],)) print(f"\nONNX model opset : {onx.opset_import[0].version}") print(f"Number of nodes : {len(onx.graph.node)}") print("Node op-types :", [n.op_type for n in onx.graph.node]) print( "Graph inputs :", [(inp.name, inp.type.tensor_type.elem_type) for inp in onx.graph.input] ) print("Graph outputs :", [out.name for out in onx.graph.output]) .. rst-class:: sphx-glr-script-out .. code-block:: none ONNX model opset : 21 Number of nodes : 8 Node op-types : ['Sub', 'Div', 'Gemm', 'Sigmoid', 'Sub', 'Concat', 'ArgMax', 'Gather'] Graph inputs : [('X', 1)] Graph outputs : ['label', 'probabilities'] .. GENERATED FROM PYTHON SOURCE LINES 79-84 3. Run the ONNX model and compare outputs ------------------------------------------ We run the converted model on a held-out test set and verify that the ONNX predictions match those produced by scikit-learn. .. GENERATED FROM PYTHON SOURCE LINES 84-103 .. code-block:: Python X_test = rng.standard_normal((20, 4)).astype(np.float32) y_test = (X_test[:, 0] + X_test[:, 1] > 0).astype(int) ref = onnxruntime.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) label_onnx, proba_onnx = ref.run(None, {"X": X_test}) label_sk = pipe.predict(X_test) proba_sk = pipe.predict_proba(X_test).astype(np.float32) print("\nFirst 5 labels (sklearn):", label_sk[:5]) print("First 5 labels (ONNX) :", label_onnx[:5]) print("First 5 probas (sklearn):", proba_sk[:5]) print("First 5 probas (ONNX) :", proba_onnx[:5]) assert np.array_equal(label_sk, label_onnx), "Label mismatch!" assert np.allclose(proba_sk, proba_onnx, atol=1e-5), "Probability mismatch!" print("\nAll predictions match ✓") .. rst-class:: sphx-glr-script-out .. code-block:: none First 5 labels (sklearn): [0 1 1 1 0] First 5 labels (ONNX) : [0 1 1 1 0] First 5 probas (sklearn): [[0.6968902 0.30310982] [0.06376062 0.93623936] [0.38313922 0.6168608 ] [0.16703625 0.83296376] [0.63305384 0.36694616]] First 5 probas (ONNX) : [[0.6968902 0.30310982] [0.06376064 0.93623936] [0.3831392 0.6168608 ] [0.16703618 0.8329638 ] [0.63305384 0.36694616]] All predictions match ✓ .. GENERATED FROM PYTHON SOURCE LINES 104-110 4. Multiclass pipeline ----------------------- The same API works transparently for multiclass problems. The ``LogisticRegression`` converter automatically switches from a sigmoid-based binary graph to a softmax-based multiclass graph. .. GENERATED FROM PYTHON SOURCE LINES 110-132 .. code-block:: Python X_mc = rng.standard_normal((120, 4)).astype(np.float32) y_mc = rng.integers(0, 3, size=120) pipe_mc = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(max_iter=500))]) pipe_mc.fit(X_mc, y_mc) X_test_mc = rng.standard_normal((30, 4)).astype(np.float32) onx_mc = to_onnx(pipe_mc, (X_test_mc[:1],)) ref_mc = onnxruntime.InferenceSession( onx_mc.SerializeToString(), providers=["CPUExecutionProvider"] ) label_mc_onnx, proba_mc_onnx = ref_mc.run(None, {"X": X_test_mc}) label_mc_sk = pipe_mc.predict(X_test_mc) proba_mc_sk = pipe_mc.predict_proba(X_test_mc).astype(np.float32) assert np.array_equal(label_mc_sk, label_mc_onnx), "Multiclass label mismatch!" assert np.allclose(proba_mc_sk, proba_mc_onnx, atol=1e-5), "Multiclass proba mismatch!" print("Multiclass predictions match ✓") .. rst-class:: sphx-glr-script-out .. code-block:: none Multiclass predictions match ✓ .. GENERATED FROM PYTHON SOURCE LINES 133-136 5. Visualize the ONNX graph ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 136-137 .. code-block:: Python plot_dot(onx) .. image-sg:: /auto_examples_sklearn/images/sphx_glr_plot_sklearn_pipeline_001.png :alt: plot sklearn pipeline :srcset: /auto_examples_sklearn/images/sphx_glr_plot_sklearn_pipeline_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.292 seconds) .. _sphx_glr_download_auto_examples_sklearn_plot_sklearn_pipeline.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sklearn_pipeline.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sklearn_pipeline.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_sklearn_pipeline.zip ` .. include:: plot_sklearn_pipeline.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_