201: Better shape inference

A simple model

import numpy as np
import onnx
import onnx.helper as oh
import onnx.shape_inference as osh
from onnx.reference import ReferenceEvaluator
from experimental_experiment.xshape.shape_builder_impl import BasicShapeBuilder

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Concat", ["X", "Y"], ["xy"], axis=1),
            oh.make_node("Split", ["xy"], ["S1", "S2"], axis=1, num_outputs=2),
            oh.make_node("Concat", ["S2", "S1"], ["zs"], axis=1),
            oh.make_node("Relu", ["zs"], ["Z"]),
        ],
        "dummy",
        [
            oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, ["a", "b"]),
            oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, ["a", "c"]),
        ],
        [oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, ["a", "e"])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=9,
)

feeds = dict(X=np.random.rand(3, 4).astype(np.float32), Y=np.random.rand(3, 6).astype(np.float32))
ref = ReferenceEvaluator(model)
expected = ref.run(None, feeds)

Classic Shape Inference

model2 = osh.infer_shapes(model)

for info in model2.graph.value_info:
    t = info.type.tensor_type
    shape = tuple(d.dim_param or d.dim_value for d in t.shape.dim)
    print(f"{info.name}: {t.elem_type}:{shape}")
xy: 1:('a', 'unk__0')
S1: 1:('a', 'unk__1')
S2: 1:('a', 'unk__2')
zs: 1:('a', 'unk__3')

Basic Shape Inference

The algorithm infer shapes wherever the output shape of a node does not depend on the content even. The evaluation relies on ast.

builder = BasicShapeBuilder()
builder.run_model(model)
builder.update_shapes(model)

for info in model.graph.value_info:
    t = info.type.tensor_type
    shape = tuple(d.dim_param or d.dim_value for d in t.shape.dim)
    print(f"{info.name}: {t.elem_type}:{shape}")
xy: 1:('a', 'b+c')
S1: 1:('a', 'CeilToInt(b+c,2)')
S2: 1:('a', 'b+c-CeilToInt(b+c,2)')
zs: 1:('a', 'b+c')

Evaluate Expressions

We can also evaluate every expression without evaluating the model itself.

dimensions = dict(a=3, b=4, c=6)
for name in ["X", "Y", "xy", "S1", "S2", "zs", "Z"]:
    sh = builder.evaluate_shape(name, dimensions)
    print(f"shape of {name!r} is {sh}")
shape of 'X' is (3, 4)
shape of 'Y' is (3, 6)
shape of 'xy' is (3, 10)
shape of 'S1' is (3, 5)
shape of 'S2' is (3, 5)
shape of 'zs' is (3, 10)
shape of 'Z' is (3, 10)

Total running time of the script: (0 minutes 0.012 seconds)

Related examples

Playground for big optimization pattern

Playground for big optimization pattern

101: Onnx Model Rewriting

101: Onnx Model Rewriting

101: Profile an existing model with onnxruntime

101: Profile an existing model with onnxruntime

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

Gallery generated by Sphinx-Gallery