.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_computed_shapes.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_computed_shapes.py: .. _l-plot-computed-shapes: Computed Shapes: Add + Concat + Reshape ======================================== This example shows how :class:`BasicShapeBuilder ` tracks symbolic dimension expressions through a sequence of ``Add``, ``Concat``, and ``Reshape`` nodes, and compares the result with the standard :func:`onnx.shape_inference.infer_shapes`. The key difference is that ``onnx.shape_inference.infer_shapes`` can only propagate shapes when dimensions are statically known integers. When the model contains dynamic (symbolic) dimensions it typically assigns ``None`` (unknown) to most intermediate results. :class:`BasicShapeBuilder` instead keeps the dimensions as symbolic arithmetic expressions so that output shapes are expressed in terms of the input dimension names. See :ref:`l-design-shape` for a detailed description of how :class:`BasicShapeBuilder ` works and a comparison table with :func:`onnx.shape_inference.infer_shapes`. .. GENERATED FROM PYTHON SOURCE LINES 24-35 .. code-block:: Python import numpy as np import onnx import onnx.helper as oh import onnx.numpy_helper as onh from yobx.reference import ExtendedReferenceEvaluator from yobx.xshape import BasicShapeBuilder TFLOAT = onnx.TensorProto.FLOAT .. GENERATED FROM PYTHON SOURCE LINES 36-48 Build a small model -------------------- The graph performs the following steps: 1. ``Add(X, Y)`` — element-wise addition of two tensors with shape ``(batch, seq, d_model)``. 2. ``Concat(added, X, axis=2)`` — concatenate the result with the original ``X`` along the last axis, giving shape ``(batch, seq, 2*d_model)``. 3. ``Reshape(concat_out, shape)`` — flatten the last two dimensions using a fixed shape constant ``[0, 0, -1]``, which collapses ``(batch, seq, 2*d_model)`` back to ``(batch, seq, 2*d_model)``. .. GENERATED FROM PYTHON SOURCE LINES 48-70 .. code-block:: Python model = oh.make_model( oh.make_graph( [ oh.make_node("Add", ["X", "Y"], ["added"]), oh.make_node("Concat", ["added", "X"], ["concat_out"], axis=2), oh.make_node("Reshape", ["concat_out", "reshape_shape"], ["Z"]), ], "add_concat_reshape", [ oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", "d_model"]), oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", "d_model"]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None, None])], [ onh.from_array(np.array([0, 0, -1], dtype=np.int64), name="reshape_shape"), ], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) .. GENERATED FROM PYTHON SOURCE LINES 71-77 Shape inference with ONNX -------------------------- ``onnx.shape_inference.infer_shapes`` propagates shapes through the model. For dynamic dimensions the inferred shapes for intermediate results are often unknown (``None``). .. GENERATED FROM PYTHON SOURCE LINES 77-94 .. code-block:: Python inferred = onnx.shape_inference.infer_shapes(model) print("=== onnx.shape_inference.infer_shapes ===") for vi in ( list(inferred.graph.input) + list(inferred.graph.value_info) + list(inferred.graph.output) ): t = vi.type.tensor_type if t.HasField("shape"): shape = tuple( d.dim_param if d.dim_param else (d.dim_value if d.dim_value else None) for d in t.shape.dim ) else: shape = "unknown" print(f" {vi.name:15s} shape={shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none === onnx.shape_inference.infer_shapes === X shape=('batch', 'seq', 'd_model') Y shape=('batch', 'seq', 'd_model') added shape=('batch', 'seq', 'd_model') concat_out shape=('batch', 'seq', 'unk__0') Z shape=('batch', 'seq', 'unk__1') .. GENERATED FROM PYTHON SOURCE LINES 95-102 Shape inference with BasicShapeBuilder ---------------------------------------- :class:`BasicShapeBuilder ` keeps the shapes as symbolic expressions. Because ``reshape_shape`` is a constant ``[0, 0, -1]``, the builder can evaluate the ``Reshape`` and express the output shape as a function of the input dimensions. .. GENERATED FROM PYTHON SOURCE LINES 102-111 .. code-block:: Python builder = BasicShapeBuilder() builder.run_model(model) print("\n=== BasicShapeBuilder ===") for name in ["X", "Y", "added", "concat_out", "Z"]: print(f" {name:15s} shape={builder.get_shape(name)}") .. rst-class:: sphx-glr-script-out .. code-block:: none === BasicShapeBuilder === X shape=('batch', 'seq', 'd_model') Y shape=('batch', 'seq', 'd_model') added shape=('batch', 'seq', 'd_model') concat_out shape=('batch', 'seq', '2*d_model') Z shape=('batch', 'seq', '2*d_model') .. GENERATED FROM PYTHON SOURCE LINES 112-118 Evaluate symbolic shapes with concrete values ----------------------------------------------- Once the concrete values of the dynamic dimensions are known, :meth:`evaluate_shape ` resolves each symbolic expression to its actual integer value. .. GENERATED FROM PYTHON SOURCE LINES 118-124 .. code-block:: Python context = dict(batch=2, seq=5, d_model=8) for name in ["X", "Y", "added", "concat_out", "Z"]: concrete = builder.evaluate_shape(name, context) print(f" {name:15s} concrete shape={concrete}") .. rst-class:: sphx-glr-script-out .. code-block:: none X concrete shape=(2, 5, 8) Y concrete shape=(2, 5, 8) added concrete shape=(2, 5, 8) concat_out concrete shape=(2, 5, 16) Z concrete shape=(2, 5, 16) .. GENERATED FROM PYTHON SOURCE LINES 125-131 Verify with real data ---------------------- Finally, run the model with concrete numpy arrays and confirm that the shapes predicted by :class:`BasicShapeBuilder` match the actual output shapes. .. GENERATED FROM PYTHON SOURCE LINES 131-143 .. code-block:: Python feeds = { "X": np.random.rand(2, 5, 8).astype(np.float32), "Y": np.random.rand(2, 5, 8).astype(np.float32), } session = ExtendedReferenceEvaluator(model) outputs = session.run(None, feeds) result = builder.compare_with_true_inputs(feeds, outputs) print("\n=== shape comparison (expr, expected, computed) ===") for name, dims in result.items(): print(f" {name}: {dims}") .. rst-class:: sphx-glr-script-out .. code-block:: none === shape comparison (expr, expected, computed) === Z: (('batch', 2, 2), ('seq', 5, 5), ('2*d_model', 16, 16)) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.054 seconds) .. _sphx_glr_download_auto_examples_plot_computed_shapes.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_computed_shapes.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_computed_shapes.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_computed_shapes.zip ` .. include:: plot_computed_shapes.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_