Note
Go to the end to download the full example code.
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints¶
Setting the dynamic shapes is not always easy. Here are a few tricks to make it work.
dx + dy not allowed?¶
import torch
class Model(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y), axis=1) + z[:, ::2]
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)
print(torch.export.export(model, (x, y, z)).graph)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%z : [num_users=1] = placeholder[target=z]
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%z, 0, 0, 9223372036854775807), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 9223372036854775807, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %slice_2), kwargs = {})
return (add,)
Everything is fine so far. With dynamic shapes now. dx + dy is not allowed…
unable to add dynamic dimensions because <class 'NotImplementedError'>, Attempted to add <class '__main__.dy'> to dx, where an integer was expected. (Only increasing linear operations with integer coefficients are supported.)
Then we could make it a different one.
dz = torch.export.Dim("dz") * 2
try:
ep = torch.export.export(
model,
(x, y, z),
dynamic_shapes={
"x": {0: batch, 1: dx},
"y": {0: batch, 1: dy},
"z": {0: batch, 1: dz},
},
)
print(ep)
raise AssertionError("able to export this moel, please update the tutorial")
except torch._dynamo.exc.UserError as e:
print(f"unable to use Dim('dz') because {type(e)}, {e}")
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
unable to use Dim('dz') because <class 'torch._dynamo.exc.UserError'>, Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][0][0].size()[0] in the specified range satisfy the generated guard L['args'][0][0].size()[0] != 9223372036854775807.
That works. We could also use
torch.export.Dim.DYNAMIC
or torch.export.Dim.AUTO
for the dimension we cannot set.
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s5]"):
#
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_4: "Sym(s3)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_5: "Sym(s5)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
add_1: "Sym(s1 + s3)" = sym_size_int_3 + sym_size_int_4; sym_size_int_3 = sym_size_int_4 = None
add_2: "Sym(s5 + 1)" = 1 + sym_size_int_5; sym_size_int_5 = None
floordiv: "Sym(((s5 + 1)//2))" = add_2 // 2; add_2 = None
eq_2: "Sym(Eq(s1 + s3, ((s5 + 1)//2)))" = add_1 == floordiv; add_1 = floordiv = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s1 + s3, ((s5 + 1)//2)) on node 'eq_2'"); eq_2 = _assert_scalar_default = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
slice_1: "f32[s0, s5]" = torch.ops.aten.slice.Tensor(z, 0, 0, 9223372036854775807); z = None
slice_2: "f32[s0, ((s5 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807, 2); slice_1 = None
add: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, slice_2); cat = slice_2 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[0, int_oo], s3: VR[0, int_oo], s5: VR[2, int_oo]}
The same result can be obtained with torch.export.Dim.AUTO
.
AUTO = torch.export.Dim.AUTO
print(
torch.export.export(
model,
(x, y, z),
dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
)
)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s5]"):
#
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_4: "Sym(s3)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_5: "Sym(s5)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
add_1: "Sym(s1 + s3)" = sym_size_int_3 + sym_size_int_4; sym_size_int_3 = sym_size_int_4 = None
add_2: "Sym(s5 + 1)" = 1 + sym_size_int_5; sym_size_int_5 = None
floordiv: "Sym(((s5 + 1)//2))" = add_2 // 2; add_2 = None
eq_2: "Sym(Eq(s1 + s3, ((s5 + 1)//2)))" = add_1 == floordiv; add_1 = floordiv = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s1 + s3, ((s5 + 1)//2)) on node 'eq_2'"); eq_2 = _assert_scalar_default = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
slice_1: "f32[s0, s5]" = torch.ops.aten.slice.Tensor(z, 0, 0, 9223372036854775807); z = None
slice_2: "f32[s0, ((s5 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807, 2); slice_1 = None
add: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, slice_2); cat = slice_2 = None
return (add,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s3: VR[2, int_oo], s5: VR[2, int_oo]}
Total running time of the script: (0 minutes 0.570 seconds)
Related examples