Use DYNAMIC or AUTO when dynamic shapes has constraints

Settings the dynamic shapes is not always easy. Here are a few tricks to make it work.

dx + dy not allowed?

import torch


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z[:, ::2]


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)


print(torch.export.export(model, (x, y, z)).graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %z : [num_users=1] = placeholder[target=z]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%z,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, None, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %slice_2), kwargs = {})
    return (add,)

Everything is fine so far. With dynamic shapes now. dx + dy is not allowed…

batch = torch.export.Dim("batch")
dx = torch.export.Dim("dx")
dy = torch.export.Dim("dy")

try:
    dz = dx + dy
    raise AssertionError("able to add dynamic dimensions, please update the tutorial")
except NotImplementedError as e:
    print(f"unable to add dynamic dimensions because {type(e)}, {e}")
unable to add dynamic dimensions because <class 'NotImplementedError'>, Attempted to add Dim('dy', min=0) to dx, where an integer was expected. (Only increasing linear operations with integer coefficients are supported.)

Then we could make it a different one.

dz = torch.export.Dim("dz") * 2
ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes={
        "x": {0: batch, 1: dx},
        "y": {0: batch, 1: dy},
        "z": {0: batch, 1: dz},
    },
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]", y: "f32[s35, s43]", z: "f32[s35, 2*s55]"):
             # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            cat: "f32[s35, s16 + s43]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
            slice_1: "f32[s35, 2*s55]" = torch.ops.aten.slice.Tensor(z);  z = None
            slice_2: "f32[s35, s55]" = torch.ops.aten.slice.Tensor(slice_1, 1, None, None, 2);  slice_1 = None
            add: "f32[s35, s16 + s43]" = torch.ops.aten.add.Tensor(cat, slice_2);  cat = slice_2 = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s35: VR[0, int_oo], s16: VR[0, int_oo], s43: VR[0, int_oo], 2*s55: VR[0, int_oo], s55: VR[2, int_oo]}

That works. We could also use torch.export.Dim.DYNAMIC or torch.export.Dim.AUTO for the dimension we cannot set.

DYNAMIC = torch.export.Dim.DYNAMIC
ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes={
        "x": {0: DYNAMIC, 1: dx},
        "y": {0: DYNAMIC, 1: dy},
        "z": {0: DYNAMIC, 1: DYNAMIC},
    },
)

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]", y: "f32[s35, s43]", z: "f32[s35, s23]"):
             #
            sym_size_int_1: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_2: "Sym(s43)" = torch.ops.aten.sym_size.int(y, 1)
            sym_size_int_3: "Sym(s23)" = torch.ops.aten.sym_size.int(z, 1)

             # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            cat: "f32[s35, s16 + s43]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None

             #
            add_1: "Sym(s16 + s43)" = sym_size_int_1 + sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
            add_2: "Sym(s23 + 1)" = 1 + sym_size_int_3;  sym_size_int_3 = None
            floordiv: "Sym(((s23 + 1)//2))" = add_2 // 2;  add_2 = None
            eq: "Sym(Eq(s16 + s43, ((s23 + 1)//2)))" = add_1 == floordiv;  add_1 = floordiv = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16 + s43, ((s23 + 1)//2)) on node 'eq'");  eq = _assert_scalar_default = None

             # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            slice_1: "f32[s35, s23]" = torch.ops.aten.slice.Tensor(z);  z = None
            slice_2: "f32[s35, ((s23 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, None, None, 2);  slice_1 = None
            add: "f32[s35, s16 + s43]" = torch.ops.aten.add.Tensor(cat, slice_2);  cat = slice_2 = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[0, int_oo], s43: VR[0, int_oo], s23: VR[2, int_oo]}

The same result can be obtained with torch.export.Dim.AUTO.

AUTO = torch.export.Dim.AUTO
print(
    torch.export.export(
        model,
        (x, y, z),
        dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
    )
)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]", y: "f32[s35, s43]", z: "f32[s35, s23]"):
             #
            sym_size_int_1: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_2: "Sym(s43)" = torch.ops.aten.sym_size.int(y, 1)
            sym_size_int_3: "Sym(s23)" = torch.ops.aten.sym_size.int(z, 1)

             # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            cat: "f32[s35, s16 + s43]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None

             #
            add_1: "Sym(s16 + s43)" = sym_size_int_1 + sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
            add_2: "Sym(s23 + 1)" = 1 + sym_size_int_3;  sym_size_int_3 = None
            floordiv: "Sym(((s23 + 1)//2))" = add_2 // 2;  add_2 = None
            eq: "Sym(Eq(s16 + s43, ((s23 + 1)//2)))" = add_1 == floordiv;  add_1 = floordiv = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s16 + s43, ((s23 + 1)//2)) on node 'eq'");  eq = _assert_scalar_default = None

             # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            slice_1: "f32[s35, s23]" = torch.ops.aten.slice.Tensor(z);  z = None
            slice_2: "f32[s35, ((s23 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, None, None, 2);  slice_1 = None
            add: "f32[s35, s16 + s43]" = torch.ops.aten.add.Tensor(cat, slice_2);  cat = slice_2 = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo], s43: VR[2, int_oo], s23: VR[2, int_oo]}

Total running time of the script: (0 minutes 3.635 seconds)

Related examples

to_onnx: Rename Dynamic Shapes

to_onnx: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

Infer dynamic shapes before exporting

Infer dynamic shapes before exporting

torch.onnx.export and a custom operator registered with a function

torch.onnx.export and a custom operator registered with a function

to_onnx and Phi-2

to_onnx and Phi-2

Gallery generated by Sphinx-Gallery