0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

torch.export.export() does not work if a tensor given to the function has 0 or 1 for dimension declared as dynamic dimension.

Simple model, no dimension with 0 or 1

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 8)
model(x, y, z)

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
graph():
    %x : [num_users=3] = placeholder[target=x]
    %y : [num_users=3] = placeholder[target=y]
    %z : [num_users=3] = placeholder[target=z]
    %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 1), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_2, %sym_size_int), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s58, s35) on node 'eq'), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_1, %sym_size_int_3), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%add_1, %sym_size_int_5), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_1, Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'), kwargs = {})
    %eq_2 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, %sym_size_int_4), kwargs = {})
    %_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_2, Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
    return (add,)

Same model, a dynamic dimension = 1

z = z[:1]

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

try:
    ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
    print(ep.graph)
except Exception as e:
    print("ERROR", e)
ERROR Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['z'].shape[0].

It failed. Let’s try a little trick.

Same model, a dynamic dimension = 1 and backed_size_oblivious=True

with torch.fx.experimental._config.patch(backed_size_oblivious=True):
    ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
    print(ep.graph)
graph():
    %x : [num_users=3] = placeholder[target=x]
    %y : [num_users=3] = placeholder[target=y]
    %z : [num_users=3] = placeholder[target=z]
    %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 1), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_2, %sym_size_int), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s58, s35) on node 'eq'), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_1, %sym_size_int_3), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%add_1, %sym_size_int_5), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_1, Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'), kwargs = {})
    %eq_2 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, %sym_size_int_4), kwargs = {})
    %_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_2, Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
    return (add,)

It worked.

doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")
plot export dim1

Total running time of the script: (0 minutes 0.391 seconds)

Related examples

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Gallery generated by Sphinx-Gallery