Note
Go to the end to download the full example code.
0, 1, 2 for a Dynamic Dimension in the dummy example to export a model¶
torch.export.export()
does not work if a tensor given to the function
has 0 or 1 for dimension declared as dynamic dimension.
Simple model, no dimension with 0 or 1¶
import torch
from onnx_diagnostic import doc
class Model(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y), axis=1) + z
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 8)
model(x, y, z)
DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
graph():
%x : [num_users=3] = placeholder[target=x]
%y : [num_users=3] = placeholder[target=y]
%z : [num_users=3] = placeholder[target=z]
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
%sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
%sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
%sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
%sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 0), kwargs = {})
%sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 1), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
%eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_2, %sym_size_int), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s58, s35) on node 'eq'), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_1, %sym_size_int_3), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%add_1, %sym_size_int_5), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_1, Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'), kwargs = {})
%eq_2 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, %sym_size_int_4), kwargs = {})
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_2, Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
return (add,)
Same model, a dynamic dimension = 1¶
ERROR Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['z'].shape[0].
It failed. Let’s try a little trick.
Same model, a dynamic dimension = 1 and backed_size_oblivious=True¶
graph():
%x : [num_users=3] = placeholder[target=x]
%y : [num_users=3] = placeholder[target=y]
%z : [num_users=3] = placeholder[target=z]
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 0), kwargs = {})
%sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%x, 1), kwargs = {})
%sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
%sym_size_int_3 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 1), kwargs = {})
%sym_size_int_4 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 0), kwargs = {})
%sym_size_int_5 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%z, 1), kwargs = {})
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
%eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_2, %sym_size_int), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq, Runtime assertion failed for expression Eq(s58, s35) on node 'eq'), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%sym_size_int_1, %sym_size_int_3), kwargs = {})
%eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%add_1, %sym_size_int_5), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_1, Runtime assertion failed for expression Eq(s16 + s43, s23) on node 'eq_1'), kwargs = {})
%eq_2 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, %sym_size_int_4), kwargs = {})
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%eq_2, Runtime assertion failed for expression Eq(s35, s7) on node 'eq_2'), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
return (add,)
It worked.
doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")

Total running time of the script: (0 minutes 0.391 seconds)
Related examples

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Cannot export torch.sym_max(x.shape[0], y.shape[0])
Cannot export torch.sym_max(x.shape[0], y.shape[0])