Note
Go to the end to download the full example code.
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints¶
Setting the dynamic shapes is not always easy. Here are a few tricks to make it work.
dx + dy not allowed?¶
import torch
from onnx_diagnostic import doc
class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z[:, ::2]
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)
print(torch.export.export(model, (x, y, z)).graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %z : [num_users=1] = placeholder[target=z]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%z,), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, None, None, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %slice_2), kwargs = {})
    return (add,)
Everything is fine so far. With dynamic shapes now. dx + dy is not allowed…
batch = torch.export.Dim("batch")
dx = torch.export.Dim("dx")
dy = torch.export.Dim("dy")
try:
    dz = dx + dy
    raise AssertionError("able to add dynamic dimensions, please update the tutorial")
except NotImplementedError as e:
    print(f"unable to add dynamic dimensions because {type(e)}, {e}")
unable to add dynamic dimensions because <class 'NotImplementedError'>, Attempted to add Dim('dy', min=0) to dx, where an integer was expected. (Only increasing linear operations with integer coefficients are supported.)
Then we could make it a different one.
dz = torch.export.Dim("dz")
try:
    ep = torch.export.export(
        model,
        (x, y, z),
        dynamic_shapes={
            "x": {0: batch, 1: dx},
            "y": {0: batch, 1: dy},
            "z": {0: batch, 1: dz},
        },
    )
    print(ep)
    raise AssertionError("able to export this model, please update the tutorial")
except torch._dynamo.exc.UserError as e:
    print(f"unable to use Dim('dz') because {type(e)}, {e}")
unable to use Dim('dz') because <class 'torch._dynamo.exc.UserError'>, Constraints violated (dz)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of dz = L['z'].size()[1] in the specified range satisfy the generated guard ((1 + L['z'].size()[1]) // 2) != 1.
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
That works. We could also use
torch.export.Dim.DYNAMIC or torch.export.Dim.AUTO
for the dimension we cannot set.
DYNAMIC = torch.export.Dim.DYNAMIC
ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes={
        "x": {0: DYNAMIC, 1: dx},
        "y": {0: DYNAMIC, 1: dy},
        "z": {0: DYNAMIC, 1: DYNAMIC},
    },
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s17, s27]", y: "f32[s17, s94]", z: "f32[s17, s32]"):
             #
            sym_size_int: "Sym(s17)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_1: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_2: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0)
            sym_size_int_3: "Sym(s94)" = torch.ops.aten.sym_size.int(y, 1)
            sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(z, 0)
            sym_size_int_5: "Sym(s32)" = torch.ops.aten.sym_size.int(z, 1)
             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_with_dynamic.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            cat: "f32[s17, s27 + s94]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
             #
            eq: "Sym(True)" = sym_size_int_2 == sym_size_int;  sym_size_int = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s17, s77) on node 'eq'");  eq = _assert_scalar_default = None
            add_1: "Sym(s27 + s94)" = sym_size_int_1 + sym_size_int_3;  sym_size_int_1 = sym_size_int_3 = None
            add_2: "Sym(s32 + 1)" = 1 + sym_size_int_5;  sym_size_int_5 = None
            floordiv: "Sym(((s32 + 1)//2))" = add_2 // 2;  add_2 = None
            eq_1: "Sym(Eq(s27 + s94, ((s32 + 1)//2)))" = add_1 == floordiv;  add_1 = floordiv = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s27 + s94, ((s32 + 1)//2)) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
            eq_2: "Sym(True)" = sym_size_int_2 == sym_size_int_4;  sym_size_int_2 = sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s17, s68) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None
             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_with_dynamic.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            slice_1: "f32[s17, s32]" = torch.ops.aten.slice.Tensor(z);  z = None
            slice_2: "f32[s17, ((s32 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, None, None, 2);  slice_1 = None
            add: "f32[s17, s27 + s94]" = torch.ops.aten.add.Tensor(cat, slice_2);  cat = slice_2 = None
            return (add,)
Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT
    # outputs
    add: USER_OUTPUT
Range constraints: {s17: VR[2, int_oo], s27: VR[0, int_oo], s94: VR[0, int_oo], s32: VR[2, int_oo]}
The same result can be obtained with torch.export.Dim.AUTO.
AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s17, s27]", y: "f32[s17, s94]", z: "f32[s17, s32]"):
             #
            sym_size_int: "Sym(s17)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_1: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_2: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0)
            sym_size_int_3: "Sym(s94)" = torch.ops.aten.sym_size.int(y, 1)
            sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(z, 0)
            sym_size_int_5: "Sym(s32)" = torch.ops.aten.sym_size.int(z, 1)
             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_with_dynamic.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            cat: "f32[s17, s27 + s94]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
             #
            eq: "Sym(True)" = sym_size_int_2 == sym_size_int;  sym_size_int = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s17, s77) on node 'eq'");  eq = _assert_scalar_default = None
            add_1: "Sym(s27 + s94)" = sym_size_int_1 + sym_size_int_3;  sym_size_int_1 = sym_size_int_3 = None
            add_2: "Sym(s32 + 1)" = 1 + sym_size_int_5;  sym_size_int_5 = None
            floordiv: "Sym(((s32 + 1)//2))" = add_2 // 2;  add_2 = None
            eq_1: "Sym(Eq(s27 + s94, ((s32 + 1)//2)))" = add_1 == floordiv;  add_1 = floordiv = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s27 + s94, ((s32 + 1)//2)) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
            eq_2: "Sym(True)" = sym_size_int_2 == sym_size_int_4;  sym_size_int_2 = sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s17, s68) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None
             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_with_dynamic.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
            slice_1: "f32[s17, s32]" = torch.ops.aten.slice.Tensor(z);  z = None
            slice_2: "f32[s17, ((s32 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, None, None, 2);  slice_1 = None
            add: "f32[s17, s27 + s94]" = torch.ops.aten.add.Tensor(cat, slice_2);  cat = slice_2 = None
            return (add,)
Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT
    # outputs
    add: USER_OUTPUT
Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo], s94: VR[2, int_oo], s32: VR[2, int_oo]}
doc.plot_legend("torch.export.Dim\nor DYNAMIC\nor AUTO", "torch.export.export", "green")

Total running time of the script: (0 minutes 0.587 seconds)
Related examples
 
0, 1, 2 for a Dynamic Dimension in the dummy example to export a model
0, 1, 2 for a Dynamic Dimension in the dummy example to export a model
 
Cannot export torch.sym_max(x.shape[0], y.shape[0])
Cannot export torch.sym_max(x.shape[0], y.shape[0])
