ShapeBuilder¶
onnx.shape_inference.infer_shapes() tries to infer
shapes and types based on input shapes. It does not
supports formulas and introduces new symbols.
yobx.xshape.ShapeBuilder
class walks through all nodes and looks into a list of functions
computing the output shapes based on the node type.
It tries as much as possible to express the new shape with formulas
based on the dimensions used to defined the inputs.
The list of functions is available in yobx.xshape.shape_type_compute
called from class _InferenceRuntime.
While doing this, every function may try to compute some tiny constants
in _BuilderRuntime.
This is used by _ShapeRuntime
to deduce some shapes.
For example, if X has shape ("d1", 2) then Shape(X, start=1) is constant [2].
This can be later used to infer the shape after a reshape.
After getting an expression, a few postprocessing are applied to reduce
its complexity. This relies on ast. It is done by function
simplify_expression.
d + f - f is replaced by d.
Symbolic Expressions¶
When input shapes contain unknown (dynamic) dimensions, ShapeBuilder represents each dimension as either:
an integer — for statically known sizes, or
a string — for symbolic (dynamic) sizes.
Symbolic strings are valid Python arithmetic expressions built from the names
of the original dynamic dimensions. For example, if the two inputs of a
Concat(axis=1) node have shapes ("batch", "seq1") and
("batch", "seq2"), the output shape is ("batch", "seq1+seq2").
Supported operators in symbolic expressions¶
+addition (e.g.seq1+seq2)-subtraction (e.g.total-seq)*multiplication (e.g.2*seq)//floor division (e.g.seq//2)%modulo^used internally to representmax(a, b)(e.g.a^bevaluates tomax(a, b))
Automatic simplification¶
Before storing a symbolic dimension,
simplify_expression
rewrites the expression to its simplest equivalent form:
<<<
from yobx.xshape.simplify_expressions import simplify_expression
print(simplify_expression("d + f - f")) # d
print(simplify_expression("2 * seq // 2")) # seq
print(simplify_expression("1024 * a // 2")) # 512*a
print(simplify_expression("b + a")) # a+b (terms sorted)
>>>
d
seq
512*a
a+b
Evaluating symbolic expressions at runtime¶
Once the concrete integer values of the input dimensions are known,
evaluate_expression
can resolve any symbolic dimension to its actual integer value.
evaluate_shape applies this
to a whole shape at once.
<<<
import onnx
import onnx.helper as oh
from yobx.xshape import BasicShapeBuilder
from yobx.xshape.evaluate_expressions import evaluate_expression
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[oh.make_node("Concat", ["X", "Y"], ["Z"], axis=1)],
"graph",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq1"]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq2"]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
# Symbolic shape of Z
sym_shape = builder.get_shape("Z")
print("symbolic shape :", sym_shape)
# Evaluate each dimension given concrete values
context = dict(batch=3, seq1=5, seq2=7)
concrete = builder.evaluate_shape("Z", context)
print("concrete shape :", concrete)
>>>
symbolic shape : ('batch', 'seq1+seq2')
concrete shape : (3, 12)
See also
Expressions in Shape Computation — sphinx-gallery
example demonstrating Concat, Reshape, and Split symbolic
expressions, automatic simplification, and evaluation with concrete values.
Example¶
The following example builds a small ONNX graph, runs
BasicShapeBuilder
on it, and prints the inferred shapes and types.
<<<
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
import numpy as np
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
# A small model: reshape X then multiply by a weight matrix W.
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Reshape", ["X", "shape"], ["Xr"]),
oh.make_node("MatMul", ["Xr", "W"], ["Z"]),
],
"graph",
[oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 64])],
[oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 32])],
[
onh.from_array(np.array([0, 0, 64], dtype=np.int64), name="shape"),
onh.from_array(np.random.randn(64, 32).astype(np.float32), name="W"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
builder = BasicShapeBuilder()
builder.run_model(model)
for name in ["X", "Xr", "W", "Z"]:
print(
f"{name:5s} type={builder.get_type(name)}" f" shape={builder.get_shape(name)}"
)
>>>
X type=1 shape=('batch', 'seq', 64)
Xr type=1 shape=('batch', 'seq', 64)
W type=1 shape=(64, 32)
Z type=1 shape=('batch', 'seq', 32)
Comparison with ONNX shape inference¶
onnx.shape_inference.infer_shapes() is ONNX’s built-in shape
propagation pass. It works well for models with fully static dimensions but
loses symbolic relationships when dimensions are dynamic: intermediate results
receive freshly generated, unrelated symbols (e.g. unk__0, unk__1)
instead of expressions derived from the input dimensions.
BasicShapeBuilder
does better in this case because it:
Carries symbolic names — every dynamic dimension keeps the name given in the input
value_info(e.g.batch,seq,d_model).Builds arithmetic expressions — when an operator changes a dimension (e.g.
Concatalong an axis doublesd_model) the result is stored as the string expression"2*d_model"rather than a new opaque symbol.Folds constants — initializer tensors that appear as shape arguments (e.g. the
[0, 0, -1]passed toReshape) are evaluated at inference-time, which lets the builder resolve the-1placeholder to the correct symbolic formula.Simplifies — the resulting expression is reduced to its simplest form by
simplify_expressionbefore being stored (2*d_model//2→d_model, etc.).
The table below summarises the difference for a model that applies
Add → Concat(axis=2) → Reshape([0,0,-1]) to inputs of shape
(batch, seq, d_model):
result |
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
See Computed Shapes: Add + Concat + Reshape for a runnable example that demonstrates this comparison step by step.