Einsum Decomposition into ONNX Nodes#

The ONNX specification includes an Einsum operator, but several runtimes either do not support it or only support a limited subset of equations. decompose_einsum replaces a single Einsum node with a sequence of simpler, universally supported ONNX operators: Transpose, Reshape, MatMul, Mul, ReduceSum, Unsqueeze, Squeeze, and Identity.

The implementation lives in yobx.helpers._einsum, a self-contained sub-package with no external dependencies beyond NumPy and ONNX.

Overview#

An einsum equation such as "bij,bjk->bik" describes a contraction of one or more input tensors into an output tensor. The subscripts before -> label the dimensions of each input operand; the subscript after -> labels the dimensions of the output.

The decomposition algorithm proceeds in three stages:

  1. Equation analysis — parse the equation into a compact matrix representation that records, for every letter, its position in each operand and in the output.

  2. Graph construction — traverse the operands left to right, emitting EinsumSubOp nodes that align dimensions, contract pairs of operands, and reduce dimensions that are no longer needed.

  3. ONNX emission — walk the GraphEinsumSubOp and lower each EinsumSubOp to one or more ONNX nodes.

Stage 1 — Equation Analysis#

analyse_einsum_equation parses the equation and returns four objects:

  • letters — sorted string of all unique letters that appear in the equation (e.g. "bcdeijk" for "bac,cd,def->ebc").

  • mat — a (n_inputs + 1) × n_letters integer matrix where each entry mat[i, j] is the position of letter j in operand i, or -1 if that letter does not appear in operand i. The last row encodes the output.

  • lengths — list of ranks (one per operand plus one for the output).

  • duplicates — per-operand dict of letters that appear more than once (used to detect diagonal / trace operations).

Example for "bac,cd,def->ebc":

<<<

from yobx.helpers._einsum import analyse_einsum_equation

letters, mat, lengths, duplicates = analyse_einsum_equation("bac,cd,def->ebc")
print("letters :", letters)
print("lengths :", lengths)
print("mat     :")
print(mat)

>>>

    letters : abcdef
    lengths : [3, 2, 3, 3]
    mat     :
    [[ 1  0  2 -1 -1 -1]
     [-1 -1  0  1 -1 -1]
     [-1 -1 -1  0  1  2]
     [-1  1  2 -1  0 -1]]

The matrix encodes the full algebraic structure of the equation in a form that is easy to manipulate programmatically.

Stage 2 — Graph Construction#

decompose_einsum_equation builds a directed acyclic graph of EinsumSubOp nodes. Each node represents one primitive tensor operation.

The available node types are:

Node type

Meaning

id

Identity / input placeholder — references one of the original input operands.

expand_dims

Inserts size-1 axes so that all operands share the same set of dimension positions (analogous to numpy.expand_dims).

transpose

Permutes axes to bring them into the canonical alphabetical order.

diagonal

Extracts the diagonal when a letter appears twice in the same operand (e.g. "ii->i").

reduce_sum

Sums out dimensions that are no longer needed after the current contraction step.

reduce_sum_mm

Like reduce_sum but takes two inputs; only the first is reduced (used internally during matrix-multiplication decomposition).

matmul

Generic matrix contraction implemented with a single numpy.einsum call on two already-aligned operands.

batch_dot

Matrix multiplication for the "numpy" strategy: combines Transpose + Reshape + MatMul to avoid any remaining einsum call.

mul

Element-wise multiplication (used when there are no contraction axes).

transpose_mm

Like transpose but takes two inputs; only the first is permuted.

squeeze

Removes the size-1 axes that were added by expand_dims once they are no longer required.

The graph is built by iterating over the input operands in order. For each operand the algorithm:

  1. Emits an id node to reference the raw input.

  2. Handles any diagonal operation if the same letter appears twice.

  3. Calls _apply_transpose_reshape to insert expand_dims and transpose nodes that bring all dimensions into the shared alphabetical order.

  4. Emits an optional reduce_sum to eliminate dimensions that will not appear in any later operand or in the output.

  5. If a previous partial result already exists, calls _apply_einsum_matmul to contract the previous result with the current operand, producing matmul or batch_dot (and auxiliary transpose / reduce_sum) nodes as required.

After all operands have been processed, a final reduce_sum and squeeze / transpose step brings the accumulated result into the shape and axis order demanded by the output subscript.

The graph produced for "bac,cd,def->ebc" looks like this:

digraph{ orientation=portrait; ranksep=0.25; nodesep=0.05; width=0.5; height=0.1; size=5; node [shape=record]; 0 [label="input 0\\nbac\\n[ 1 0 2 -1 -1 -1]"]; 133426593637984 [label="id\\n"]; 0 -> 133426593637984; 133426592620976 [label="expand_dims\\naxes=((3, 3), (3, 4), (3, 5))"]; 133426593637984 -> 133426592620976; 133428366390864 [label="transpose\\nperm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))"]; 133426592620976 -> 133428366390864; 133428366390720 [label="reduce_sum - I0\\naxes=(0,)" style=filled fillcolor=red]; 133428366390864 -> 133428366390720; 1 [label="input 1\\ncd\\n[-1 -1 0 1 -1 -1]"]; 133427455183952 [label="id\\n"]; 1 -> 133427455183952; 133428366390816 [label="expand_dims\\naxes=((0, 0), (0, 1), (2, 4), (2, 5))"]; 133427455183952 -> 133428366390816; 133426590878896 [label="matmul - I1\\naxes=() left=(1, 2) ndim=6 right=(2, 3)\\n~aBCdef,abCDef-\\=aBCDef" style=filled fillcolor=red]; 133428366390720 -> 133426590878896; 133428366390816 -> 133426590878896; 2 [label="input 2\\ndef\\n[-1 -1 -1 0 1 2]"]; 133426590878128 [label="id\\n"]; 2 -> 133426590878128; 133426590879136 [label="expand_dims\\naxes=((0, 0), (0, 1), (0, 2))"]; 133426590878128 -> 133426590879136; 133426590879328 [label="reduce_sum\\naxes=(5,)"]; 133426590879136 -> 133426590879328; 133426590879904 [label="matmul - I2\\naxes=(3,) left=(1, 2) ndim=6 right=(4,)\\n~aBCdef,abcdEf-\\=aBCEf" style=filled fillcolor=red]; 133426590878896 -> 133426590879904; 133426590879328 -> 133426590879904; 133426595579424 [label="transpose\\nperm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5))"]; 133426590879904 -> 133426595579424; 133426590880000 [label="squeeze - I-1\\naxes=(0, 3, 5)" style=filled fillcolor=red]; 133426595579424 -> 133426590880000; }

Decomposition strategies#

Two strategies are available via the strategy parameter of decompose_einsum_equation:

  • "simple"simplified approach. Contractions between two aligned operands are emitted as a single matmul node which is still evaluated with numpy.einsum internally.

  • "numpy" — contractions are fully expanded into Transpose + Reshape + batch_dot nodes so that no numpy.einsum call remains. This is the default used by decompose_einsum when generating ONNX models.

The simplified approach is useful to understand or debug the contraction plan: the intermediate graph is typically shorter and closer to the original equation. The "numpy" strategy is more explicit and is the one intended for ONNX export because each contraction is rewritten into standard tensor primitives.

The main differences are summarized below:

Aspect

"simple" (simplified)

"numpy"

Contraction implementation

matmul sub-op may still rely on an internal numpy.einsum call

contractions are expanded into Transpose + Reshape + batch_dot-style steps

Graph structure

shorter and easier to read when inspecting decomposition logic

more explicit, with additional layout and reshape nodes

Intended use

introspection, teaching, debugging decomposition behavior

robust ONNX model generation without remaining einsum semantics

Printing the operation sequence#

The GraphEinsumSubOp object can be iterated to inspect the full sequence:

<<<

from yobx.helpers._einsum import decompose_einsum_equation

seq = decompose_einsum_equation("bac,cd,def->ebc")
for op in seq:
    print(op)

>>>

    EinsumSubOp('id', 0, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5)))
    EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5)))
    EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,))
    EinsumSubOp('id', 1, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5)))
    EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6)
    EinsumSubOp('id', 2, )
    EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2)))
    EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,))
    EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6)
    EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5)))
    EinsumSubOp('squeeze', EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5))), axes=(0, 3, 5))

Stage 3 — ONNX Emission#

decompose_einsum calls GraphEinsumSubOp.to_onnx which walks the graph and converts each EinsumSubOp into one or more ONNX nodes according to the following mapping:

EinsumSubOp

ONNX node(s)

id

direct wire (no node emitted)

expand_dims

one Unsqueeze node per inserted axis

transpose

Transpose (omitted when the permutation is the identity)

diagonal

Gather along the diagonal axis

reduce_sum

ReduceSum

matmul

MatMul (after optional Transpose + Reshape)

batch_dot

Transpose + Reshape + MatMul + Reshape + Transpose

mul

Mul

squeeze

Squeeze

The resulting model is a stand-alone onnx.ModelProto that can be run directly with any compliant ONNX runtime.

End-to-end example#

The following snippet decomposes the batched matrix multiplication "bij,bjk->bik" and validates the result numerically:

<<<

import numpy as np
import onnxruntime
from yobx.helpers.einsum_helper import decompose_einsum

model = decompose_einsum("bij,bjk->bik", (2, 3, 4), (2, 4, 5))
print("ONNX node types:", [n.op_type for n in model.graph.node])

sess = onnxruntime.InferenceSession(
    model.SerializeToString(), providers=["CPUExecutionProvider"]
)
a = np.random.rand(2, 3, 4).astype(np.float32)
b = np.random.rand(2, 4, 5).astype(np.float32)
(result,) = sess.run(None, {"X0": a, "X1": b})
expected = np.einsum("bij,bjk->bik", a, b)
print("max |error|:", np.max(np.abs(result - expected)))

>>>

    ONNX node types: ['MatMul']
    max |error|: 1.1920929e-07

More examples (including a three-operand contraction and a chart comparing node counts) are shown in the gallery example Decompose Einsum into Regular ONNX Operators.

See also