Note
Go to the end to download the full example code.
0, 1, 2 for a Dynamic Dimension in the dummy example to export a model¶
torch.export.export() does not work if a tensor given to the function
has 0 or 1 for dimension declared as dynamic dimension.
Simple model, no dimension with 0 or 1¶
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
class Model(torch.nn.Module):
def forward(self, x, y, z):
return torch.cat((x, y), axis=1) + z
model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 8)
model(x, y, z)
DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}
print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
-- export shape: (T1s2x3,T1s2x5,T1s2x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s17, s27]", y: "f32[s17, s94]", z: "f32[s17, s27 + s94]"):
# File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_dim1.py:22 in forward, code: return torch.cat((x, y), axis=1) + z
cat: "f32[s17, s27 + s94]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
add: "f32[s17, s27 + s94]" = torch.ops.aten.add.Tensor(cat, z); cat = z = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo], s94: VR[2, int_oo], s27 + s94: VR[4, int_oo]}
Same model, a dynamic dimension = 1¶
z = z[:1]
DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}
print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))
try:
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
except Exception as e:
print("ERROR", e)
-- export shape: (T1s2x3,T1s2x5,T1s1x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ERROR Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['z'].shape[0].
It failed. Let’s try a little trick.
Same model, a dynamic dimension = 1 and backed_size_oblivious=True¶
print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))
try:
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
except RuntimeError as e:
print("ERROR", e)
-- export shape: (T1s2x3,T1s2x5,T1s1x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ERROR The size of tensor a (s17) must match the size of tensor b (s68) at non-singleton dimension 0)
Final try with patches…¶
print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))
with torch_export_patches(patch_torch=1):
try:
ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
except RuntimeError as e:
print("ERROR", e)
-- export shape: (T1s2x3,T1s2x5,T1s1x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s17, s27]", y: "f32[s17, s94]", z: "f32[1, s27 + s94]"):
# File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_dim1.py:22 in forward, code: return torch.cat((x, y), axis=1) + z
cat: "f32[s17, s27 + s94]" = torch.ops.aten.cat.default([x, y], 1); x = y = None
add: "f32[s17, s27 + s94]" = torch.ops.aten.add.Tensor(cat, z); cat = z = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
z: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo], s94: VR[2, int_oo], s27 + s94: VR[4, int_oo]}
It is difficult to find the good option. It is possible on a simple model but sometimes impossible on a bigger model mixing different shapes.
doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")

Total running time of the script: (0 minutes 6.053 seconds)
Related examples
Cannot export torch.sym_max(x.shape[0], y.shape[0])
Cannot export torch.sym_max(x.shape[0], y.shape[0])