0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

torch.export.export() does not work if a tensor given to the function has 0 or 1 for dimension declared as dynamic dimension.

Simple model, no dimension with 0 or 1

import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 8)
model(x, y, z)

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep)
-- export shape: (T1s2x3,T1s2x5,T1s2x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s17, s27]", y: "f32[s17, s94]", z: "f32[s17, s27 + s94]"):
             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_dim1.py:22 in forward, code: return torch.cat((x, y), axis=1) + z
            cat: "f32[s17, s27 + s94]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
            add: "f32[s17, s27 + s94]" = torch.ops.aten.add.Tensor(cat, z);  cat = z = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo], s94: VR[2, int_oo], s27 + s94: VR[4, int_oo]}

Same model, a dynamic dimension = 1

z = z[:1]

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

try:
    ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
    print(ep)
except Exception as e:
    print("ERROR", e)
-- export shape: (T1s2x3,T1s2x5,T1s1x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ERROR Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['z'].shape[0].

It failed. Let’s try a little trick.

Same model, a dynamic dimension = 1 and backed_size_oblivious=True

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

try:
    with torch.fx.experimental._config.patch(backed_size_oblivious=True):
        ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
        print(ep)
except RuntimeError as e:
    print("ERROR", e)
-- export shape: (T1s2x3,T1s2x5,T1s1x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ERROR The size of tensor a (s17) must match the size of tensor b (s68) at non-singleton dimension 0)

Final try with patches…

print("-- export shape:", string_type((x, y, z), with_shape=True))
print("-- dynamic shapes:", string_type((ds, ds, ds)))

with torch_export_patches(patch_torch=1):
    try:
        ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
        print(ep)
    except RuntimeError as e:
        print("ERROR", e)
-- export shape: (T1s2x3,T1s2x5,T1s1x8)
-- dynamic shapes: ({0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC},{0:DYNAMIC,1:DYNAMIC})
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s17, s27]", y: "f32[s17, s94]", z: "f32[1, s27 + s94]"):
             # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_dim1.py:22 in forward, code: return torch.cat((x, y), axis=1) + z
            cat: "f32[s17, s27 + s94]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
            add: "f32[s17, s27 + s94]" = torch.ops.aten.add.Tensor(cat, z);  cat = z = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo], s94: VR[2, int_oo], s27 + s94: VR[4, int_oo]}

It is difficult to find the good option. It is possible on a simple model but sometimes impossible on a bigger model mixing different shapes.

doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")
plot export dim1

Total running time of the script: (0 minutes 6.053 seconds)

Related examples

Builds dynamic shapes from any input

Builds dynamic shapes from any input

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Gallery generated by Sphinx-Gallery