Expressions in Shape Computation

When an ONNX model contains dynamic (unknown) input dimensions, BasicShapeBuilder represents every output dimension as either a plain integer or a symbolic string expression built from the names of the input dimensions.

This example walks through several common patterns:

  • Concat — adds dimension names to produce "seq1+seq2"

  • Reshape with -1 — uses floor-division to produce "c//2"

  • Split — introduces ceiling-division via CeilToInt(…)

  • Automatic simplificationd + f - fd, 2*x//2x

  • Evaluation — resolving symbolic shapes to concrete integers once the actual dimension values are known

How it works

BasicShapeBuilder walks every node of the ONNX graph in order. For each node it calls an op-specific handler (defined in yobx.xshape.shape_type_compute) that derives the output shape from the input shapes. When a dimension cannot be expressed as a plain integer it is stored as a Python-arithmetic string (e.g. "seq1+seq2", "c//2"). Before storing, the string is normalised by simplify_expression, which uses Python’s ast module to cancel identical sub-expressions (d + f - fd) and fold constants (2 * seq // 2seq). Once the actual input sizes are available at runtime, every symbolic dimension can be resolved to an integer by evaluate_expression or the higher-level evaluate_shape.

For a deeper description of the design, see the ShapeBuilder design page.

See also

Concat: summing two dynamic dimensions

When two tensors are concatenated along a dynamic axis, the output dimension is the sum of the two input dimensions. Because both seq1 and seq2 are unknown at graph-construction time, the result is the symbolic expression "seq1+seq2".

import onnx
import numpy as np
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.xshape import BasicShapeBuilder
from yobx.xshape.simplify_expressions import simplify_expression

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [oh.make_node("Concat", ["X", "Y"], ["Z"], axis=1)],
        "concat_graph",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq1"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq2"]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder = BasicShapeBuilder()
builder.run_model(model)

print("shape of X :", builder.get_shape("X"))
print("shape of Y :", builder.get_shape("Y"))
print("shape of Z :", builder.get_shape("Z"))  # ('batch', 'seq1+seq2')
shape of X : ('batch', 'seq1')
shape of Y : ('batch', 'seq2')
shape of Z : ('batch', 'seq1+seq2')

Evaluating symbolic shapes with concrete values

Once we know the actual sizes of the input dimensions we can resolve every symbolic dimension to an integer with evaluate_shape.

context = dict(batch=2, seq1=5, seq2=7)
print("concrete shape of Z:", builder.evaluate_shape("Z", context))  # (2, 12)
concrete shape of Z: (2, 12)

Reshape: floor-division expressions

A Reshape node that halves a dynamic dimension produces the symbolic expression "c//2". The -1 sentinel in the target shape is resolved to the appropriate quotient expression.

model_reshape = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Reshape", ["X", "shape"], ["Xr"]),
        ],
        "reshape_graph",
        [oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"])],
        [oh.make_tensor_value_info("Xr", TFLOAT, [None, None, None, None])],
        [onh.from_array(np.array([0, 0, 2, -1], dtype=np.int64), name="shape")],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder_reshape = BasicShapeBuilder()
builder_reshape.run_model(model_reshape)

print("shape of X  :", builder_reshape.get_shape("X"))
print("shape of Xr :", builder_reshape.get_shape("Xr"))  # ('a', 'b', 2, 'c//2')
shape of X  : ('a', 'b', 'c')
shape of Xr : ('a', 'b', 2, 'c//2')

Split: ceiling-division expressions

Split with num_outputs=2 and no explicit split attribute divides the axis dimension as evenly as possible. When the dimension is odd the two halves differ by one, which is captured by the expression CeilToInt(b+c, 2) for the first output and b+c - CeilToInt(b+c, 2) for the second.

model_split = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Concat", ["X", "Y"], ["xy"], axis=1),
            oh.make_node("Split", ["xy"], ["S1", "S2"], axis=1, num_outputs=2),
        ],
        "split_graph",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["a", "b"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["a", "c"]),
        ],
        [
            oh.make_tensor_value_info("S1", TFLOAT, [None, None]),
            oh.make_tensor_value_info("S2", TFLOAT, [None, None]),
        ],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

builder_split = BasicShapeBuilder()
builder_split.run_model(model_split)

print("shape of xy :", builder_split.get_shape("xy"))
print("shape of S1 :", builder_split.get_shape("S1"))
print("shape of S2 :", builder_split.get_shape("S2"))

context_split = dict(a=3, b=4, c=6)
print("concrete shape of S1:", builder_split.evaluate_shape("S1", context_split))
print("concrete shape of S2:", builder_split.evaluate_shape("S2", context_split))
shape of xy : ('a', 'b+c')
shape of S1 : ('a', 'CeilToInt(b+c,2)')
shape of S2 : ('a', 'b+c-CeilToInt(b+c,2)')
concrete shape of S1: (3, 5)
concrete shape of S2: (3, 5)

Automatic expression simplification

Before storing a symbolic dimension, simplify_expression reduces the expression to its simplest equivalent form.

examples = [
    "d + f - f",  # cancellation → d
    "2 * seq // 2",  # multiplication and floor-division cancel → seq
    "1024 * a // 2",  # partial fold → 512*a
    "b + a",  # terms are sorted → a+b
]

for expr in examples:
    print(f"  simplify({expr!r:20s}) = {simplify_expression(expr)!r}")
simplify('d + f - f'         ) = 'd'
simplify('2 * seq // 2'      ) = 'seq'
simplify('1024 * a // 2'     ) = '512*a'
simplify('b + a'             ) = 'a+b'

Total running time of the script: (0 minutes 0.014 seconds)

Related examples

Computed Shapes: Add + Concat + Reshape

Computed Shapes: Add + Concat + Reshape

ONNX Graph Visualization with to_dot

ONNX Graph Visualization with to_dot

ExtendedReferenceEvaluator: running models with contrib operators

ExtendedReferenceEvaluator: running models with contrib operators

Gallery generated by Sphinx-Gallery