ShapeBuilder#

onnx.shape_inference.infer_shapes() tries to infer shapes and types based on input shapes. It does not supports formulas and introduces new symbols.

yobx.xshape.ShapeBuilder class walks through all nodes and looks into a list of functions computing the output shapes based on the node type. It tries as much as possible to express the new shape with formulas based on the dimensions used to defined the inputs. The list of functions is available in yobx.xshape.shape_type_compute called from class _InferenceRuntime.

While doing this, every function may try to compute some tiny constants in _BuilderRuntime. This is used by _ShapeRuntime to deduce some shapes.

The whole algorithm relies on four components:

  • An analyser for expressions able to parse and simplify numerical expressions built upon name for the dynamic dimension sets of the inputs,

  • A list of functions inferring shapes, including the numerical expressions for every ONNX operator,

  • A very simple runtime able to run a short list of kernels usually used to handle shapes (Add, Sub, Mul, Div, Concat, Squeeze, Unsqueeze, Shape, Size, Reshape),

  • An algorithm solving constraints after inferring function was run. An unknown dimension may be known or at least constrained to a short set of values after a binary operator (or any other) was processed. The constraint mechanism is put in place to implement a kind of backward pass where output dimensions restricts the number of possible values for input dimensions.

Class Hierarchy#

BasicShapeBuilder is the main concrete implementation and is composed of four cooperative base classes:

  • ShapeBuilder — the public API contract: get_shape, get_type, get_rank, has_shape, has_type, has_rank, set_shape, set_type, set_rank, evaluate_shape, compare_with_true_inputs, update_shapes.

  • _InferenceRuntime — walks the graph node by node, dispatching each node to the matching per-operator handler in yobx.xshape.shape_type_compute.

  • _BuilderRuntime — evaluates small constant sub-expressions (e.g. the [0, 0, -1] passed to a Reshape node) so the builder can resolve -1 to the correct symbolic formula.

  • _ShapeRuntime — handles the special value-as-shape tracking needed by operators such as Shape, Gather, Concat, and Slice when their output feeds directly into a Reshape.

For example, if X has shape ("d1", 2) then Shape(X, start=1) is constant [2]. This can be later used to infer the shape after a reshape.

After getting an expression, a few postprocessing are applied to reduce its complexity. This relies on ast. It is done by function simplify_expression. d + f - f is replaced by d.

Symbolic Expressions#

When input shapes contain unknown (dynamic) dimensions, ShapeBuilder represents each dimension as either:

  • an integer — for statically known sizes, or

  • a string — for symbolic (dynamic) sizes.

Symbolic strings are valid Python arithmetic expressions built from the names of the original dynamic dimensions. For example, if the two inputs of a Concat(axis=1) node have shapes ("batch", "seq1") and ("batch", "seq2"), the output shape is ("batch", "seq1+seq2").

Supported operators in symbolic expressions#

  • + addition (e.g. seq1+seq2)

  • - subtraction (e.g. total-seq)

  • * multiplication (e.g. 2*seq)

  • // floor division (e.g. seq//2)

  • % modulo

  • ^ used internally to represent max(a, b) (e.g. a^b evaluates to max(a, b))

Automatic simplification#

Before storing a symbolic dimension, simplify_expression rewrites the expression to its simplest equivalent form:

<<<

from yobx.xexpressions import simplify_expression

print(simplify_expression("d + f - f"))  # d
print(simplify_expression("2 * seq // 2"))  # seq
print(simplify_expression("1024 * a // 2"))  # 512*a
print(simplify_expression("b + a"))  # a+b  (terms sorted)

>>>

    d
    seq
    512*a
    a+b

Evaluating symbolic expressions at runtime#

Once the concrete integer values of the input dimensions are known, evaluate_expression can resolve any symbolic dimension to its actual integer value. evaluate_shape applies this to a whole shape at once.

<<<

import onnx
import onnx.helper as oh
from yobx.xexpressions import evaluate_expression
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [oh.make_node("Concat", ["X", "Y"], ["Z"], axis=1)],
        "graph",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq1"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq2"]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

# Symbolic shape of Z
sym_shape = builder.get_shape("Z")
print("symbolic shape :", sym_shape)

# Evaluate each dimension given concrete values
context = dict(batch=3, seq1=5, seq2=7)
concrete = builder.evaluate_shape("Z", context)
print("concrete shape :", concrete)

>>>

    symbolic shape : ('batch', 'seq1+seq2')
    concrete shape : (3, 12)

See also

Expressions in Shape Computation — sphinx-gallery example demonstrating Concat, Reshape, and Split symbolic expressions, automatic simplification, and evaluation with concrete values.

Example#

The following example builds a small ONNX graph, runs BasicShapeBuilder on it, and prints the inferred shapes and types.

<<<

import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

# A small model: reshape X then multiply by a weight matrix W.
model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Reshape", ["X", "shape"], ["Xr"]),
            oh.make_node("MatMul", ["Xr", "W"], ["Z"]),
        ],
        "graph",
        [oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64])],
        [oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 32])],
        [
            onh.from_array(np.array([0, 0, 64], dtype=np.int64), name="shape"),
            onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
        ],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

for name in ["X", "Xr", "W", "Z"]:
    print(
        f"{name:5s}  type={builder.get_type(name)}" f"  shape={builder.get_shape(name)}"
    )

>>>

    X      type=1  shape=('batch', 'seq', 64)
    Xr     type=1  shape=('batch', 'seq', 64)
    W      type=1  shape=(64, 32)
    Z      type=1  shape=('batch', 'seq', 32)

Comparison with ONNX shape inference#

onnx.shape_inference.infer_shapes() is ONNX’s built-in shape propagation pass. It works well for models with fully static dimensions but loses symbolic relationships when dimensions are dynamic: intermediate results receive freshly generated, unrelated symbols (e.g. unk__0, unk__1) instead of expressions derived from the input dimensions.

BasicShapeBuilder does better in this case because it:

  1. Carries symbolic names — every dynamic dimension keeps the name given in the input value_info (e.g. batch, seq, d_model).

  2. Builds arithmetic expressions — when an operator changes a dimension (e.g. Concat along an axis doubles d_model) the result is stored as the string expression "2*d_model" rather than a new opaque symbol.

  3. Folds constants — initializer tensors that appear as shape arguments (e.g. the [0, 0, -1] passed to Reshape) are evaluated at inference-time, which lets the builder resolve the -1 placeholder to the correct symbolic formula.

  4. Simplifies — the resulting expression is reduced to its simplest form by simplify_expression before being stored (2*d_model//2d_model, etc.).

The table below summarises the difference for a model that applies Add Concat(axis=2) Reshape([0,0,-1]) to inputs of shape (batch, seq, d_model):

result

infer_shapes

BasicShapeBuilder

added

(batch, seq, d_model)

(batch, seq, d_model)

concat_out

(batch, seq, unk__0)

(batch, seq, 2*d_model)

Z

(batch, seq, unk__1)

(batch, seq, 2*d_model)

See Computed Shapes: Add + Concat + Reshape for a runnable example that demonstrates this comparison step by step.

Validating computed shapes#

Once the model has been run with concrete inputs you can verify that the symbolic shapes predicted by BasicShapeBuilder agree with the actual tensor shapes using compare_with_true_inputs. The method accepts the concrete input and output dictionaries (or lists) and returns, for every output result, the list of (symbolic_expr, expected, computed) triples.

<<<

import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "Y"], ["added"]),
            oh.make_node("Concat", ["added", "X"], ["Z"], axis=2),
        ],
        "add_concat",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

feeds = {
    "X": np.random.rand(2, 5, 4).astype(np.float32),
    "Y": np.random.rand(2, 5, 4).astype(np.float32),
}
session = ExtendedReferenceEvaluator(model)
outputs = session.run(None, feeds)

result = builder.compare_with_true_inputs(feeds, outputs)
for name, dims in result.items():
    print(f"{name}: {dims}")

>>>

    Z: (('batch', 2, 2), ('seq', 5, 5), ('2*d_model', 8, 8))

Each triple (expr, expected, computed) confirms that evaluating the symbolic expression with the concrete dimension values yields the same size as the tensor produced by the runtime.

Writing shapes back to a model#

update_shapes writes the inferred shapes and types back into the value_info section of the onnx.ModelProto. Inputs, outputs, and initializers are left untouched; only intermediate results (node outputs that are neither inputs, outputs, nor initializers) are annotated.

This is useful for visualisation tools and downstream passes that rely on value_info being populated.

<<<

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "Y"], ["added"]),
            oh.make_node("MatMul", ["added", "W"], ["Z"]),
        ],
        "add_matmul",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 64]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
        [onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W")],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

print("value_info before:", [vi.name for vi in model.graph.value_info])
builder.update_shapes(model)
print("value_info after :", [vi.name for vi in model.graph.value_info])

vi = model.graph.value_info[0]
t = vi.type.tensor_type
shape = tuple(d.dim_param if d.dim_param else d.dim_value for d in t.shape.dim)
print(f"  {vi.name}: dtype={t.elem_type}  shape={shape}")

>>>

    value_info before: []
    value_info after : ['added']
      added: dtype=1  shape=('batch', 'seq', 64)

Debugging Shape Inference with Environment Variables#

BasicShapeBuilder respects several environment variables that help narrow down shape-inference problems:

Environment variable

Effect

ONNXSTOPSHAPE=<name>

Raises an exception the moment result <name> receives a shape. Useful for finding the first place where a wrong shape is assigned.

ONNXSTOPTYPE=<name>

Raises an exception the moment result <name> receives a type.

ONNXDYNDIM=<name>

Prints a message every time the dynamic dimension <name> is encountered during shape propagation.

ONNXCST=1

Prints which constant value is being requested during inference.

ONNXSHAPECOMPUTE=1

Raises an exception when a shape is missing for a result that should have one.

ONNXSTOPVALUESHAPE=<name>

Prints extra information inside the function that tracks shapes of results used as shape arguments (e.g. inputs to Reshape).

In addition, get_debug_msg returns a detailed text dump of the builder’s internal state (known shapes, types, constants, ranks, and the sequence of calls) which can be printed or logged whenever an assertion fails.

Constraint Mechanism#

When BasicShapeBuilder processes a broadcasting operation (e.g. Add, Mul, Where) it computes the output shape with broadcast_shape. If one input dimension is a symbolic string (unknown at graph-construction time) and the other is a concrete integer that is not 1, the builder registers a constraint equating the symbolic name to the concrete value.

Why constraints are needed#

Without constraints, the shape of the broadcast result would be left as the symbolic name (e.g. "d_model"), and any operation that follows would inherit this uncertainty. Later, when a downstream node reveals the concrete value, the builder would have to backtrack and update all earlier shapes — an expensive and error-prone operation that is not implemented.

With constraints, the concrete value is used immediately as the output dimension, and the equality symbolic_name = concrete_value is stored. Downstream operations can propagate the concrete shape without revisiting previous nodes.

How constraints are registered#

broadcast_shape applies the following rules for each pair of aligned dimensions (a, b):

a

b

Result

Constraint registered

symbolic string

concrete int n 0, 1

n

a = n

concrete int n 0, 1

symbolic string

n

b = n

symbolic string

1

a

(none — 1 broadcasts freely)

two symbolic strings

a == b

a

(none — already equal)

two symbolic strings

a != b

a^b

(none — max expression)

The concrete integer is always chosen as the output dimension so that subsequent operations see a precise shape immediately.

Example: broadcasting after an unknown dimension#

<<<

import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder
from yobx.xshape.shape_type_compute import broadcast_shape

TFLOAT = onnx.TensorProto.FLOAT

# X has dynamic last dimension "d_model"; bias has static size 64.
model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "bias"], ["Z"]),
            oh.make_node("MatMul", ["Z", "W"], ["Out"]),
        ],
        "graph",
        [oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"])],
        [oh.make_tensor_value_info("Out", TFLOAT, [None, None, None])],
        [
            onh.from_array(np.zeros((64,), dtype=np.float32), name="bias"),
            onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
        ],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

for name in ["X", "Z", "Out"]:
    print(f"{name:5s}  shape={builder.get_shape(name)}")

# The constraint records that d_model equals 64
print("constraints:", builder.get_registered_constraints())

>>>

    X      shape=('batch', 'seq', 'd_model')
    Z      shape=('batch', 'seq', 64)
    Out    shape=('batch', 'seq', 32)
    constraints: {'d_model': {64}}

When the Add node is processed, broadcast_shape aligns ("batch", "seq", "d_model") with (64,) (right-padded to (1, 1, 64)). The pair ("d_model", 64) triggers the constraint "d_model" = 64. The output shape Z therefore becomes ("batch", "seq", 64) rather than ("batch", "seq", "d_model"), and the MatMul handler can propagate the shape of Out immediately as ("batch", "seq", 32) without any backtracking.

Constraint API#

Three methods on ShapeBuilder expose the constraint registry:

The registry is also used by _improves_dynamic_dimension_naming to replace internal opaque tokens (e.g. "s0", "DYN0") with user-visible names once the relationships between them are known.

Implementing a new shape function#

Adding support for a new operator (or overriding an existing one) requires writing a small function and registering it in yobx.xshape.shape_type_compute.

Shape functions signature#

Every shape function receives two arguments:

  • g — the ShapeBuilder instance that holds all currently known shapes, types, ranks, and devices.

  • node — the onnx.NodeProto being processed.

A minimal shape function expects to see the following API Expected API and it should do:

  1. Propagate device — if g.has_device(input) is true, copy the device to the output with set_device.

  2. Propagate type — guard with g.has_type(input) before calling set_type on every output; return None early if the type is not yet known.

  3. Compute and set the shape — guard with g.has_shape(input) before deriving the output shape and calling set_shape. When the full shape is unavailable, fall back to g.has_rank / set_rank.

  4. Return the shape (or True if only a rank was set, or None if nothing could be done).

Example: a custom element-wise scaling operator#

The following example shows a shape function for a hypothetical Scale operator (domain "my.domain") that multiplies its first input X by a scalar scale and returns a result with the same shape and type as X.

from onnx import NodeProto
from yobx.xshape.shape_builder import ShapeBuilder

def _set_shape_type_scale(g: ShapeBuilder, node: NodeProto):
    "Shape function for the custom Scale operator."
    x = node.input[0]
    out = node.output[0]

    # 1. propagate device
    if g.has_device(x):
        g.set_device(out, g.get_device(x))

    # 2. propagate element type
    if not g.has_type(x):
        return None
    g.set_type(out, g.get_type(x))

    # 3. compute output shape (same shape as input)
    if g.has_shape(x):
        shape = g.get_shape(x)
        g.set_shape(out, shape)
        return shape

    # fallback: propagate rank only
    if g.has_rank(x):
        g.set_rank(out, g.get_rank(x))
        return True

    return None

To register the function so that BasicShapeBuilder calls it automatically, add it to the appropriate registry dictionary in yobx.xshape.shape_type_compute:

  • _set_shape_type_op_any_known — for standard ONNX operators (domain "").

  • _set_shape_type_op_any_custom — for operators in non-standard domains (e.g. "com.microsoft").

# In yobx/xshape/shape_type_compute.py:
_set_shape_type_op_any_custom["Scale"] = _set_shape_type_scale

The function will then be called automatically whenever run_node processes a Scale node.