Cannot export torch.sym_max(x.shape[0], y.shape[0])

This is related to the following issues: Cannot export torch.sym_max(x.shape[0], y.shape[0]).

The algorithm trying to automatically infer shapes after every operator in the exported program is something very aggreessive. Here is a case where it takes a wrong decision and how to get around it.

This bug was fixed after 4/24/2025.

Wrong Model

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def forward(self, x, y, fact):
        s1 = max(x.shape[0], y.shape[0])
        s2 = max(x.shape[1], y.shape[1])
        # Shapes cannot be known here.
        z = torch.zeros((s1, s2), dtype=x.dtype)
        z[: x.shape[0], : x.shape[1]] = x
        z[: y.shape[0], : y.shape[1]] += y
        return z * fact


model = Model()
x = torch.arange(6).reshape((2, 3))
y = torch.arange(6).reshape((3, 2)) * 10
fact = torch.tensor([[1, 2, 3]], dtype=x.dtype)
z = model(x, y, fact)
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
x.shape=torch.Size([2, 3]), y.shape=torch.Size([3, 2]), z.shape=torch.Size([3, 3])

Export

DYN = torch.export.Dim.DYNAMIC

ep = torch.export.export(
    model, (x, y, fact), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN})
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[s35, s16]", y: "i64[s58, s43]", fact: "i64[1, Max(s16, s43)]"):
             #
            sym_size_int_10: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_11: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_12: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0)
            sym_size_int_13: "Sym(s43)" = torch.ops.aten.sym_size.int(y, 1)
            sym_size_int_14: "Sym(Max(s16, s43))" = torch.ops.aten.sym_size.int(fact, 1)
            sym_max_2: "Sym(Max(s16, s43))" = torch.sym_max(sym_size_int_11, sym_size_int_13)
            eq_8: "Sym(True)" = sym_max_2 == sym_size_int_14;  sym_size_int_14 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_8, "Runtime assertion failed for expression Eq(Max(s16, s43), s23) on node 'eq_8'");  eq_8 = _assert_scalar_default = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:863 in sym_max, code: return a.__sym_max__(b)
            sym_max: "Sym(Max(s35, s58))" = torch.sym_max(sym_size_int_10, sym_size_int_12)

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:28 in forward, code: z = torch.zeros((s1, s2), dtype=x.dtype)
            zeros: "i64[Max(s35, s58), Max(s16, s43)]" = torch.ops.aten.zeros.default([sym_max, sym_max_2], dtype = torch.int64, device = device(type='cpu'), pin_memory = False);  sym_max = sym_max_2 = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:29 in forward, code: z[: x.shape[0], : x.shape[1]] = x
            slice_1: "i64[s35, Max(s16, s43)]" = torch.ops.aten.slice.Tensor(zeros, 0, 0, sym_size_int_10);  sym_size_int_10 = None
            slice_2: "i64[s35, s16]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, sym_size_int_11);  slice_1 = sym_size_int_11 = None
            copy_: "i64[s35, s16]" = torch.ops.aten.copy_.default(slice_2, x);  slice_2 = x = copy_ = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:30 in forward, code: z[: y.shape[0], : y.shape[1]] += y
            slice_3: "i64[s58, Max(s16, s43)]" = torch.ops.aten.slice.Tensor(zeros, 0, None, sym_size_int_12)
            slice_4: "i64[s58, s43]" = torch.ops.aten.slice.Tensor(slice_3, 1, None, sym_size_int_13);  slice_3 = None
            add_: "i64[s58, s43]" = torch.ops.aten.add_.Tensor(slice_4, y);  slice_4 = y = None
            slice_5: "i64[s58, Max(s16, s43)]" = torch.ops.aten.slice.Tensor(zeros, 0, 0, sym_size_int_12);  sym_size_int_12 = None
            slice_6: "i64[s58, s43]" = torch.ops.aten.slice.Tensor(slice_5, 1, 0, sym_size_int_13);  slice_5 = sym_size_int_13 = None
            copy__1: "i64[s58, s43]" = torch.ops.aten.copy_.default(slice_6, add_);  slice_6 = add_ = copy__1 = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:31 in forward, code: return z * fact
            mul: "i64[Max(s35, s58), Max(s16, s43)]" = torch.ops.aten.mul.Tensor(zeros, fact);  zeros = fact = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    fact: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo], s58: VR[2, int_oo], s43: VR[2, int_oo], Max(s16, s43): VR[2, int_oo]}

But does it really work? Let’s print the shapes.

model_ep = ep.module()
ez = model_ep(x, y, fact)
print("case 1:", z.shape, ez.shape)
case 1: torch.Size([3, 3]) torch.Size([3, 3])

Case with different shapes.

x = torch.arange(4).reshape((2, 2))
y = torch.arange(9).reshape((3, 3))
try:
    ez = model_ep(x, y, fact)
    print("case 2:", model(x, y, fact).shape, ez.shape)
except Exception as e:
    print("case 2 failed:", e)
case 2: torch.Size([3, 3]) torch.Size([3, 3])

It does not even compute. The exported program does not get the correct shape.

Rewritten Model

max does not get captured, torch.sym_max() is no better, torch.max() only works on tensors. Nothing really works. We use a trick to introduce new shape the shape inference algorithm cannot know. This requires to hide the failing logic in a custom operator.

def make_undefined_dimension(i: int) -> torch.SymInt:
    """
    Uses for a custom op when a new dimension must be introduced to bypass
    some verification. The following function creates a dummy output
    with a dimension based on the content.

    .. code-block:: python

        def symbolic_shape(x, y):
            return torch.empty(
                x.shape[0],
                make_undefined_dimension(min(x.shape[1], y[0])),
            )
    """
    t = torch.ones((i * 2,))
    t[:i] = 0
    res = torch.nonzero(t).shape[0]
    return res


def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape))
    z = torch.zeros(tuple(shape), dtype=x.dtype)
    z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]]
    z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]]
    return z


def symbolic_shape(x, y):
    return torch.empty(
        tuple(
            make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape))
        ),
        dtype=x.dtype,
    )


def register(fct, fct_shape, namespace, fname):
    schema_str = torch.library.infer_schema(fct, mutates_args=())
    custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
    custom_def.register_kernel("cpu")(fct)
    custom_def._abstract_fn = fct_shape


register(
    copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions"
)

Now everything is registered. Let’s rewrite the model.

class RewrittenModel(torch.nn.Module):
    def forward(self, x, y, fact):
        z = torch.ops.mylib.copy_max_dimensions(x, y)
        return z * fact

And check it works.

rewritten_model = RewrittenModel()
x = torch.arange(6).reshape((2, 3))
y = torch.arange(6).reshape((3, 2)) * 10
z = rewritten_model(x, y, fact)
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
x.shape=torch.Size([2, 3]), y.shape=torch.Size([3, 2]), z.shape=torch.Size([3, 3])

Export again

ep = torch.export.export(
    rewritten_model,
    (x, y, fact),
    dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}),
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[s35, s16]", y: "i64[s58, s43]", fact: "i64[1, s23]"):
             #
            sym_size_int_4: "Sym(s23)" = torch.ops.aten.sym_size.int(fact, 1)

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:134 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
            copy_max_dimensions: "i64[u0, s23]" = torch.ops.mylib.copy_max_dimensions.default(x, y);  x = y = None

             #
            sym_size_int_5: "Sym(u0)" = torch.ops.aten.sym_size.int(copy_max_dimensions, 0)
            sym_size_int_6: "Sym(s23)" = torch.ops.aten.sym_size.int(copy_max_dimensions, 1)
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_5);  sym_constrain_range_for_size_default = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:134 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
            ge: "Sym(u0 >= 0)" = sym_size_int_5 >= 0;  sym_size_int_5 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

             #
            sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_6);  sym_constrain_range_for_size_default_1 = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:134 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
            ge_1: "Sym(s23 >= 2)" = sym_size_int_6 >= 2
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 2 on node 'ge_1'");  ge_1 = _assert_scalar_default_1 = None

             #
            eq: "Sym(True)" = sym_size_int_6 == sym_size_int_4;  sym_size_int_6 = sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, s23) on node 'eq'");  eq = _assert_scalar_default_2 = None

             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:135 in forward, code: return z * fact
            mul: "i64[u0, s23]" = torch.ops.aten.mul.Tensor(copy_max_dimensions, fact);  copy_max_dimensions = fact = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    fact: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {u0: VR[0, 9223372036854775806], u1: VR[2, 9223372036854775806], s35: VR[2, int_oo], s16: VR[2, int_oo], s58: VR[2, int_oo], s43: VR[2, int_oo], s23: VR[2, 9223372036854775806]}

We check it works.

model_ep = ep.module()
ez = model_ep(x, y, fact)
print("case 1:", z.shape, ez.shape)

x = torch.arange(4).reshape((2, 2))
y = torch.arange(9).reshape((3, 3))
try:
    ez = model_ep(x, y, fact)
    print("case 2:", rewritten_model(x, y, fact).shape, ez.shape)
except Exception as e:
    print("case 2 failed:", e)
case 1: torch.Size([3, 3]) torch.Size([3, 3])
case 2: torch.Size([3, 3]) torch.Size([3, 3])

Final Check on very different dimension

x = torch.arange(6 * 8).reshape((6, 8))
y = torch.arange(10 * 4).reshape((10, 4)) * 10
fact = torch.arange(8).reshape((1, -1))

print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape)
final case: torch.Size([10, 8]) torch.Size([10, 8])

This is not perfect as we get an exported program but some logic is hidden in a custom operator.

doc.plot_legend(
    "Fixed in torch==2.8\nmax(d1, d2)\nwith d1, d2\ndimensions", "dynamic shapes", "green"
)
plot dynamic shapes max

Total running time of the script: (0 minutes 0.992 seconds)

Related examples

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

Export a model with a control flow (If)

Export a model with a control flow (If)

Gallery generated by Sphinx-Gallery