Note
Go to the end to download the full example code.
Use DYNAMIC or AUTO when dynamic shapes has constraints¶
Settings the dynamic shapes is not always easy. Here are a few tricks to make it work.
dx + dy not allowed?¶
import torch
class Model(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y), axis=1) + z[:, ::2]
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)
print(torch.export.export(model, (x, y, z)).graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%z : [num_users=1] = placeholder[target=z]
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%z, 0, 0, 9223372036854775807), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%slice_1, 1, 0, 9223372036854775807, 2), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %slice_2), kwargs = {})
return (add,)
Everything is fine so far. With dynamic shapes now. dx + dy is not allowed…
unable to add dynamic dimensions because <class 'NotImplementedError'>, Attempted to add <class '__main__.dy'> to dx, where an integer was expected. (Only increasing linear operations with integer coefficients are supported.)
Then we could make it a different one.
dz = torch.export.Dim("dz") * 2
try:
ep = torch.export.export(
model,
(x, y, z),
dynamic_shapes={
"x": {0: batch, 1: dx},
"y": {0: batch, 1: dy},
"z": {0: batch, 1: dz},
},
)
print(ep)
raise AssertionError("able to export this moel, please update the tutorial")
except torch._dynamo.exc.UserError as e:
print(f"unable to use Dim('dz') because {type(e)}, {e}")
unable to use Dim('dz') because <class 'torch._dynamo.exc.UserError'>, Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['x'].size()[0] in the specified range satisfy the generated guard L['x'].size()[0] != 9223372036854775807.
That works. We could also use
torch.export.Dim.DYNAMIC
or torch.export.Dim.AUTO
for the dimension we cannot set.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s5]"):
#
sym_size_int_5: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_6: "Sym(s3)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_7: "Sym(s5)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
add_3: "Sym(s1 + s3)" = sym_size_int_5 + sym_size_int_6; sym_size_int_5 = sym_size_int_6 = None
add_4: "Sym(s5 + 1)" = 1 + sym_size_int_7; sym_size_int_7 = None
floordiv_1: "Sym(((s5 + 1)//2))" = add_4 // 2; add_4 = None
eq_3: "Sym(Eq(s1 + s3, ((s5 + 1)//2)))" = add_3 == floordiv_1; add_3 = floordiv_1 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(s1 + s3, ((s5 + 1)//2)) on node 'eq_3'"); eq_3 = _assert_scalar_default = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
slice_1: "f32[s0, s5]" = torch.ops.aten.slice.Tensor(z, 0, 0, 9223372036854775807); z = None
slice_2: "f32[s0, ((s5 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807, 2); slice_1 = None
add_2: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, slice_2); cat = slice_2 = None
return (add_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[0, int_oo], s3: VR[0, int_oo], s5: VR[2, int_oo]}
The same result can be obtained with torch.export.Dim.AUTO
.
AUTO = torch.export.Dim.AUTO
print(
torch.export.export(
model,
(x, y, z),
dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
)
)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s5]"):
#
sym_size_int_5: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_6: "Sym(s3)" = torch.ops.aten.sym_size.int(y, 1)
sym_size_int_7: "Sym(s5)" = torch.ops.aten.sym_size.int(z, 1)
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
#
add_3: "Sym(s1 + s3)" = sym_size_int_5 + sym_size_int_6; sym_size_int_5 = sym_size_int_6 = None
add_4: "Sym(s5 + 1)" = 1 + sym_size_int_7; sym_size_int_7 = None
floordiv_1: "Sym(((s5 + 1)//2))" = add_4 // 2; add_4 = None
eq_3: "Sym(Eq(s1 + s3, ((s5 + 1)//2)))" = add_3 == floordiv_1; add_3 = floordiv_1 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(s1 + s3, ((s5 + 1)//2)) on node 'eq_3'"); eq_3 = _assert_scalar_default = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes_auto.py:19 in forward, code: return torch.cat((x, y), axis=1) + z[:, ::2]
slice_1: "f32[s0, s5]" = torch.ops.aten.slice.Tensor(z, 0, 0, 9223372036854775807); z = None
slice_2: "f32[s0, ((s5 + 1)//2)]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807, 2); slice_1 = None
add_2: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, slice_2); cat = slice_2 = None
return (add_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s3: VR[2, int_oo], s5: VR[2, int_oo]}
Total running time of the script: (0 minutes 1.038 seconds)
Related examples