Einsum Decomposition into ONNX Nodes#
The ONNX specification includes an Einsum operator, but several runtimes
either do not support it or only support a limited subset of equations.
decompose_einsum replaces
a single Einsum node with a sequence of simpler, universally supported ONNX
operators: Transpose, Reshape, MatMul, Mul, ReduceSum,
Unsqueeze, Squeeze, and Identity.
The implementation lives in yobx.helpers._einsum, a self-contained
sub-package with no external dependencies beyond NumPy and ONNX.
Overview#
An einsum equation such as "bij,bjk->bik" describes a contraction of one
or more input tensors into an output tensor. The subscripts before ->
label the dimensions of each input operand; the subscript after -> labels
the dimensions of the output.
The decomposition algorithm proceeds in three stages:
Equation analysis — parse the equation into a compact matrix representation that records, for every letter, its position in each operand and in the output.
Graph construction — traverse the operands left to right, emitting
EinsumSubOpnodes that align dimensions, contract pairs of operands, and reduce dimensions that are no longer needed.ONNX emission — walk the
GraphEinsumSubOpand lower eachEinsumSubOpto one or more ONNX nodes.
Stage 1 — Equation Analysis#
analyse_einsum_equation
parses the equation and returns four objects:
letters — sorted string of all unique letters that appear in the equation (e.g.
"bcdeijk"for"bac,cd,def->ebc").mat — a
(n_inputs + 1) × n_lettersinteger matrix where each entrymat[i, j]is the position of letterjin operandi, or-1if that letter does not appear in operandi. The last row encodes the output.lengths — list of ranks (one per operand plus one for the output).
duplicates — per-operand dict of letters that appear more than once (used to detect diagonal / trace operations).
Example for "bac,cd,def->ebc":
<<<
from yobx.helpers._einsum import analyse_einsum_equation
letters, mat, lengths, duplicates = analyse_einsum_equation("bac,cd,def->ebc")
print("letters :", letters)
print("lengths :", lengths)
print("mat :")
print(mat)
>>>
letters : abcdef
lengths : [3, 2, 3, 3]
mat :
[[ 1 0 2 -1 -1 -1]
[-1 -1 0 1 -1 -1]
[-1 -1 -1 0 1 2]
[-1 1 2 -1 0 -1]]
The matrix encodes the full algebraic structure of the equation in a form that is easy to manipulate programmatically.
Stage 2 — Graph Construction#
decompose_einsum_equation
builds a directed acyclic graph of
EinsumSubOp
nodes. Each node represents one primitive tensor operation.
The available node types are:
Node type |
Meaning |
|---|---|
|
Identity / input placeholder — references one of the original input operands. |
|
Inserts size-1 axes so that all operands share the same set of
dimension positions (analogous to |
|
Permutes axes to bring them into the canonical alphabetical order. |
|
Extracts the diagonal when a letter appears twice in the same operand
(e.g. |
|
Sums out dimensions that are no longer needed after the current contraction step. |
|
Like |
|
Generic matrix contraction implemented with a single
|
|
Matrix multiplication for the |
|
Element-wise multiplication (used when there are no contraction axes). |
|
Like |
|
Removes the size-1 axes that were added by |
The graph is built by iterating over the input operands in order. For each operand the algorithm:
Emits an
idnode to reference the raw input.Handles any diagonal operation if the same letter appears twice.
Calls
_apply_transpose_reshapeto insertexpand_dimsandtransposenodes that bring all dimensions into the shared alphabetical order.Emits an optional
reduce_sumto eliminate dimensions that will not appear in any later operand or in the output.If a previous partial result already exists, calls
_apply_einsum_matmulto contract the previous result with the current operand, producingmatmulorbatch_dot(and auxiliarytranspose/reduce_sum) nodes as required.
After all operands have been processed, a final reduce_sum and
squeeze / transpose step brings the accumulated result into the
shape and axis order demanded by the output subscript.
The graph produced for "bac,cd,def->ebc" looks like this:
Decomposition strategies#
Two strategies are available via the strategy parameter of
decompose_einsum_equation:
"simple"— simplified approach. Contractions between two aligned operands are emitted as a singlematmulnode which is still evaluated withnumpy.einsuminternally."numpy"— contractions are fully expanded intoTranspose+Reshape+batch_dotnodes so that nonumpy.einsumcall remains. This is the default used bydecompose_einsumwhen generating ONNX models.
The simplified approach is useful to understand or debug the contraction plan:
the intermediate graph is typically shorter and closer to the original equation.
The "numpy" strategy is more explicit and is the one intended for ONNX
export because each contraction is rewritten into standard tensor primitives.
The main differences are summarized below:
Aspect |
|
|
|---|---|---|
Contraction implementation |
|
contractions are expanded into |
Graph structure |
shorter and easier to read when inspecting decomposition logic |
more explicit, with additional layout and reshape nodes |
Intended use |
introspection, teaching, debugging decomposition behavior |
robust ONNX model generation without remaining |
Printing the operation sequence#
The GraphEinsumSubOp
object can be iterated to inspect the full sequence:
<<<
from yobx.helpers._einsum import decompose_einsum_equation
seq = decompose_einsum_equation("bac,cd,def->ebc")
for op in seq:
print(op)
>>>
EinsumSubOp('id', 0, )
EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5)))
EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5)))
EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,))
EinsumSubOp('id', 1, )
EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5)))
EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6)
EinsumSubOp('id', 2, )
EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2)))
EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,))
EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6)
EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5)))
EinsumSubOp('squeeze', EinsumSubOp('transpose', EinsumSubOp('matmul', EinsumSubOp('matmul', EinsumSubOp('reduce_sum', EinsumSubOp('transpose', EinsumSubOp('expand_dims', EinsumSubOp('id', 0, ), axes=((3, 3), (3, 4), (3, 5))), perm=(np.int64(1), np.int64(0), np.int64(2), np.int64(3), np.int64(4), np.int64(5))), axes=(0,)), EinsumSubOp('expand_dims', EinsumSubOp('id', 1, ), axes=((0, 0), (0, 1), (2, 4), (2, 5))), axes=(), left=(1, 2), right=(2, 3), ndim=6), EinsumSubOp('reduce_sum', EinsumSubOp('expand_dims', EinsumSubOp('id', 2, ), axes=((0, 0), (0, 1), (0, 2))), axes=(5,)), axes=(3,), left=(1, 2), right=(4,), ndim=6), perm=(np.int64(0), np.int64(4), np.int64(1), np.int64(3), np.int64(2), np.int64(5))), axes=(0, 3, 5))
Stage 3 — ONNX Emission#
decompose_einsum calls
GraphEinsumSubOp.to_onnx
which walks the graph and converts each
EinsumSubOp into
one or more ONNX nodes according to the following mapping:
EinsumSubOp |
ONNX node(s) |
|---|---|
|
direct wire (no node emitted) |
|
one |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
The resulting model is a stand-alone onnx.ModelProto that can be
run directly with any compliant ONNX runtime.
End-to-end example#
The following snippet decomposes the batched matrix multiplication
"bij,bjk->bik" and validates the result numerically:
<<<
import numpy as np
import onnxruntime
from yobx.helpers.einsum_helper import decompose_einsum
model = decompose_einsum("bij,bjk->bik", (2, 3, 4), (2, 4, 5))
print("ONNX node types:", [n.op_type for n in model.graph.node])
sess = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
a = np.random.rand(2, 3, 4).astype(np.float32)
b = np.random.rand(2, 4, 5).astype(np.float32)
(result,) = sess.run(None, {"X0": a, "X1": b})
expected = np.einsum("bij,bjk->bik", a, b)
print("max |error|:", np.max(np.abs(result - expected)))
>>>
ONNX node types: ['MatMul']
max |error|: 1.1920929e-07
More examples (including a three-operand contraction and a chart comparing node counts) are shown in the gallery example Decompose Einsum into Regular ONNX Operators.
See also
decompose_einsum— public API function.decompose_einsum_equation— internal function that builds the operation graph.EinsumSubOp— a single node in the decomposition graph.GraphEinsumSubOp— the full decomposition graph.Decompose Einsum into Regular ONNX Operators — gallery example with numerical validation and node-count comparison.