0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

torch.export.export() does not work if a tensor given to the function has 0 or 1 for dimension declared as dynamic dimension.

Simple model, no dimension with 0 or 1

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 8)
model(x, y, z)

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
graph():
    %x : [num_users=3] = placeholder[target=x]
    %y : [num_users=3] = placeholder[target=y]
    %z : [num_users=3] = placeholder[target=z]
    %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 1), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_2, %sym_size_int), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s58, s35) on node 'eq'), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_1, %sym_size_int_3), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%add_1, %sym_size_int_5), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_1, Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'), kwargs = {})
    %eq_2 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, %sym_size_int_4), kwargs = {})
    %_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_2, Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
    return (add,)

Same model, a dynamic dimension = 1

z = z[:1]

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

try:
    ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
    print(ep.graph)
except Exception as e:
    print("ERROR", e)
ERROR Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['z'].shape[0].

It failed. Let’s try a little trick.

Same model, a dynamic dimension = 1 and backed_size_oblivious=True

with torch.fx.experimental._config.patch(backed_size_oblivious=True):
    ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
    print(ep.graph)
graph():
    %x : [num_users=3] = placeholder[target=x]
    %y : [num_users=3] = placeholder[target=y]
    %z : [num_users=3] = placeholder[target=z]
    %sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
    %sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
    %sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 0), kwargs = {})
    %sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 1), kwargs = {})
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_2, %sym_size_int), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s58, s35) on node 'eq'), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_1, %sym_size_int_3), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%add_1, %sym_size_int_5), kwargs = {})
    %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_1, Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'), kwargs = {})
    %eq_2 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, %sym_size_int_4), kwargs = {})
    %_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_2, Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
    return (add,)

It worked.

doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")
plot export dim1

Total running time of the script: (0 minutes 0.391 seconds)

Related examples

Export a model with a control flow (If)

Export a model with a control flow (If)

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Gallery generated by Sphinx-Gallery