Note
Go to the end to download the full example code.
Cannot export torch.sym_max(x.shape[0], y.shape[0])
¶
This is related to the following issues: Cannot export torch.sym_max(x.shape[0], y.shape[0]).
The algorithm trying to automatically infer shapes after every operator in the exported program is something very aggreessive. Here is a case where it takes a wrong decision and how to get around it.
Wrong Model¶
import torch
from onnx_diagnostic import doc
class Model(torch.nn.Module):
def forward(self, x, y, fact):
s1 = max(x.shape[0], y.shape[0])
s2 = max(x.shape[1], y.shape[1])
# Shapes cannot be known here.
z = torch.zeros((s1, s2), dtype=x.dtype)
z[: x.shape[0], : x.shape[1]] = x
z[: y.shape[0], : y.shape[1]] += y
return z * fact
model = Model()
x = torch.arange(6).reshape((2, 3))
y = torch.arange(6).reshape((3, 2)) * 10
fact = torch.tensor([[1, 2, 3]], dtype=x.dtype)
z = model(x, y, fact)
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
x.shape=torch.Size([2, 3]), y.shape=torch.Size([3, 2]), z.shape=torch.Size([3, 3])
Export¶
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[s35, s16]", y: "i64[s58, s43]", fact: "i64[1, s16]"):
#
sym_size_int_8: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_9: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_10: "Sym(s58)" = torch.ops.aten.sym_size.int(y, 0)
sym_size_int_11: "Sym(s43)" = torch.ops.aten.sym_size.int(y, 1)
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:26 in forward, code: z = torch.zeros((s1, s2), dtype=x.dtype)
zeros: "i64[s58, s16]" = torch.ops.aten.zeros.default([sym_size_int_10, sym_size_int_9], dtype = torch.int64, device = device(type='cpu'), pin_memory = False); sym_size_int_9 = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:27 in forward, code: z[: x.shape[0], : x.shape[1]] = x
slice_1: "i64[s35, s16]" = torch.ops.aten.slice.Tensor(zeros, 0, 0, sym_size_int_8); sym_size_int_8 = None
copy_: "i64[s35, s16]" = torch.ops.aten.copy_.default(slice_1, x); slice_1 = x = copy_ = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:28 in forward, code: z[: y.shape[0], : y.shape[1]] += y
slice_2: "i64[s58, s16]" = torch.ops.aten.slice.Tensor(zeros, 0, None, sym_size_int_10); sym_size_int_10 = None
slice_3: "i64[s58, s43]" = torch.ops.aten.slice.Tensor(slice_2, 1, None, sym_size_int_11); slice_2 = None
add_: "i64[s58, s43]" = torch.ops.aten.add_.Tensor(slice_3, y); slice_3 = y = None
slice_4: "i64[s58, s43]" = torch.ops.aten.slice.Tensor(zeros, 1, 0, sym_size_int_11); sym_size_int_11 = None
copy__1: "i64[s58, s43]" = torch.ops.aten.copy_.default(slice_4, add_); slice_4 = add_ = copy__1 = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:29 in forward, code: return z * fact
mul: "i64[s58, s16]" = torch.ops.aten.mul.Tensor(zeros, fact); zeros = fact = None
return (mul,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
fact: USER_INPUT
# outputs
mul: USER_OUTPUT
Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo], s58: VR[2, int_oo], s43: VR[2, int_oo]}
But does it really work? Let’s print the shapes.
case 1: torch.Size([3, 3]) torch.Size([3, 3])
Case with different shapes.
case 2 failed: Expected input at *args[2].shape[1] to be equal to 2, but got 3
It does not even compute. The exported program does not get the correct shape.
Rewritten Model¶
max
does not get captured, torch.sym_max()
is no better,
torch.max()
only works on tensors. Nothing really works.
We use a trick to introduce new shape the shape inference algorithm
cannot know. This requires to hide the failing logic in a custom operator.
def make_undefined_dimension(i: int) -> torch.SymInt:
"""
Uses for a custom op when a new dimension must be introduced to bypass
some verification. The following function creates a dummy output
with a dimension based on the content.
.. code-block:: python
def symbolic_shape(x, y):
return torch.empty(
x.shape[0],
make_undefined_dimension(min(x.shape[1], y[0])),
)
"""
t = torch.ones((i * 2,))
t[:i] = 0
res = torch.nonzero(t).shape[0]
return res
def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape))
z = torch.zeros(tuple(shape), dtype=x.dtype)
z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]]
z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]]
return z
def symbolic_shape(x, y):
return torch.empty(
tuple(
make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape))
),
dtype=x.dtype,
)
def register(fct, fct_shape, namespace, fname):
schema_str = torch.library.infer_schema(fct, mutates_args=())
custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
custom_def.register_kernel("cpu")(fct)
custom_def._abstract_fn = fct_shape
register(
copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions"
)
Now everything is registered. Let’s rewrite the model.
And check it works.
rewritten_model = RewrittenModel()
x = torch.arange(6).reshape((2, 3))
y = torch.arange(6).reshape((3, 2)) * 10
z = rewritten_model(x, y, fact)
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
x.shape=torch.Size([2, 3]), y.shape=torch.Size([3, 2]), z.shape=torch.Size([3, 3])
Export again¶
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "i64[s35, s16]", y: "i64[s58, s43]", fact: "i64[1, s23]"):
#
sym_size_int_4: "Sym(s23)" = torch.ops.aten.sym_size.int(fact, 1)
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:132 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
copy_max_dimensions: "i64[u0, s23]" = torch.ops.mylib.copy_max_dimensions.default(x, y); x = y = None
#
sym_size_int_5: "Sym(u0)" = torch.ops.aten.sym_size.int(copy_max_dimensions, 0)
sym_size_int_6: "Sym(s23)" = torch.ops.aten.sym_size.int(copy_max_dimensions, 1)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_5); sym_constrain_range_for_size_default = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:132 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
ge: "Sym(u0 >= 0)" = sym_size_int_5 >= 0; sym_size_int_5 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
#
sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_6); sym_constrain_range_for_size_default_1 = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:132 in forward, code: z = torch.ops.mylib.copy_max_dimensions(x, y)
ge_1: "Sym(s23 >= 2)" = sym_size_int_6 >= 2
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 2 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None
#
eq: "Sym(True)" = sym_size_int_6 == sym_size_int_4; sym_size_int_6 = sym_size_int_4 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, s23) on node 'eq'"); eq = _assert_scalar_default_2 = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_max.py:133 in forward, code: return z * fact
mul: "i64[u0, s23]" = torch.ops.aten.mul.Tensor(copy_max_dimensions, fact); copy_max_dimensions = fact = None
return (mul,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
fact: USER_INPUT
# outputs
mul: USER_OUTPUT
Range constraints: {u0: VR[0, 9223372036854775806], u1: VR[2, 9223372036854775806], s35: VR[2, int_oo], s16: VR[2, int_oo], s58: VR[2, int_oo], s43: VR[2, int_oo], s23: VR[2, 9223372036854775806]}
We check it works.
case 1: torch.Size([3, 3]) torch.Size([3, 3])
case 2: torch.Size([3, 3]) torch.Size([3, 3])
Final Check on very different dimension¶
x = torch.arange(6 * 8).reshape((6, 8))
y = torch.arange(10 * 4).reshape((10, 4)) * 10
fact = torch.arange(8).reshape((1, -1))
print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape)
final case: torch.Size([10, 8]) torch.Size([10, 8])
This is not perfect as we get an exported program but some logic is hidden in a custom operator.
doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow")

Total running time of the script: (0 minutes 0.261 seconds)
Related examples