.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_core/plot_einsum.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_core_plot_einsum.py: .. _l-plot-einsum-decomposition: Decompose Einsum into Regular ONNX Operators ============================================ The ONNX ``Einsum`` operator is very expressive but not all runtimes support it. :func:`decompose_einsum ` converts an einsum equation into a sub-graph built from simpler operators (``Transpose``, ``Reshape``, ``MatMul``, ``Mul``, ``ReduceSum``, …) that every compliant ONNX runtime understands. The decomposition is implemented in :mod:`yobx.helpers._einsum`, a self-contained sub-package — no external dependency is required. .. GENERATED FROM PYTHON SOURCE LINES 17-22 .. code-block:: Python import numpy as np import onnxruntime from yobx.helpers.einsum_helper import decompose_einsum .. GENERATED FROM PYTHON SOURCE LINES 23-27 1. Matrix multiplication — ``ij,jk->ik`` ----------------------------------------- The simplest useful einsum: multiply two 2-D matrices. .. GENERATED FROM PYTHON SOURCE LINES 27-44 .. code-block:: Python model_mm = decompose_einsum("ij,jk->ik", (3, 4), (4, 5)) # Inspect the generated node types. print("Node types:", [n.op_type for n in model_mm.graph.node]) # Validate the result numerically. sess = onnxruntime.InferenceSession( model_mm.SerializeToString(), providers=["CPUExecutionProvider"] ) a = np.random.rand(3, 4).astype(np.float32) b = np.random.rand(4, 5).astype(np.float32) (result,) = sess.run(None, {"X0": a, "X1": b}) expected = np.einsum("ij,jk->ik", a, b) print("max |error|:", np.max(np.abs(result - expected))) assert np.allclose(result, expected, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none Node types: ['Identity', 'Unsqueeze', 'Identity', 'Unsqueeze', 'Transpose', 'Transpose', 'Shape', 'Shape', 'Gather', 'Gather', 'Concat', 'Concat', 'Reshape', 'Reshape', 'Gemm', 'Gather', 'Gather', 'Concat', 'Reshape', 'Transpose', 'Squeeze', 'Identity', 'Identity'] max |error|: 5.9604645e-08 .. GENERATED FROM PYTHON SOURCE LINES 45-49 2. Batched matrix multiplication — ``bij,bjk->bik`` ---------------------------------------------------- A 3-D batched version of the matrix product. .. GENERATED FROM PYTHON SOURCE LINES 49-64 .. code-block:: Python model_bmm = decompose_einsum("bij,bjk->bik", (2, 3, 4), (2, 4, 5)) print("Node types:", [n.op_type for n in model_bmm.graph.node]) sess2 = onnxruntime.InferenceSession( model_bmm.SerializeToString(), providers=["CPUExecutionProvider"] ) a = np.random.rand(2, 3, 4).astype(np.float32) b = np.random.rand(2, 4, 5).astype(np.float32) (result,) = sess2.run(None, {"X0": a, "X1": b}) expected = np.einsum("bij,bjk->bik", a, b) print("max |error|:", np.max(np.abs(result - expected))) assert np.allclose(result, expected, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none Node types: ['Identity', 'Unsqueeze', 'Identity', 'Unsqueeze', 'Transpose', 'Transpose', 'Shape', 'Shape', 'Gather', 'Gather', 'Gather', 'Gather', 'Concat', 'Concat', 'Reshape', 'Reshape', 'Transpose', 'MatMul', 'Max', 'Gather', 'Gather', 'Concat', 'Reshape', 'Transpose', 'Squeeze', 'Identity', 'Identity'] max |error|: 1.1920929e-07 .. GENERATED FROM PYTHON SOURCE LINES 65-69 3. Three-operand contraction — ``bac,cd,def->ebc`` --------------------------------------------------- A more complex equation involving three input tensors. .. GENERATED FROM PYTHON SOURCE LINES 69-85 .. code-block:: Python model_3op = decompose_einsum("bac,cd,def->ebc", (2, 2, 2), (2, 2), (2, 2, 2)) print("Node types:", [n.op_type for n in model_3op.graph.node]) sess3 = onnxruntime.InferenceSession( model_3op.SerializeToString(), providers=["CPUExecutionProvider"] ) x0 = np.random.rand(2, 2, 2).astype(np.float32) x1 = np.random.rand(2, 2).astype(np.float32) x2 = np.random.rand(2, 2, 2).astype(np.float32) (result,) = sess3.run(None, {"X0": x0, "X1": x1, "X2": x2}) expected = np.einsum("bac,cd,def->ebc", x0, x1, x2) print("max |error|:", np.max(np.abs(result - expected))) assert np.allclose(result, expected, atol=1e-5) .. rst-class:: sphx-glr-script-out .. code-block:: none Node types: ['Identity', 'Unsqueeze', 'Transpose', 'ReduceSum', 'Identity', 'Unsqueeze', 'Transpose', 'Transpose', 'Shape', 'Shape', 'Gather', 'Gather', 'ReduceProd', 'ReduceProd', 'Concat', 'Concat', 'Reshape', 'Reshape', 'Transpose', 'MatMul', 'Max', 'Gather', 'Gather', 'Concat', 'Reshape', 'Transpose', 'Identity', 'Unsqueeze', 'ReduceSum', 'Transpose', 'Shape', 'Shape', 'Gather', 'Gather', 'Gather', 'Gather', 'Concat', 'Concat', 'Reshape', 'Reshape', 'Gemm', 'Max', 'Gather', 'Gather', 'Concat', 'Reshape', 'Transpose', 'Squeeze', 'Identity', 'Identity'] max |error|: 1.1920929e-07 .. GENERATED FROM PYTHON SOURCE LINES 86-91 4. Operator counts comparison ------------------------------ The bar chart below shows how many ONNX nodes each decomposed graph contains compared to the single ``Einsum`` node it replaces. .. GENERATED FROM PYTHON SOURCE LINES 91-124 .. code-block:: Python import matplotlib.pyplot as plt # noqa: E402 equations = { "ij,jk->ik": [(3, 4), (4, 5)], "bij,bjk->bik": [(2, 3, 4), (2, 4, 5)], "bac,cd,def->ebc": [(2, 2, 2), (2, 2), (2, 2, 2)], } node_counts = {} for eq, shapes in equations.items(): model = decompose_einsum(eq, *shapes) node_counts[eq] = len(model.graph.node) labels = list(node_counts.keys()) counts = list(node_counts.values()) fig, ax = plt.subplots(figsize=(8, 4)) bars = ax.barh(labels, counts, color="#4c72b0") ax.axvline(1, color="#dd8452", linestyle="--", label="1 Einsum node") ax.set_xlabel("Number of ONNX nodes after decomposition") ax.set_title("Einsum decomposition: node count") ax.legend() for bar, count in zip(bars, counts): ax.text( bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2, str(count), va="center", fontsize=9, ) plt.tight_layout() plt.show() .. image-sg:: /auto_examples_core/images/sphx_glr_plot_einsum_001.png :alt: Einsum decomposition: node count :srcset: /auto_examples_core/images/sphx_glr_plot_einsum_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.535 seconds) .. _sphx_glr_download_auto_examples_core_plot_einsum.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_einsum.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_einsum.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_einsum.zip ` .. include:: plot_einsum.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_