.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_core/plot_einsum.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_core_plot_einsum.py: .. _l-plot-einsum-decomposition: Decompose Einsum into Regular ONNX Operators ============================================ The ONNX ``Einsum`` operator is very expressive but not all runtimes support it. :func:`decompose_einsum ` converts an einsum equation into a sub-graph built from simpler operators (``Transpose``, ``Reshape``, ``MatMul``, ``Mul``, ``ReduceSum``, …) that every compliant ONNX runtime understands. The decomposition is implemented in :mod:`yobx.helpers._einsum`, a self-contained sub-package — no external dependency is required. .. GENERATED FROM PYTHON SOURCE LINES 17-23 .. code-block:: Python import numpy as np import onnxruntime from yobx.doc import plot_dot from yobx.helpers.einsum_helper import decompose_einsum .. GENERATED FROM PYTHON SOURCE LINES 24-28 1. Matrix multiplication — ``ij,jk->ik`` ----------------------------------------- The simplest useful einsum: multiply two 2-D matrices. .. GENERATED FROM PYTHON SOURCE LINES 28-42 .. code-block:: Python model_mm = decompose_einsum("ij,jk->ik", (3, 4), (4, 5)) # Validate the result numerically. sess = onnxruntime.InferenceSession( model_mm.SerializeToString(), providers=["CPUExecutionProvider"] ) a = np.random.rand(3, 4).astype(np.float32) b = np.random.rand(4, 5).astype(np.float32) (result,) = sess.run(None, {"X0": a, "X1": b}) expected = np.einsum("ij,jk->ik", a, b) print("max |error|:", np.max(np.abs(result - expected))) assert np.allclose(result, expected, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none max |error|: 1.1920929e-07 .. GENERATED FROM PYTHON SOURCE LINES 43-48 Graph of ``ij,jk->ik`` ~~~~~~~~~~~~~~~~~~~~~~~ :func:`~yobx.doc.plot_dot` renders the decomposed ONNX graph so you can see every node and edge at a glance. .. GENERATED FROM PYTHON SOURCE LINES 48-51 .. code-block:: Python plot_dot(model_mm) .. image-sg:: /auto_examples_core/images/sphx_glr_plot_einsum_001.png :alt: plot einsum :srcset: /auto_examples_core/images/sphx_glr_plot_einsum_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 52-56 2. Batched matrix multiplication — ``bij,bjk->bik`` ---------------------------------------------------- A 3-D batched version of the matrix product. .. GENERATED FROM PYTHON SOURCE LINES 56-69 .. code-block:: Python model_bmm = decompose_einsum("bij,bjk->bik", (2, 3, 4), (2, 4, 5)) sess2 = onnxruntime.InferenceSession( model_bmm.SerializeToString(), providers=["CPUExecutionProvider"] ) a = np.random.rand(2, 3, 4).astype(np.float32) b = np.random.rand(2, 4, 5).astype(np.float32) (result,) = sess2.run(None, {"X0": a, "X1": b}) expected = np.einsum("bij,bjk->bik", a, b) print("max |error|:", np.max(np.abs(result - expected))) assert np.allclose(result, expected, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none max |error|: 1.1920929e-07 .. GENERATED FROM PYTHON SOURCE LINES 70-72 Graph of ``bij,bjk->bik`` ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 72-75 .. code-block:: Python plot_dot(model_bmm) .. image-sg:: /auto_examples_core/images/sphx_glr_plot_einsum_002.png :alt: plot einsum :srcset: /auto_examples_core/images/sphx_glr_plot_einsum_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 76-80 3. Three-operand contraction — ``bac,cd,def->ebc`` --------------------------------------------------- A more complex equation involving three input tensors. .. GENERATED FROM PYTHON SOURCE LINES 80-94 .. code-block:: Python model_3op = decompose_einsum("bac,cd,def->ebc", (2, 2, 2), (2, 2), (2, 2, 2)) sess3 = onnxruntime.InferenceSession( model_3op.SerializeToString(), providers=["CPUExecutionProvider"] ) x0 = np.random.rand(2, 2, 2).astype(np.float32) x1 = np.random.rand(2, 2).astype(np.float32) x2 = np.random.rand(2, 2, 2).astype(np.float32) (result,) = sess3.run(None, {"X0": x0, "X1": x1, "X2": x2}) expected = np.einsum("bac,cd,def->ebc", x0, x1, x2) print("max |error|:", np.max(np.abs(result - expected))) assert np.allclose(result, expected, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none max |error|: 2.3841858e-07 .. GENERATED FROM PYTHON SOURCE LINES 95-97 Graph of ``bac,cd,def->ebc`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 97-99 .. code-block:: Python plot_dot(model_3op) .. image-sg:: /auto_examples_core/images/sphx_glr_plot_einsum_003.png :alt: plot einsum :srcset: /auto_examples_core/images/sphx_glr_plot_einsum_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.672 seconds) .. _sphx_glr_download_auto_examples_core_plot_einsum.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_einsum.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_einsum.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_einsum.zip ` .. include:: plot_einsum.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_