Note
Go to the end to download the full example code.
Computed Shapes: Add + Concat + Reshape¶
This example shows how BasicShapeBuilder tracks symbolic dimension
expressions through a sequence of Add, Concat, and Reshape nodes,
and compares the result with the standard
onnx.shape_inference.infer_shapes().
The key difference is that onnx.shape_inference.infer_shapes can only
propagate shapes when dimensions are statically known integers. When the
model contains dynamic (symbolic) dimensions it typically assigns None
(unknown) to most intermediate results. BasicShapeBuilder instead
keeps the dimensions as symbolic arithmetic expressions so that output shapes
are expressed in terms of the input dimension names.
See ShapeBuilder for a detailed description of how
BasicShapeBuilder
works and a comparison table with onnx.shape_inference.infer_shapes().
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.reference import ExtendedReferenceEvaluator
from yobx.xshape import BasicShapeBuilder
TFLOAT = onnx.TensorProto.FLOAT
Build a small model¶
The graph performs the following steps:
Add(X, Y)— element-wise addition of two tensors with shape(batch, seq, d_model).Concat(added, X, axis=2)— concatenate the result with the originalXalong the last axis, giving shape(batch, seq, 2*d_model).Reshape(concat_out, shape)— flatten the last two dimensions using a fixed shape constant[0, 0, -1], which collapses(batch, seq, 2*d_model)back to(batch, seq, 2*d_model).
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["added"]),
oh.make_node("Concat", ["added", "X"], ["concat_out"], axis=2),
oh.make_node("Reshape", ["concat_out", "reshape_shape"], ["Z"]),
],
"add_concat_reshape",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])],
[
onh.from_array(np.array([0, 0, -1], dtype=np.int64), name="reshape_shape"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
Shape inference with ONNX¶
onnx.shape_inference.infer_shapes propagates shapes through the model.
For dynamic dimensions the inferred shapes for intermediate results are
often unknown (None).
inferred = onnx.shape_inference.infer_shapes(model)
print("=== onnx.shape_inference.infer_shapes ===")
for vi in (
list(inferred.graph.input) + list(inferred.graph.value_info) + list(inferred.graph.output)
):
t = vi.type.tensor_type
if t.HasField("shape"):
shape = tuple(
d.dim_param if d.dim_param else (d.dim_value if d.dim_value else None)
for d in t.shape.dim
)
else:
shape = "unknown"
print(f" {vi.name:15s} shape={shape}")
=== onnx.shape_inference.infer_shapes ===
X shape=('batch', 'seq', 'd_model')
Y shape=('batch', 'seq', 'd_model')
added shape=('batch', 'seq', 'd_model')
concat_out shape=('batch', 'seq', 'unk__0')
Z shape=('batch', 'seq', 'unk__1')
Shape inference with BasicShapeBuilder¶
BasicShapeBuilder
keeps the shapes as symbolic expressions. Because reshape_shape is a
constant [0, 0, -1], the builder can evaluate the Reshape and express
the output shape as a function of the input dimensions.
=== BasicShapeBuilder ===
X shape=('batch', 'seq', 'd_model')
Y shape=('batch', 'seq', 'd_model')
added shape=('batch', 'seq', 'd_model')
concat_out shape=('batch', 'seq', '2*d_model')
Z shape=('batch', 'seq', '2*d_model')
Evaluate symbolic shapes with concrete values¶
Once the concrete values of the dynamic dimensions are known,
evaluate_shape resolves
each symbolic expression to its actual integer value.
X concrete shape=(2, 5, 8)
Y concrete shape=(2, 5, 8)
added concrete shape=(2, 5, 8)
concat_out concrete shape=(2, 5, 16)
Z concrete shape=(2, 5, 16)
Verify with real data¶
Finally, run the model with concrete numpy arrays and confirm that the
shapes predicted by BasicShapeBuilder match the actual output
shapes.
feeds = {
"X": np.random.rand(2, 5, 8).astype(np.float32),
"Y": np.random.rand(2, 5, 8).astype(np.float32),
}
session = ExtendedReferenceEvaluator(model)
outputs = session.run(None, feeds)
result = builder.compare_with_true_inputs(feeds, outputs)
print("\n=== shape comparison (expr, expected, computed) ===")
for name, dims in result.items():
print(f" {name}: {dims}")
=== shape comparison (expr, expected, computed) ===
Z: (('batch', 2, 2), ('seq', 5, 5), ('2*d_model', 16, 16))
Total running time of the script: (0 minutes 0.054 seconds)
Related examples
ExtendedReferenceEvaluator: running models with contrib operators