Note
Go to the end to download the full example code.
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints¶
Setting the dynamic shapes is not always easy. Here are a few tricks to make it work.
dx + dy not allowed?¶
import torch
from onnx_diagnostic import doc
class Model(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y), axis=1) + z[:, ::2]
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)
print(torch.export.export(model, (x, y, z)).graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%z : [num_users=1] = placeholder[target=z]
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%z, 0, 0, 9223372036854775807), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 9223372036854775807, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %slice_2), kwargs = {})
return (add,)
Everything is fine so far. With dynamic shapes now. dx + dy is not allowed…
batch = torch.export.Dim("batch")
dx = torch.export.Dim("dx")
dy = torch.export.Dim("dy")
try:
dz = dx + dy
raise AssertionError("able to add dynamic dimensions, please update the tutorial")
except NotImplementedError as e:
print(f"unable to add dynamic dimensions because {type(e)}, {e}")
unable to add dynamic dimensions because <class 'NotImplementedError'>, Attempted to add <class '__main__.dy'> to dx, where an integer was expected. (Only increasing linear operations with integer coefficients are supported.)
Then we could make it a different one.
dz = torch.export.Dim("dz") * 2
try:
ep = torch.export.export(
model,
(x, y, z),
dynamic_shapes={
"x": {0: batch, 1: dx},
"y": {0: batch, 1: dy},
"z": {0: batch, 1: dz},
},
)
print(ep)
raise AssertionError("able to export this model, please update the tutorial")
except torch._dynamo.exc.UserError as e:
print(f"unable to use Dim('dz') because {type(e)}, {e}")
unable to use Dim('dz') because <class 'torch._dynamo.exc.UserError'>, Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][0][0].size()[0] in the specified range satisfy the generated guard L['args'][0][0].size()[0] != 9223372036854775807.
That works. We could also use
torch.export.Dim.DYNAMIC
or torch.export.Dim.AUTO
for the dimension we cannot set.
DYNAMIC = torch.export.Dim.DYNAMIC
ep = torch.export.export(
model,
(x, y, z),
dynamic_shapes={
"x": {0: DYNAMIC, 1: dx},
"y": {0: DYNAMIC, 1: dy},
"z": {0: DYNAMIC, 1: DYNAMIC},
},
)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s5]"):
#
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_4: "Sym(s3)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_5: "Sym(s5)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
add_1: "Sym(s1 + s3)" = sym_size_int_3 + sym_size_int_4; sym_size_int_3 = sym_size_int_4 = None
add_2: "Sym(s5 + 1)" = 1 + sym_size_int_5; sym_size_int_5 = None
floordiv: "Sym(((s5 + 1)//2))" = add_2 // 2; add_2 = None
eq_2: "Sym(Eq(s1 + s3, ((s5 + 1)//2)))" = add_1 == floordiv; add_1 = floordiv = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s1 + s3, ((s5 + 1)//2)) on node 'eq_2'"); eq_2 = _assert_scalar_default = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
slice_1: "f32[s0, s5]" = torch.ops.aten.slice.Tensor(z, 0, 0, 9223372036854775807); z = None
slice_2: "f32[s0, ((s5 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807, 2); slice_1 = None
add: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, slice_2); cat = slice_2 = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s0: VR[2, int_oo], s1: VR[0, int_oo], s3: VR[0, int_oo], s5: VR[2, int_oo]}
The same result can be obtained with torch.export.Dim.AUTO
.
AUTO = torch.export.Dim.AUTO
print(
torch.export.export(
model,
(x, y, z),
dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
)
)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s5]"):
#
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_4: "Sym(s3)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_5: "Sym(s5)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
add_1: "Sym(s1 + s3)" = sym_size_int_3 + sym_size_int_4; sym_size_int_3 = sym_size_int_4 = None
add_2: "Sym(s5 + 1)" = 1 + sym_size_int_5; sym_size_int_5 = None
floordiv: "Sym(((s5 + 1)//2))" = add_2 // 2; add_2 = None
eq_2: "Sym(Eq(s1 + s3, ((s5 + 1)//2)))" = add_1 == floordiv; add_1 = floordiv = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s1 + s3, ((s5 + 1)//2)) on node 'eq_2'"); eq_2 = _assert_scalar_default = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:20 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
slice_1: "f32[s0, s5]" = torch.ops.aten.slice.Tensor(z, 0, 0, 9223372036854775807); z = None
slice_2: "f32[s0, ((s5 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807, 2); slice_1 = None
add: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, slice_2); cat = slice_2 = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s3: VR[2, int_oo], s5: VR[2, int_oo]}
doc.plot_legend("dynamic shapes\ninferred", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 0.378 seconds)
Related examples