ShapeBuilder#
onnx.shape_inference.infer_shapes() tries to infer
shapes and types based on input shapes. It does not
supports formulas and introduces new symbols.
yobx.xshape.ShapeBuilder
class walks through all nodes and looks into a list of functions
computing the output shapes based on the node type.
It tries as much as possible to express the new shape with formulas
based on the dimensions used to defined the inputs.
The list of functions is available in yobx.xshape.shape_type_compute
called from class _InferenceRuntime.
While doing this, every function may try to compute some tiny constants
in _BuilderRuntime.
This is used by _ShapeRuntime
to deduce some shapes.
Class Hierarchy#
BasicShapeBuilder
is the main concrete implementation and is composed of four cooperative
base classes:
ShapeBuilder— the public API contract:get_shape,get_type,get_rank,has_shape,has_type,has_rank,set_shape,set_type,set_rank,evaluate_shape,compare_with_true_inputs,update_shapes._InferenceRuntime— walks the graph node by node, dispatching each node to the matching per-operator handler inyobx.xshape.shape_type_compute._BuilderRuntime— evaluates small constant sub-expressions (e.g. the[0, 0, -1]passed to aReshapenode) so the builder can resolve-1to the correct symbolic formula._ShapeRuntime— handles the special value-as-shape tracking needed by operators such asShape,Gather,Concat, andSlicewhen their output feeds directly into aReshape.
For example, if X has shape ("d1", 2) then Shape(X, start=1) is constant [2].
This can be later used to infer the shape after a reshape.
After getting an expression, a few postprocessing are applied to reduce
its complexity. This relies on ast. It is done by function
simplify_expression.
d + f - f is replaced by d.
Symbolic Expressions#
When input shapes contain unknown (dynamic) dimensions, ShapeBuilder represents each dimension as either:
an integer — for statically known sizes, or
a string — for symbolic (dynamic) sizes.
Symbolic strings are valid Python arithmetic expressions built from the names
of the original dynamic dimensions. For example, if the two inputs of a
Concat(axis=1) node have shapes ("batch", "seq1") and
("batch", "seq2"), the output shape is ("batch", "seq1+seq2").
Supported operators in symbolic expressions#
+addition (e.g.seq1+seq2)-subtraction (e.g.total-seq)*multiplication (e.g.2*seq)//floor division (e.g.seq//2)%modulo^used internally to representmax(a, b)(e.g.a^bevaluates tomax(a, b))
Automatic simplification#
Before storing a symbolic dimension,
simplify_expression
rewrites the expression to its simplest equivalent form:
<<<
from yobx.xexpressions import simplify_expression
print(simplify_expression("d + f - f")) # d
print(simplify_expression("2 * seq // 2")) # seq
print(simplify_expression("1024 * a // 2")) # 512*a
print(simplify_expression("b + a")) # a+b (terms sorted)
>>>
d
seq
512*a
a+b
Evaluating symbolic expressions at runtime#
Once the concrete integer values of the input dimensions are known,
evaluate_expression
can resolve any symbolic dimension to its actual integer value.
evaluate_shape applies this
to a whole shape at once.
<<<
import onnx
import onnx.helper as oh
from yobx.xexpressions import evaluate_expression
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[oh.make_node("Concat", ["X", "Y"], ["Z"], axis=1)],
"graph",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq1"]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq2"]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
# Symbolic shape of Z
sym_shape = builder.get_shape("Z")
print("symbolic shape :", sym_shape)
# Evaluate each dimension given concrete values
context = dict(batch=3, seq1=5, seq2=7)
concrete = builder.evaluate_shape("Z", context)
print("concrete shape :", concrete)
>>>
symbolic shape : ('batch', 'seq1+seq2')
concrete shape : (3, 12)
See also
Expressions in Shape Computation — sphinx-gallery
example demonstrating Concat, Reshape, and Split symbolic
expressions, automatic simplification, and evaluation with concrete values.
Example#
The following example builds a small ONNX graph, runs
BasicShapeBuilder
on it, and prints the inferred shapes and types.
<<<
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
# A small model: reshape X then multiply by a weight matrix W.
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Reshape", ["X", "shape"], ["Xr"]),
oh.make_node("MatMul", ["Xr", "W"], ["Z"]),
],
"graph",
[oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64])],
[oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 32])],
[
onh.from_array(np.array([0, 0, 64], dtype=np.int64), name="shape"),
onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
for name in ["X", "Xr", "W", "Z"]:
print(
f"{name:5s} type={builder.get_type(name)}" f" shape={builder.get_shape(name)}"
)
>>>
X type=1 shape=('batch', 'seq', 64)
Xr type=1 shape=('batch', 'seq', 64)
W type=1 shape=(64, 32)
Z type=1 shape=('batch', 'seq', 32)
Comparison with ONNX shape inference#
onnx.shape_inference.infer_shapes() is ONNX’s built-in shape
propagation pass. It works well for models with fully static dimensions but
loses symbolic relationships when dimensions are dynamic: intermediate results
receive freshly generated, unrelated symbols (e.g. unk__0, unk__1)
instead of expressions derived from the input dimensions.
BasicShapeBuilder
does better in this case because it:
Carries symbolic names — every dynamic dimension keeps the name given in the input
value_info(e.g.batch,seq,d_model).Builds arithmetic expressions — when an operator changes a dimension (e.g.
Concatalong an axis doublesd_model) the result is stored as the string expression"2*d_model"rather than a new opaque symbol.Folds constants — initializer tensors that appear as shape arguments (e.g. the
[0, 0, -1]passed toReshape) are evaluated at inference-time, which lets the builder resolve the-1placeholder to the correct symbolic formula.Simplifies — the resulting expression is reduced to its simplest form by
simplify_expressionbefore being stored (2*d_model//2→d_model, etc.).
The table below summarises the difference for a model that applies
Add → Concat(axis=2) → Reshape([0,0,-1]) to inputs of shape
(batch, seq, d_model):
result |
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
See Computed Shapes: Add + Concat + Reshape for a runnable example that demonstrates this comparison step by step.
Validating computed shapes#
Once the model has been run with concrete inputs you can verify that the
symbolic shapes predicted by
BasicShapeBuilder
agree with the actual tensor shapes using
compare_with_true_inputs.
The method accepts the concrete input and output dictionaries (or lists) and
returns, for every output result, the list of (symbolic_expr, expected, computed)
triples.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["added"]),
oh.make_node("Concat", ["added", "X"], ["Z"], axis=2),
],
"add_concat",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
feeds = {
"X": np.random.rand(2, 5, 4).astype(np.float32),
"Y": np.random.rand(2, 5, 4).astype(np.float32),
}
session = ExtendedReferenceEvaluator(model)
outputs = session.run(None, feeds)
result = builder.compare_with_true_inputs(feeds, outputs)
for name, dims in result.items():
print(f"{name}: {dims}")
>>>
Z: (('batch', 2, 2), ('seq', 5, 5), ('2*d_model', 8, 8))
Each triple (expr, expected, computed) confirms that evaluating the
symbolic expression with the concrete dimension values yields the same size as
the tensor produced by the runtime.
Writing shapes back to a model#
update_shapes writes the
inferred shapes and types back into the value_info section of the
onnx.ModelProto. Inputs, outputs, and initializers are left untouched;
only intermediate results (node outputs that are neither inputs, outputs, nor
initializers) are annotated.
This is useful for visualisation tools and downstream passes that rely on
value_info being populated.
<<<
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["added"]),
oh.make_node("MatMul", ["added", "W"], ["Z"]),
],
"add_matmul",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 64]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
[onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W")],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
print("value_info before:", [vi.name for vi in model.graph.value_info])
builder.update_shapes(model)
print("value_info after :", [vi.name for vi in model.graph.value_info])
vi = model.graph.value_info[0]
t = vi.type.tensor_type
shape = tuple(d.dim_param if d.dim_param else d.dim_value for d in t.shape.dim)
print(f" {vi.name}: dtype={t.elem_type} shape={shape}")
>>>
value_info before: []
value_info after : ['added']
added: dtype=1 shape=('batch', 'seq', 64)
Debugging#
BasicShapeBuilder
respects several environment variables that help narrow down shape-inference
problems:
Environment variable |
Effect |
|---|---|
|
Raises an exception the moment result |
|
Raises an exception the moment result |
|
Prints a message every time the dynamic dimension |
|
Prints which constant value is being requested during inference. |
|
Raises an exception when a shape is missing for a result that should have one. |
|
Prints extra information inside the function that tracks shapes of
results used as shape arguments (e.g. inputs to |
In addition, get_debug_msg returns a
detailed text dump of the builder’s internal state (known shapes, types,
constants, ranks, and the sequence of calls) which can be printed or logged
whenever an assertion fails.
Implementing a new shape function#
Adding support for a new operator (or overriding an existing one) requires
writing a small function and registering it in
yobx.xshape.shape_type_compute.
Shape functions signature#
Every shape function receives two arguments:
g— theShapeBuilderinstance that holds all currently known shapes, types, ranks, and devices.node— theonnx.NodeProtobeing processed.
A minimal shape function expects to see the following API Expected API and it should do:
Propagate device — if
g.has_device(input)is true, copy the device to the output withset_device.Propagate type — guard with
g.has_type(input)before callingset_typeon every output; returnNoneearly if the type is not yet known.Compute and set the shape — guard with
g.has_shape(input)before deriving the output shape and callingset_shape. When the full shape is unavailable, fall back tog.has_rank/set_rank.Return the shape (or
Trueif only a rank was set, orNoneif nothing could be done).
Example: a custom element-wise scaling operator#
The following example shows a shape function for a hypothetical Scale
operator (domain "my.domain") that multiplies its first input X by a
scalar scale and returns a result with the same shape and type as X.
from onnx import NodeProto
from yobx.xshape.shape_builder import ShapeBuilder
def _set_shape_type_scale(g: ShapeBuilder, node: NodeProto):
"Shape function for the custom Scale operator."
x = node.input[0]
out = node.output[0]
# 1. propagate device
if g.has_device(x):
g.set_device(out, g.get_device(x))
# 2. propagate element type
if not g.has_type(x):
return None
g.set_type(out, g.get_type(x))
# 3. compute output shape (same shape as input)
if g.has_shape(x):
shape = g.get_shape(x)
g.set_shape(out, shape)
return shape
# fallback: propagate rank only
if g.has_rank(x):
g.set_rank(out, g.get_rank(x))
return True
return None
To register the function so that
BasicShapeBuilder
calls it automatically, add it to the appropriate registry dictionary in
yobx.xshape.shape_type_compute:
_set_shape_type_op_any_known— for standard ONNX operators (domain"")._set_shape_type_op_any_custom— for operators in non-standard domains (e.g."com.microsoft").
# In yobx/xshape/shape_type_compute.py:
_set_shape_type_op_any_custom["Scale"] = _set_shape_type_scale
The function will then be called automatically whenever
run_node
processes a Scale node.