Cannot export torch.sym_max(x.shape[0], y.shape[0])

This is related to the following issues: Cannot export torch.sym_max(x.shape[0], y.shape[0]).

The algorithm trying to automatically infer shapes after every operator in the exported program is something very aggreessive. Here is a case where it takes a wrong decision and how to get around it.

Wrong Model

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def forward(self, x, y, fact):
        s1 = max(x.shape[0], y.shape[0])
        s2 = max(x.shape[1], y.shape[1])
        # Shapes cannot be known here.
        z = torch.zeros((s1, s2), dtype=x.dtype)
        z[: x.shape[0], : x.shape[1]] = x
        z[: y.shape[0], : y.shape[1]] += y
        return z * fact


model = Model()
x = torch.arange(6).reshape((2, 3))
y = torch.arange(6).reshape((3, 2)) * 10
fact = torch.tensor([[1, 2, 3]], dtype=x.dtype)
z = model(x, y, fact)
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
x.shape=torch.Size([2, 3]), y.shape=torch.Size([3, 2]), z.shape=torch.Size([3, 3])

Export

DYN = torch.export.Dim.DYNAMIC

ep = torch.export.export(
    model, (x, y, fact), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN})
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[s35, s16]", y: "i64[s58, s43]", fact: "i64[1, s16]"):
             #
            sym_size_int_8: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_9: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_10: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0)
            sym_size_int_11: "Sym(s43)" = torch.ops.aten.sym_size.int(y, 1)

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:26 in forward, code: z = torch.zeros((s1, s2), dtype=x.dtype)
            zeros: "i64[s58, s16]" = torch.ops.aten.zeros.default([sym_size_int_10, sym_size_int_9], dtype = torch.int64, device = device(type='cpu'), pin_memory = False);  sym_size_int_9 = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:27 in forward, code: z[: x.shape[0], : x.shape[1]] = x
            slice_1: "i64[s35, s16]" = torch.ops.aten.slice.Tensor(zeros, 0, 0, sym_size_int_8);  sym_size_int_8 = None
            copy_: "i64[s35, s16]" = torch.ops.aten.copy_.default(slice_1, x);  slice_1 = x = copy_ = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:28 in forward, code: z[: y.shape[0], : y.shape[1]] += y
            slice_2: "i64[s58, s16]" = torch.ops.aten.slice.Tensor(zeros, 0, None, sym_size_int_10);  sym_size_int_10 = None
            slice_3: "i64[s58, s43]" = torch.ops.aten.slice.Tensor(slice_2, 1, None, sym_size_int_11);  slice_2 = None
            add_: "i64[s58, s43]" = torch.ops.aten.add_.Tensor(slice_3, y);  slice_3 = y = None
            slice_4: "i64[s58, s43]" = torch.ops.aten.slice.Tensor(zeros, 1, 0, sym_size_int_11);  sym_size_int_11 = None
            copy__1: "i64[s58, s43]" = torch.ops.aten.copy_.default(slice_4, add_);  slice_4 = add_ = copy__1 = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:29 in forward, code: return z * fact
            mul: "i64[s58, s16]" = torch.ops.aten.mul.Tensor(zeros, fact);  zeros = fact = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    fact: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo], s58: VR[2, int_oo], s43: VR[2, int_oo]}

But does it really work? Let’s print the shapes.

model_ep = ep.module()
ez = model_ep(x, y, fact)
print("case 1:", z.shape, ez.shape)
case 1: torch.Size([3, 3]) torch.Size([3, 3])

Case with different shapes.

x = torch.arange(4).reshape((2, 2))
y = torch.arange(9).reshape((3, 3))
try:
    ez = model_ep(x, y, fact)
    print("case 2:", model(x, y, fact).shape, ez.shape)
except Exception as e:
    print("case 2 failed:", e)
case 2 failed: Expected input at *args[2].shape[1] to be equal to 2, but got 3

It does not even compute. The exported program does not get the correct shape.

Rewritten Model

max does not get captured, torch.sym_max() is no better, torch.max() only works on tensors. Nothing really works. We use a trick to introduce new shape the shape inference algorithm cannot know. This requires to hide the failing logic in a custom operator.

def make_undefined_dimension(i: int) -> torch.SymInt:
    """
    Uses for a custom op when a new dimension must be introduced to bypass
    some verification. The following function creates a dummy output
    with a dimension based on the content.

    .. code-block:: python

        def symbolic_shape(x, y):
            return torch.empty(
                x.shape[0],
                make_undefined_dimension(min(x.shape[1], y[0])),
            )
    """
    t = torch.ones((i * 2,))
    t[:i] = 0
    res = torch.nonzero(t).shape[0]
    return res


def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape))
    z = torch.zeros(tuple(shape), dtype=x.dtype)
    z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]]
    z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]]
    return z


def symbolic_shape(x, y):
    return torch.empty(
        tuple(
            make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape))
        ),
        dtype=x.dtype,
    )


def register(fct, fct_shape, namespace, fname):
    schema_str = torch.library.infer_schema(fct, mutates_args=())
    custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
    custom_def.register_kernel("cpu")(fct)
    custom_def._abstract_fn = fct_shape


register(
    copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions"
)

Now everything is registered. Let’s rewrite the model.

class RewrittenModel(torch.nn.Module):
    def forward(self, x, y, fact):
        z = torch.ops.mylib.copy_max_dimensions(x, y)
        return z * fact

And check it works.

rewritten_model = RewrittenModel()
x = torch.arange(6).reshape((2, 3))
y = torch.arange(6).reshape((3, 2)) * 10
z = rewritten_model(x, y, fact)
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
x.shape=torch.Size([2, 3]), y.shape=torch.Size([3, 2]), z.shape=torch.Size([3, 3])

Export again

ep = torch.export.export(
    rewritten_model,
    (x, y, fact),
    dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}),
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[s35, s16]", y: "i64[s58, s43]", fact: "i64[1, s23]"):
             #
            sym_size_int_4: "Sym(s23)" = torch.ops.aten.sym_size.int(fact, 1)

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:132 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
            copy_max_dimensions: "i64[u0, s23]" = torch.ops.mylib.copy_max_dimensions.default(x, y);  x = y = None

             #
            sym_size_int_5: "Sym(u0)" = torch.ops.aten.sym_size.int(copy_max_dimensions, 0)
            sym_size_int_6: "Sym(s23)" = torch.ops.aten.sym_size.int(copy_max_dimensions, 1)
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_5);  sym_constrain_range_for_size_default = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:132 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
            ge: "Sym(u0 >= 0)" = sym_size_int_5 >= 0;  sym_size_int_5 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

             #
            sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_6);  sym_constrain_range_for_size_default_1 = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:132 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
            ge_1: "Sym(s23 >= 2)" = sym_size_int_6 >= 2
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 2 on node 'ge_1'");  ge_1 = _assert_scalar_default_1 = None

             #
            eq: "Sym(True)" = sym_size_int_6 == sym_size_int_4;  sym_size_int_6 = sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, s23) on node 'eq'");  eq = _assert_scalar_default_2 = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:133 in forward, code: return z * fact
            mul: "i64[u0, s23]" = torch.ops.aten.mul.Tensor(copy_max_dimensions, fact);  copy_max_dimensions = fact = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    fact: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {u0: VR[0, 9223372036854775806], u1: VR[2, 9223372036854775806], s35: VR[2, int_oo], s16: VR[2, int_oo], s58: VR[2, int_oo], s43: VR[2, int_oo], s23: VR[2, 9223372036854775806]}

We check it works.

model_ep = ep.module()
ez = model_ep(x, y, fact)
print("case 1:", z.shape, ez.shape)

x = torch.arange(4).reshape((2, 2))
y = torch.arange(9).reshape((3, 3))
try:
    ez = model_ep(x, y, fact)
    print("case 2:", rewritten_model(x, y, fact).shape, ez.shape)
except Exception as e:
    print("case 2 failed:", e)
case 1: torch.Size([3, 3]) torch.Size([3, 3])
case 2: torch.Size([3, 3]) torch.Size([3, 3])

Final Check on very different dimension

x = torch.arange(6 * 8).reshape((6, 8))
y = torch.arange(10 * 4).reshape((10, 4)) * 10
fact = torch.arange(8).reshape((1, -1))

print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape)
final case: torch.Size([10, 8]) torch.Size([10, 8])

This is not perfect as we get an exported program but some logic is hidden in a custom operator.

doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow")
plot dynamic shapes max

Total running time of the script: (0 minutes 0.261 seconds)

Related examples

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Half certain nonzero

Half certain nonzero

Gallery generated by Sphinx-Gallery