ShapeBuilder#
onnx.shape_inference.infer_shapes() tries to infer
shapes and types based on input shapes. It does not
supports formulas and introduces new symbols.
yobx.xshape.ShapeBuilder
class walks through all nodes and looks into a list of functions
computing the output shapes based on the node type.
It tries as much as possible to express the new shape with formulas
based on the dimensions used to defined the inputs.
The list of functions is available in yobx.xshape.shape_type_compute
called from class _InferenceRuntime.
While doing this, every function may try to compute some tiny constants
in _BuilderRuntime.
This is used by _ShapeRuntime
to deduce some shapes.
The whole algorithm relies on four components:
An analyser for expressions able to parse and simplify numerical expressions built upon name for the dynamic dimension sets of the inputs,
A list of functions inferring shapes, including the numerical expressions for every ONNX operator,
A very simple runtime able to run a short list of kernels usually used to handle shapes (Add, Sub, Mul, Div, Concat, Squeeze, Unsqueeze, Shape, Size, Reshape),
An algorithm solving constraints after inferring function was run. An unknown dimension may be known or at least constrained to a short set of values after a binary operator (or any other) was processed. The constraint mechanism is put in place to implement a kind of backward pass where output dimensions restricts the number of possible values for input dimensions.
Class Hierarchy#
BasicShapeBuilder
is the main concrete implementation and is composed of four cooperative
base classes:
ShapeBuilder— the public API contract:get_shape,get_type,get_rank,has_shape,has_type,has_rank,set_shape,set_type,set_rank,evaluate_shape,compare_with_true_inputs,update_shapes._InferenceRuntime— walks the graph node by node, dispatching each node to the matching per-operator handler inyobx.xshape.shape_type_compute._BuilderRuntime— evaluates small constant sub-expressions (e.g. the[0, 0, -1]passed to aReshapenode) so the builder can resolve-1to the correct symbolic formula._ShapeRuntime— handles the special value-as-shape tracking needed by operators such asShape,Gather,Concat, andSlicewhen their output feeds directly into aReshape.
For example, if X has shape ("d1", 2) then Shape(X, start=1) is constant [2].
This can be later used to infer the shape after a reshape.
After getting an expression, a few postprocessing are applied to reduce
its complexity. This relies on ast. It is done by function
simplify_expression.
d + f - f is replaced by d.
Symbolic Expressions#
When input shapes contain unknown (dynamic) dimensions, ShapeBuilder represents each dimension as either:
an integer — for statically known sizes, or
a string — for symbolic (dynamic) sizes.
Symbolic strings are valid Python arithmetic expressions built from the names
of the original dynamic dimensions. For example, if the two inputs of a
Concat(axis=1) node have shapes ("batch", "seq1") and
("batch", "seq2"), the output shape is ("batch", "seq1+seq2").
Supported operators in symbolic expressions#
+addition (e.g.seq1+seq2)-subtraction (e.g.total-seq)*multiplication (e.g.2*seq)//floor division (e.g.seq//2)%modulo^used internally to representmax(a, b)(e.g.a^bevaluates tomax(a, b))
Automatic simplification#
Before storing a symbolic dimension,
simplify_expression
rewrites the expression to its simplest equivalent form:
<<<
from yobx.xexpressions import simplify_expression
print(simplify_expression("d + f - f")) # d
print(simplify_expression("2 * seq // 2")) # seq
print(simplify_expression("1024 * a // 2")) # 512*a
print(simplify_expression("b + a")) # a+b (terms sorted)
>>>
d
seq
512*a
a+b
Evaluating symbolic expressions at runtime#
Once the concrete integer values of the input dimensions are known,
evaluate_expression
can resolve any symbolic dimension to its actual integer value.
evaluate_shape applies this
to a whole shape at once.
<<<
import onnx
import onnx.helper as oh
from yobx.xexpressions import evaluate_expression
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[oh.make_node("Concat", ["X", "Y"], ["Z"], axis=1)],
"graph",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq1"]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq2"]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
# Symbolic shape of Z
sym_shape = builder.get_shape("Z")
print("symbolic shape :", sym_shape)
# Evaluate each dimension given concrete values
context = dict(batch=3, seq1=5, seq2=7)
concrete = builder.evaluate_shape("Z", context)
print("concrete shape :", concrete)
>>>
symbolic shape : ('batch', 'seq1+seq2')
concrete shape : (3, 12)
See also
Expressions in Shape Computation — sphinx-gallery
example demonstrating Concat, Reshape, and Split symbolic
expressions, automatic simplification, and evaluation with concrete values.
Example#
The following example builds a small ONNX graph, runs
BasicShapeBuilder
on it, and prints the inferred shapes and types.
<<<
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
# A small model: reshape X then multiply by a weight matrix W.
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Reshape", ["X", "shape"], ["Xr"]),
oh.make_node("MatMul", ["Xr", "W"], ["Z"]),
],
"graph",
[oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64])],
[oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 32])],
[
onh.from_array(np.array([0, 0, 64], dtype=np.int64), name="shape"),
onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
for name in ["X", "Xr", "W", "Z"]:
print(
f"{name:5s} type={builder.get_type(name)}" f" shape={builder.get_shape(name)}"
)
>>>
X type=1 shape=('batch', 'seq', 64)
Xr type=1 shape=('batch', 'seq', 64)
W type=1 shape=(64, 32)
Z type=1 shape=('batch', 'seq', 32)
Comparison with ONNX shape inference#
onnx.shape_inference.infer_shapes() is ONNX’s built-in shape
propagation pass. It works well for models with fully static dimensions but
loses symbolic relationships when dimensions are dynamic: intermediate results
receive freshly generated, unrelated symbols (e.g. unk__0, unk__1)
instead of expressions derived from the input dimensions.
BasicShapeBuilder
does better in this case because it:
Carries symbolic names — every dynamic dimension keeps the name given in the input
value_info(e.g.batch,seq,d_model).Builds arithmetic expressions — when an operator changes a dimension (e.g.
Concatalong an axis doublesd_model) the result is stored as the string expression"2*d_model"rather than a new opaque symbol.Folds constants — initializer tensors that appear as shape arguments (e.g. the
[0, 0, -1]passed toReshape) are evaluated at inference-time, which lets the builder resolve the-1placeholder to the correct symbolic formula.Simplifies — the resulting expression is reduced to its simplest form by
simplify_expressionbefore being stored (2*d_model//2→d_model, etc.).
The table below summarises the difference for a model that applies
Add → Concat(axis=2) → Reshape([0,0,-1]) to inputs of shape
(batch, seq, d_model):
result |
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
See Computed Shapes: Add + Concat + Reshape for a runnable example that demonstrates this comparison step by step.
Validating computed shapes#
Once the model has been run with concrete inputs you can verify that the
symbolic shapes predicted by
BasicShapeBuilder
agree with the actual tensor shapes using
compare_with_true_inputs.
The method accepts the concrete input and output dictionaries (or lists) and
returns, for every output result, the list of (symbolic_expr, expected, computed)
triples.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["added"]),
oh.make_node("Concat", ["added", "X"], ["Z"], axis=2),
],
"add_concat",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
feeds = {
"X": np.random.rand(2, 5, 4).astype(np.float32),
"Y": np.random.rand(2, 5, 4).astype(np.float32),
}
session = ExtendedReferenceEvaluator(model)
outputs = session.run(None, feeds)
result = builder.compare_with_true_inputs(feeds, outputs)
for name, dims in result.items():
print(f"{name}: {dims}")
>>>
Z: (('batch', 2, 2), ('seq', 5, 5), ('2*d_model', 8, 8))
Each triple (expr, expected, computed) confirms that evaluating the
symbolic expression with the concrete dimension values yields the same size as
the tensor produced by the runtime.
Writing shapes back to a model#
update_shapes writes the
inferred shapes and types back into the value_info section of the
onnx.ModelProto. Inputs, outputs, and initializers are left untouched;
only intermediate results (node outputs that are neither inputs, outputs, nor
initializers) are annotated.
This is useful for visualisation tools and downstream passes that rely on
value_info being populated.
<<<
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["added"]),
oh.make_node("MatMul", ["added", "W"], ["Z"]),
],
"add_matmul",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 64]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
[onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W")],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
print("value_info before:", [vi.name for vi in model.graph.value_info])
builder.update_shapes(model)
print("value_info after :", [vi.name for vi in model.graph.value_info])
vi = model.graph.value_info[0]
t = vi.type.tensor_type
shape = tuple(d.dim_param if d.dim_param else d.dim_value for d in t.shape.dim)
print(f" {vi.name}: dtype={t.elem_type} shape={shape}")
>>>
value_info before: []
value_info after : ['added']
added: dtype=1 shape=('batch', 'seq', 64)
Debugging Shape Inference with Environment Variables#
BasicShapeBuilder
respects several environment variables that help narrow down shape-inference
problems:
Environment variable |
Effect |
|---|---|
|
Raises an exception the moment result |
|
Raises an exception the moment result |
|
Prints a message every time the dynamic dimension |
|
Prints which constant value is being requested during inference. |
|
Raises an exception when a shape is missing for a result that should have one. |
|
Prints extra information inside the function that tracks shapes of
results used as shape arguments (e.g. inputs to |
In addition, get_debug_msg returns a
detailed text dump of the builder’s internal state (known shapes, types,
constants, ranks, and the sequence of calls) which can be printed or logged
whenever an assertion fails.
Constraint Mechanism#
When BasicShapeBuilder
processes a broadcasting operation (e.g. Add, Mul, Where) it computes
the output shape with
broadcast_shape.
If one input dimension is a symbolic string (unknown at graph-construction time)
and the other is a concrete integer that is not 1, the builder registers a
constraint equating the symbolic name to the concrete value.
Why constraints are needed#
Without constraints, the shape of the broadcast result would be left as the
symbolic name (e.g. "d_model"), and any operation that follows would inherit
this uncertainty. Later, when a downstream node reveals the concrete value, the
builder would have to backtrack and update all earlier shapes — an expensive
and error-prone operation that is not implemented.
With constraints, the concrete value is used immediately as the output
dimension, and the equality symbolic_name = concrete_value is stored.
Downstream operations can propagate the concrete shape without revisiting
previous nodes.
How constraints are registered#
broadcast_shape applies
the following rules for each pair of aligned dimensions (a, b):
|
|
Result |
Constraint registered |
|---|---|---|---|
symbolic string |
concrete int |
|
|
concrete int |
symbolic string |
|
|
symbolic string |
|
|
(none — 1 broadcasts freely) |
two symbolic strings |
|
|
(none — already equal) |
two symbolic strings |
|
|
(none — max expression) |
The concrete integer is always chosen as the output dimension so that subsequent operations see a precise shape immediately.
Example: broadcasting after an unknown dimension#
<<<
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder
from yobx.xshape.shape_type_compute import broadcast_shape
TFLOAT = onnx.TensorProto.FLOAT
# X has dynamic last dimension "d_model"; bias has static size 64.
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "bias"], ["Z"]),
oh.make_node("MatMul", ["Z", "W"], ["Out"]),
],
"graph",
[oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"])],
[oh.make_tensor_value_info("Out", TFLOAT, [None, None, None])],
[
onh.from_array(np.zeros((64,), dtype=np.float32), name="bias"),
onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
for name in ["X", "Z", "Out"]:
print(f"{name:5s} shape={builder.get_shape(name)}")
# The constraint records that d_model equals 64
print("constraints:", builder.get_registered_constraints())
>>>
X shape=('batch', 'seq', 'd_model')
Z shape=('batch', 'seq', 64)
Out shape=('batch', 'seq', 32)
constraints: {'d_model': {64}}
When the Add node is processed, broadcast_shape aligns
("batch", "seq", "d_model") with (64,) (right-padded to
(1, 1, 64)). The pair ("d_model", 64) triggers the constraint
"d_model" = 64. The output shape Z therefore becomes
("batch", "seq", 64) rather than ("batch", "seq", "d_model"), and
the MatMul handler can propagate the shape of Out immediately as
("batch", "seq", 32) without any backtracking.
Constraint API#
Three methods on ShapeBuilder
expose the constraint registry:
register_constraint_dimension(dim_name, value)— record that the symbolic dimensiondim_nameis equal tovalue(an integer or another symbolic name). Called automatically bybroadcast_shapewhen needed.add_to_constraints(dim_name, value)— lower-level helper that accepts a set of values as well as a single value.get_registered_constraints()— returns the full mapping{dim_name: {values}}accumulated so far.
The registry is also used by
_improves_dynamic_dimension_naming to replace
internal opaque tokens (e.g. "s0", "DYN0") with user-visible names
once the relationships between them are known.
Implementing a new shape function#
Adding support for a new operator (or overriding an existing one) requires
writing a small function and registering it in
yobx.xshape.shape_type_compute.
Shape functions signature#
Every shape function receives two arguments:
g— theShapeBuilderinstance that holds all currently known shapes, types, ranks, and devices.node— theonnx.NodeProtobeing processed.
A minimal shape function expects to see the following API Expected API and it should do:
Propagate device — if
g.has_device(input)is true, copy the device to the output withset_device.Propagate type — guard with
g.has_type(input)before callingset_typeon every output; returnNoneearly if the type is not yet known.Compute and set the shape — guard with
g.has_shape(input)before deriving the output shape and callingset_shape. When the full shape is unavailable, fall back tog.has_rank/set_rank.Return the shape (or
Trueif only a rank was set, orNoneif nothing could be done).
Example: a custom element-wise scaling operator#
The following example shows a shape function for a hypothetical Scale
operator (domain "my.domain") that multiplies its first input X by a
scalar scale and returns a result with the same shape and type as X.
from onnx import NodeProto
from yobx.xshape.shape_builder import ShapeBuilder
def _set_shape_type_scale(g: ShapeBuilder, node: NodeProto):
"Shape function for the custom Scale operator."
x = node.input[0]
out = node.output[0]
# 1. propagate device
if g.has_device(x):
g.set_device(out, g.get_device(x))
# 2. propagate element type
if not g.has_type(x):
return None
g.set_type(out, g.get_type(x))
# 3. compute output shape (same shape as input)
if g.has_shape(x):
shape = g.get_shape(x)
g.set_shape(out, shape)
return shape
# fallback: propagate rank only
if g.has_rank(x):
g.set_rank(out, g.get_rank(x))
return True
return None
To register the function so that
BasicShapeBuilder
calls it automatically, add it to the appropriate registry dictionary in
yobx.xshape.shape_type_compute:
_set_shape_type_op_any_known— for standard ONNX operators (domain"")._set_shape_type_op_any_custom— for operators in non-standard domains (e.g."com.microsoft").
# In yobx/xshape/shape_type_compute.py:
_set_shape_type_op_any_custom["Scale"] = _set_shape_type_scale
The function will then be called automatically whenever
run_node
processes a Scale node.