.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_sklearn/plot_sklearn_kmeans.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sklearn_plot_sklearn_kmeans.py: .. _l-plot-sklearn-kmeans: Converting a scikit-learn KMeans to ONNX ========================================= :func:`yobx.sklearn.to_onnx` converts a fitted :class:`sklearn.cluster.KMeans` into an :class:`onnx.ModelProto` that can be executed with any ONNX-compatible runtime. The converted model produces two outputs: * **label** - cluster index for each sample (equivalent to :meth:`~sklearn.cluster.KMeans.predict`). * **distances** - Euclidean distance from each sample to every centroid (equivalent to :meth:`~sklearn.cluster.KMeans.transform`). The workflow is: 1. **Train** a :class:`~sklearn.cluster.KMeans` as usual. 2. Call :func:`yobx.sklearn.to_onnx` with a representative dummy input. 3. **Run** the ONNX model with any ONNX runtime — this example uses :epkg:`onnxruntime`. 4. **Verify** that the ONNX outputs match scikit-learn's predictions. .. GENERATED FROM PYTHON SOURCE LINES 27-36 .. code-block:: Python import numpy as np import onnxruntime from sklearn.cluster import KMeans from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from yobx.doc import plot_dot from yobx.sklearn import to_onnx .. GENERATED FROM PYTHON SOURCE LINES 37-39 1. Train a KMeans model ----------------------- .. GENERATED FROM PYTHON SOURCE LINES 39-46 .. code-block:: Python rng = np.random.default_rng(0) X = rng.standard_normal((100, 4)).astype(np.float32) km = KMeans(n_clusters=3, random_state=0, n_init=10) km.fit(X) .. raw:: html
KMeans(n_clusters=3, n_init=10, random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 47-49 2. Convert to ONNX ------------------ .. GENERATED FROM PYTHON SOURCE LINES 49-54 .. code-block:: Python onx = to_onnx(km, (X,)) print(f"ONNX model inputs : {[i.name for i in onx.graph.input]}") print(f"ONNX model outputs: {[o.name for o in onx.graph.output]}") .. rst-class:: sphx-glr-script-out .. code-block:: none ONNX model inputs : ['X'] ONNX model outputs: ['label', 'distances'] .. GENERATED FROM PYTHON SOURCE LINES 55-57 3. Run the ONNX model and compare outputs ------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 57-74 .. code-block:: Python ref = onnxruntime.InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"]) label_onnx, distances_onnx = ref.run(None, {"X": X}) label_sk = km.predict(X).astype(np.int64) distances_sk = km.transform(X).astype(np.float32) print("\nFirst 5 labels (sklearn):", label_sk[:5]) print("First 5 labels (ONNX) :", label_onnx[:5]) print("\nFirst 5 distances (sklearn):", distances_sk[:5].round(4)) print("First 5 distances (ONNX) :", distances_onnx[:5].round(4)) assert (label_sk == label_onnx).all(), "Labels differ!" assert np.allclose(distances_sk, distances_onnx, atol=1e-4), "Distances differ!" print("\nAll labels and distances match ✓") .. rst-class:: sphx-glr-script-out .. code-block:: none First 5 labels (sklearn): [2 1 0 0 1] First 5 labels (ONNX) : [2 1 0 0 1] First 5 distances (sklearn): [[1.9018 1.3076 1.0513] [2.7406 1.1723 2.3066] [0.7925 2.3442 2.1114] [1.8823 3.3654 3.1916] [1.8663 0.9522 2.1298]] First 5 distances (ONNX) : [[1.9018 1.3076 1.0513] [2.7406 1.1723 2.3066] [0.7925 2.3442 2.1114] [1.8823 3.3654 3.1916] [1.8663 0.9522 2.1298]] All labels and distances match ✓ .. GENERATED FROM PYTHON SOURCE LINES 75-79 4. KMeans inside a Pipeline ---------------------------- KMeans also works as the final step of a :class:`~sklearn.pipeline.Pipeline`. .. GENERATED FROM PYTHON SOURCE LINES 79-97 .. code-block:: Python pipe = Pipeline( [("scaler", StandardScaler()), ("km", KMeans(n_clusters=3, random_state=0, n_init=10))] ) pipe.fit(X) onx_pipe = to_onnx(pipe, (X,)) ref_pipe = onnxruntime.InferenceSession( onx_pipe.SerializeToString(), providers=["CPUExecutionProvider"] ) label_pipe_onnx, _ = ref_pipe.run(None, {"X": X}) label_pipe_sk = pipe.predict(X).astype(np.int64) assert (label_pipe_sk == label_pipe_onnx).all(), "Pipeline labels differ!" print("Pipeline labels match ✓") .. rst-class:: sphx-glr-script-out .. code-block:: none Pipeline labels match ✓ .. GENERATED FROM PYTHON SOURCE LINES 98-100 5. Visualize the pipeline ------------------------- .. GENERATED FROM PYTHON SOURCE LINES 100-102 .. code-block:: Python plot_dot(onx_pipe) .. image-sg:: /auto_examples_sklearn/images/sphx_glr_plot_sklearn_kmeans_001.png :alt: plot sklearn kmeans :srcset: /auto_examples_sklearn/images/sphx_glr_plot_sklearn_kmeans_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.437 seconds) .. _sphx_glr_download_auto_examples_sklearn_plot_sklearn_kmeans.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sklearn_kmeans.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sklearn_kmeans.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_sklearn_kmeans.zip ` .. include:: plot_sklearn_kmeans.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_