Note
Go to the end to download the full example code.
Do not use python int with dynamic shapes¶
torch.export.export()
uses torch.SymInt
to operate on shapes and
optimizes the graph it produces. It checks if two tensors share the same dimension,
if the shapes can be broadcast, … To do that, python types must not be used
or the algorithm looses information.
Wrong Model¶
import math
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
class Model(torch.nn.Module):
def dim(self, i, divisor):
return int(math.ceil(i / divisor)) # noqa: RUF046
def forward(self, x):
new_shape = (self.dim(x.shape[0], 8), x.shape[1])
return torch.zeros(new_shape)
model = Model()
x = torch.rand((10, 15))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x.shape=torch.Size([10, 15]), y.shape=torch.Size([2, 15])
Export¶
DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s35, s16]"):
#
sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1); x = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:26 in forward, code: return torch.zeros(new_shape)
zeros: "f32[2, s16]" = torch.ops.aten.zeros.default([2, sym_size_int_3], device = device(type='cpu'), pin_memory = False); sym_size_int_3 = None
return (zeros,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
zeros: USER_OUTPUT
Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
The last dimension became static. We must not use int.
math.ceil()
should be avoided as well since it is a python operation.
The exporter may fail to detect it is operating on shapes.
Rewrite¶
class RewrittenModel(torch.nn.Module):
def dim(self, i, divisor):
return (i + divisor - 1) // divisor
def forward(self, x):
new_shape = (self.dim(x.shape[0], 8), x.shape[1])
return torch.zeros(new_shape)
rewritten_model = RewrittenModel()
y = rewritten_model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x.shape=torch.Size([10, 15]), y.shape=torch.Size([2, 15])
Export¶
ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s35, s16]"):
#
sym_size_int_2: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1); x = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:56 in forward, code: new_shape = (self.dim(x.shape[0], 8), x.shape[1])
add: "Sym(s35 + 8)" = sym_size_int_2 + 8; sym_size_int_2 = None
sub: "Sym(s35 + 7)" = add - 1; add = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/__init__.py:465 in __floordiv__, code: return self.__int_floordiv__(other)
floordiv: "Sym(((s35 + 7)//8))" = sub // 8; sub = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:57 in forward, code: return torch.zeros(new_shape)
zeros: "f32[((s35 + 7)//8), s16]" = torch.ops.aten.zeros.default([floordiv, sym_size_int_3], device = device(type='cpu'), pin_memory = False); floordiv = sym_size_int_3 = None
return (zeros,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
zeros: USER_OUTPUT
Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
Find the error¶
Function onnx_diagnostic.torch_export_patches.bypass_export_some_errors()
has a parameter stop_if_static
which patches torch to raise exception
when something like that is happening.
with bypass_export_some_errors(stop_if_static=True):
ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s35, s16]"):
#
sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1); x = None
# File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:26 in forward, code: return torch.zeros(new_shape)
zeros: "f32[2, s16]" = torch.ops.aten.zeros.default([2, sym_size_int_3], device = device(type='cpu'), pin_memory = False); sym_size_int_3 = None
return (zeros,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
zeros: USER_OUTPUT
Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
doc.plot_legend("dynamic shapes\ndo not cast to\npython int", "dynamic shapes", "yellow")

Total running time of the script: (0 minutes 0.126 seconds)
Related examples