0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

torch.export.export() does not work if a tensor given to the function has 0 or 1 for dimension declared as dynamic dimension.

Simple model, no dimension with 0 or 1

import torch
from onnx_diagnostic import doc


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 8)
model(x, y, z)

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
print(ep.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %z : [num_users=1] = placeholder[target=z]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
    return (add,)

Same model, a dynamic dimension = 1

z = z[:1]

DYN = torch.export.Dim.DYNAMIC
ds = {0: DYN, 1: DYN}

try:
    ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
    print(ep.graph)
except Exception as e:
    print("ERROR", e)
ERROR Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but export 0/1 specialized due to hint of 1 for dimension inputs['z'].shape[0].

It failed. Let’s try a little trick.

Same model, a dynamic dimension = 1 and backed_size_oblivious=True

try:
    with torch.fx.experimental._config.patch(backed_size_oblivious=True):
        ep = torch.export.export(model, (x, y, z), dynamic_shapes=(ds, ds, ds))
        print(ep.graph)
except RuntimeError as e:
    print("ERROR", e)
ERROR The size of tensor a (s17) must match the size of tensor b (s68) at non-singleton dimension 0)

It worked.

doc.plot_legend("dynamic dimension\nworking with\n0 or 1", "torch.export.export", "green")
plot export dim1

Total running time of the script: (0 minutes 0.253 seconds)

Related examples

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Export a model with a control flow (If)

Export a model with a control flow (If)

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Cannot export torch.sym_max(x.shape[0], y.shape[0])

Gallery generated by Sphinx-Gallery