to_onnx: Rename Dynamic Shapes

Example given in Use DYNAMIC or AUTO when dynamic shapes has constraints can only be exported with dynamic shapes using torch.export.Dim.AUTO. As a result, the exported onnx models have dynamic dimensions with unpredictable names.

Model with unpredictable names for dynamic shapes

import torch
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z[:, ::2]


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)
tensor([[ 0.6733, -1.4223,  1.3222,  1.0038,  1.7007, -1.6315,  1.1124,  2.6534],
        [ 0.1644, -0.9957, -0.0325,  1.6376,  1.2567,  0.7279,  0.9898, -2.3211]])

Let’s export it.

AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
)

Let’s convert it into ONNX.

onx = to_onnx(ep)

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
 input: EXTERNAL[s0,s1] x
 input: EXTERNAL[s2,s3] y
 input: EXTERNAL[s4,s5] z
output: EXTERNAL[s0,s1+s3] output_0

Rename the dynamic shapes

We just need to give the onnx exporter the same information torch.export.export() was given but we replace AUTO by the name this dimension should have.

onx = to_onnx(
    ep,
    dynamic_shapes=(
        {0: "batch", 1: "dx"},
        {0: "batch", 1: "dy"},
        {0: "batch", 1: "dx+dy"},
    ),
)

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
 input: EXTERNAL[batch,dx] x
 input: EXTERNAL[batch,dy] y
 input: EXTERNAL[batch,dx+dy] z
output: EXTERNAL[batch,dx+dy] output_0

A model with an unknown output shape

class UnknownOutputModel(torch.nn.Module):
    def forward(self, x):
        return torch.nonzero(x)


model = UnknownOutputModel()
x = torch.randint(0, 2, (10, 2))
model(x)
tensor([[0, 0],
        [1, 1],
        [2, 1],
        [3, 0],
        [4, 1],
        [5, 0],
        [6, 0],
        [8, 1],
        [9, 0],
        [9, 1]])

Let’s export it.

ep = torch.export.export(
    model, (x,), dynamic_shapes=({0: torch.export.Dim("batch"), 1: AUTO},)
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[s0, s1]"):
             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_named_ds_auto.py:80 in forward, code: return torch.nonzero(x)
            nonzero: "i64[u0, 2]" = torch.ops.aten.nonzero.default(x);  x = None

             #
            sym_size_int_3: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_3);  sym_constrain_range_for_size_default = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_named_ds_auto.py:80 in forward, code: return torch.nonzero(x)
            ge_1: "Sym(u0 >= 0)" = sym_size_int_3 >= 0;  sym_size_int_3 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            return (nonzero,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='nonzero'), target=None)])
Range constraints: {u0: VR[0, 9223372036854775806], u1: VR[0, 9223372036854775806], s0: VR[0, int_oo], s1: VR[2, int_oo]}

Let’s export it into ONNX.

onx = to_onnx(ep, dynamic_shapes=({0: "batch", 1: "dx"},))

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
 input: INT64[batch,dx] x
output: INT64[NEWDIM_nonzero,2] output_0

The exporter has detected a dimension could not be infered from the input shape somewhere in the graph and introduced a new dimension name. Let’s rename it as well. Let’s also change the output name because the functionality may not be implemented yet when the output dynamic shapes are given as a tuple.

onx = to_onnx(
    ep,
    dynamic_shapes=({0: "batch", 1: "dx"},),
    output_dynamic_shapes={"zeros": {0: "num_zeros"}},
    output_names=["zeros"],
)

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
 input: INT64[batch,dx] x
output: INT64[num_zeros,2] zeros

Total running time of the script: (0 minutes 0.344 seconds)

Related examples

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

Use DYNAMIC or AUTO when dynamic shapes has constraints

Use DYNAMIC or AUTO when dynamic shapes has constraints

A dynamic dimension lost by torch.export.export

A dynamic dimension lost by torch.export.export

to_onnx and a custom operator inplace

to_onnx and a custom operator inplace

Do no use Module as inputs!

Do no use Module as inputs!

Gallery generated by Sphinx-Gallery