ShapeBuilder#

onnx.shape_inference.infer_shapes() tries to infer shapes and types based on input shapes. It does not supports formulas and introduces new symbols.

yobx.xshape.ShapeBuilder class walks through all nodes and looks into a list of functions computing the output shapes based on the node type. It tries as much as possible to express the new shape with formulas based on the dimensions used to defined the inputs. The list of functions is available in yobx.xshape.shape_type_compute called from class _InferenceRuntime.

While doing this, every function may try to compute some tiny constants in _BuilderRuntime. This is used by _ShapeRuntime to deduce some shapes.

Class Hierarchy#

BasicShapeBuilder is the main concrete implementation and is composed of four cooperative base classes:

  • ShapeBuilder — the public API contract: get_shape, get_type, get_rank, has_shape, has_type, has_rank, set_shape, set_type, set_rank, evaluate_shape, compare_with_true_inputs, update_shapes.

  • _InferenceRuntime — walks the graph node by node, dispatching each node to the matching per-operator handler in yobx.xshape.shape_type_compute.

  • _BuilderRuntime — evaluates small constant sub-expressions (e.g. the [0, 0, -1] passed to a Reshape node) so the builder can resolve -1 to the correct symbolic formula.

  • _ShapeRuntime — handles the special value-as-shape tracking needed by operators such as Shape, Gather, Concat, and Slice when their output feeds directly into a Reshape.

For example, if X has shape ("d1", 2) then Shape(X, start=1) is constant [2]. This can be later used to infer the shape after a reshape.

After getting an expression, a few postprocessing are applied to reduce its complexity. This relies on ast. It is done by function simplify_expression. d + f - f is replaced by d.

Symbolic Expressions#

When input shapes contain unknown (dynamic) dimensions, ShapeBuilder represents each dimension as either:

  • an integer — for statically known sizes, or

  • a string — for symbolic (dynamic) sizes.

Symbolic strings are valid Python arithmetic expressions built from the names of the original dynamic dimensions. For example, if the two inputs of a Concat(axis=1) node have shapes ("batch", "seq1") and ("batch", "seq2"), the output shape is ("batch", "seq1+seq2").

Supported operators in symbolic expressions#

  • + addition (e.g. seq1+seq2)

  • - subtraction (e.g. total-seq)

  • * multiplication (e.g. 2*seq)

  • // floor division (e.g. seq//2)

  • % modulo

  • ^ used internally to represent max(a, b) (e.g. a^b evaluates to max(a, b))

Automatic simplification#

Before storing a symbolic dimension, simplify_expression rewrites the expression to its simplest equivalent form:

<<<

from yobx.xexpressions import simplify_expression

print(simplify_expression("d + f - f"))  # d
print(simplify_expression("2 * seq // 2"))  # seq
print(simplify_expression("1024 * a // 2"))  # 512*a
print(simplify_expression("b + a"))  # a+b  (terms sorted)

>>>

    d
    seq
    512*a
    a+b

Evaluating symbolic expressions at runtime#

Once the concrete integer values of the input dimensions are known, evaluate_expression can resolve any symbolic dimension to its actual integer value. evaluate_shape applies this to a whole shape at once.

<<<

import onnx
import onnx.helper as oh
from yobx.xexpressions import evaluate_expression
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [oh.make_node("Concat", ["X", "Y"], ["Z"], axis=1)],
        "graph",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq1"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq2"]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

# Symbolic shape of Z
sym_shape = builder.get_shape("Z")
print("symbolic shape :", sym_shape)

# Evaluate each dimension given concrete values
context = dict(batch=3, seq1=5, seq2=7)
concrete = builder.evaluate_shape("Z", context)
print("concrete shape :", concrete)

>>>

    symbolic shape : ('batch', 'seq1+seq2')
    concrete shape : (3, 12)

See also

Expressions in Shape Computation — sphinx-gallery example demonstrating Concat, Reshape, and Split symbolic expressions, automatic simplification, and evaluation with concrete values.

Example#

The following example builds a small ONNX graph, runs BasicShapeBuilder on it, and prints the inferred shapes and types.

<<<

import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

# A small model: reshape X then multiply by a weight matrix W.
model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Reshape", ["X", "shape"], ["Xr"]),
            oh.make_node("MatMul", ["Xr", "W"], ["Z"]),
        ],
        "graph",
        [oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64])],
        [oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 32])],
        [
            onh.from_array(np.array([0, 0, 64], dtype=np.int64), name="shape"),
            onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
        ],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

for name in ["X", "Xr", "W", "Z"]:
    print(
        f"{name:5s}  type={builder.get_type(name)}" f"  shape={builder.get_shape(name)}"
    )

>>>

    X      type=1  shape=('batch', 'seq', 64)
    Xr     type=1  shape=('batch', 'seq', 64)
    W      type=1  shape=(64, 32)
    Z      type=1  shape=('batch', 'seq', 32)

Comparison with ONNX shape inference#

onnx.shape_inference.infer_shapes() is ONNX’s built-in shape propagation pass. It works well for models with fully static dimensions but loses symbolic relationships when dimensions are dynamic: intermediate results receive freshly generated, unrelated symbols (e.g. unk__0, unk__1) instead of expressions derived from the input dimensions.

BasicShapeBuilder does better in this case because it:

  1. Carries symbolic names — every dynamic dimension keeps the name given in the input value_info (e.g. batch, seq, d_model).

  2. Builds arithmetic expressions — when an operator changes a dimension (e.g. Concat along an axis doubles d_model) the result is stored as the string expression "2*d_model" rather than a new opaque symbol.

  3. Folds constants — initializer tensors that appear as shape arguments (e.g. the [0, 0, -1] passed to Reshape) are evaluated at inference-time, which lets the builder resolve the -1 placeholder to the correct symbolic formula.

  4. Simplifies — the resulting expression is reduced to its simplest form by simplify_expression before being stored (2*d_model//2d_model, etc.).

The table below summarises the difference for a model that applies Add Concat(axis=2) Reshape([0,0,-1]) to inputs of shape (batch, seq, d_model):

result

infer_shapes

BasicShapeBuilder

added

(batch, seq, d_model)

(batch, seq, d_model)

concat_out

(batch, seq, unk__0)

(batch, seq, 2*d_model)

Z

(batch, seq, unk__1)

(batch, seq, 2*d_model)

See Computed Shapes: Add + Concat + Reshape for a runnable example that demonstrates this comparison step by step.

Validating computed shapes#

Once the model has been run with concrete inputs you can verify that the symbolic shapes predicted by BasicShapeBuilder agree with the actual tensor shapes using compare_with_true_inputs. The method accepts the concrete input and output dictionaries (or lists) and returns, for every output result, the list of (symbolic_expr, expected, computed) triples.

<<<

import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "Y"], ["added"]),
            oh.make_node("Concat", ["added", "X"], ["Z"], axis=2),
        ],
        "add_concat",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

feeds = {
    "X": np.random.rand(2, 5, 4).astype(np.float32),
    "Y": np.random.rand(2, 5, 4).astype(np.float32),
}
session = ExtendedReferenceEvaluator(model)
outputs = session.run(None, feeds)

result = builder.compare_with_true_inputs(feeds, outputs)
for name, dims in result.items():
    print(f"{name}: {dims}")

>>>

    Z: (('batch', 2, 2), ('seq', 5, 5), ('2*d_model', 8, 8))

Each triple (expr, expected, computed) confirms that evaluating the symbolic expression with the concrete dimension values yields the same size as the tensor produced by the runtime.

Writing shapes back to a model#

update_shapes writes the inferred shapes and types back into the value_info section of the onnx.ModelProto. Inputs, outputs, and initializers are left untouched; only intermediate results (node outputs that are neither inputs, outputs, nor initializers) are annotated.

This is useful for visualisation tools and downstream passes that rely on value_info being populated.

<<<

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.xshape import BasicShapeBuilder

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "Y"], ["added"]),
            oh.make_node("MatMul", ["added", "W"], ["Z"]),
        ],
        "add_matmul",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 64]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
        [onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W")],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

print("value_info before:", [vi.name for vi in model.graph.value_info])
builder.update_shapes(model)
print("value_info after :", [vi.name for vi in model.graph.value_info])

vi = model.graph.value_info[0]
t = vi.type.tensor_type
shape = tuple(d.dim_param if d.dim_param else d.dim_value for d in t.shape.dim)
print(f"  {vi.name}: dtype={t.elem_type}  shape={shape}")

>>>

    value_info before: []
    value_info after : ['added']
      added: dtype=1  shape=('batch', 'seq', 64)

Debugging#

BasicShapeBuilder respects several environment variables that help narrow down shape-inference problems:

Environment variable

Effect

ONNXSTOPSHAPE=<name>

Raises an exception the moment result <name> receives a shape. Useful for finding the first place where a wrong shape is assigned.

ONNXSTOPTYPE=<name>

Raises an exception the moment result <name> receives a type.

ONNXDYNDIM=<name>

Prints a message every time the dynamic dimension <name> is encountered during shape propagation.

ONNXCST=1

Prints which constant value is being requested during inference.

ONNXSHAPECOMPUTE=1

Raises an exception when a shape is missing for a result that should have one.

ONNXSTOPVALUESHAPE=<name>

Prints extra information inside the function that tracks shapes of results used as shape arguments (e.g. inputs to Reshape).

In addition, get_debug_msg returns a detailed text dump of the builder’s internal state (known shapes, types, constants, ranks, and the sequence of calls) which can be printed or logged whenever an assertion fails.

Implementing a new shape function#

Adding support for a new operator (or overriding an existing one) requires writing a small function and registering it in yobx.xshape.shape_type_compute.

Shape functions signature#

Every shape function receives two arguments:

  • g — the ShapeBuilder instance that holds all currently known shapes, types, ranks, and devices.

  • node — the onnx.NodeProto being processed.

A minimal shape function expects to see the following API Expected API and it should do:

  1. Propagate device — if g.has_device(input) is true, copy the device to the output with set_device.

  2. Propagate type — guard with g.has_type(input) before calling set_type on every output; return None early if the type is not yet known.

  3. Compute and set the shape — guard with g.has_shape(input) before deriving the output shape and calling set_shape. When the full shape is unavailable, fall back to g.has_rank / set_rank.

  4. Return the shape (or True if only a rank was set, or None if nothing could be done).

Example: a custom element-wise scaling operator#

The following example shows a shape function for a hypothetical Scale operator (domain "my.domain") that multiplies its first input X by a scalar scale and returns a result with the same shape and type as X.

from onnx import NodeProto
from yobx.xshape.shape_builder import ShapeBuilder

def _set_shape_type_scale(g: ShapeBuilder, node: NodeProto):
    "Shape function for the custom Scale operator."
    x = node.input[0]
    out = node.output[0]

    # 1. propagate device
    if g.has_device(x):
        g.set_device(out, g.get_device(x))

    # 2. propagate element type
    if not g.has_type(x):
        return None
    g.set_type(out, g.get_type(x))

    # 3. compute output shape (same shape as input)
    if g.has_shape(x):
        shape = g.get_shape(x)
        g.set_shape(out, shape)
        return shape

    # fallback: propagate rank only
    if g.has_rank(x):
        g.set_rank(out, g.get_rank(x))
        return True

    return None

To register the function so that BasicShapeBuilder calls it automatically, add it to the appropriate registry dictionary in yobx.xshape.shape_type_compute:

  • _set_shape_type_op_any_known — for standard ONNX operators (domain "").

  • _set_shape_type_op_any_custom — for operators in non-standard domains (e.g. "com.microsoft").

# In yobx/xshape/shape_type_compute.py:
_set_shape_type_op_any_custom["Scale"] = _set_shape_type_scale

The function will then be called automatically whenever run_node processes a Scale node.