A few tricks about dynamic shapes

Settings the dynamic shapes is not always easy. Here are a few tricks to make it work.

dx + dy not allowed?

import torch


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 4)
z = torch.randn(2, 7)
model(x, y, z)


print(torch.export.export(model, (x, y, z)).graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %z : [num_users=1] = placeholder[target=z]
    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %y], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%cat, %z), kwargs = {})
    return (add,)

Everything is fine so far. With dynamic shapes now. dx + dy is not allowed…

batch = torch.export.Dim("batch")
dx = torch.export.Dim("dz")
dy = torch.export.Dim("dy")

try:
    dz = dx + dy
except Exception as e:
    print(f"unable to add dimension because {e}")
unable to add dimension because Attempted to add <class '__main__.dy'> to dz, where an integer was expected. (Only increasing linear operations with integer coefficients are supported.)

Then we could make it a different one.

dz = torch.export.Dim("dz")
try:
    torch.export.export(
        model,
        (x, y, z),
        dynamic_shapes={
            "x": {0: batch, 1: dx},
            "y": {0: batch, 1: dy},
            "z": {0: batch, 1: dz},
        },
    )
except Exception as e:
    print(e)
L['z'].size()[1] = 7 is not equal to L['x'].size()[1] = 3

Still no luck but with torch.export.Dim.DYNAMIC.

ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes={
        "x": {0: batch, 1: dx},
        "y": {0: batch, 1: dy},
        "z": {0: batch, 1: torch.export.Dim.DYNAMIC},
    },
)

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1]", y: "f32[s0, s3]", z: "f32[s0, s1 + s3]"):
             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes.py:19 in forward, code: return torch.cat((x, y), axis=1) + z
            cat: "f32[s0, s1 + s3]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
            add: "f32[s0, s1 + s3]" = torch.ops.aten.add.Tensor(cat, z);  cat = z = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[0, int_oo], s1: VR[0, int_oo], s3: VR[0, int_oo], s1 + s3: VR[4, int_oo]}

Still no luck but with torch.export.Dim.AUTO.

print(
    torch.export.export(
        model,
        (x, y, z),
        dynamic_shapes=(
            {0: batch, 1: torch.export.Dim.STATIC},
            {0: batch, 1: torch.export.Dim.AUTO},
            {0: batch, 1: torch.export.Dim.AUTO},
        ),
    )
)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, 3]", y: "f32[s0, s2]", z: "f32[s0, s2 + 3]"):
             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_dynamic_shapes.py:19 in forward, code: return torch.cat((x, y), axis=1) + z
            cat: "f32[s0, s2 + 3]" = torch.ops.aten.cat.default([x, y], 1);  x = y = None
            add: "f32[s0, s2 + 3]" = torch.ops.aten.add.Tensor(cat, z);  cat = z = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[0, int_oo], s2: VR[2, int_oo], s2 + 3: VR[5, int_oo]}

Total running time of the script: (0 minutes 1.050 seconds)

Related examples

Do no use Module as inputs!

Do no use Module as inputs!

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

Export a model using a custom type as input

Export a model using a custom type as input

Gallery generated by Sphinx-Gallery