torch.onnx.export: Rename Dynamic Shapes

Example given in Use DYNAMIC or AUTO when dynamic shapes has constraints can only be exported with dynamic shapes using torch.export.Dim.AUTO. As a result, the exported onnx models have dynamic dimensions with unpredictable names.

Model with unpredictable names for dynamic shapes

import torch
from experimental_experiment.helpers import pretty_onnx


class Model(torch.nn.Module):
    def forward(self, x, y, z):
        return torch.cat((x, y), axis=1) + z[:, ::2]


model = Model()
x = torch.randn(2, 3)
y = torch.randn(2, 5)
z = torch.randn(2, 16)
model(x, y, z)
tensor([[ 1.7332,  2.5631,  3.2852, -1.5264,  0.3001,  0.5521, -0.3925, -0.0207],
        [-0.9006, -0.4954,  0.3012, -2.6659,  0.8444,  2.6513,  1.4366, -0.0708]])

Let’s export it.

AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
    model,
    (x, y, z),
    dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
)

Let’s convert it into ONNX.

onx = torch.onnx.export(ep).model_proto

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 1 of general pattern rewrite rules.
 input: EXTERNAL[s0,s1] x
 input: EXTERNAL[s0,s3] y
 input: EXTERNAL[s0,s5] z
output: EXTERNAL[s0,s1 + s3] add_11

Rename the dynamic shapes

We just need to give the onnx exporter the same information torch.export.export() was given but we replace AUTO by the name this dimension should have.

onx = torch.onnx.export(
    ep,
    dynamic_shapes=(
        {0: "batch", 1: "dx"},
        {0: "batch", 1: "dy"},
        {0: "batch", 1: "dx+dy"},
    ),
).model_proto

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_dynamic_shapes.py:253: UserWarning: # The axis name: batch will not be used, since it shares the same shape constraints with another axis: batch.
  warnings.warn(
Applied 1 of general pattern rewrite rules.
 input: EXTERNAL[batch,dx] x
 input: EXTERNAL[batch,dy] y
 input: EXTERNAL[batch,dx+dy] z
output: EXTERNAL[batch,dx + dy] add_11

A model with an unknown output shape

class UnknownOutputModel(torch.nn.Module):
    def forward(self, x):
        return torch.nonzero(x)


model = UnknownOutputModel()
x = torch.randint(0, 2, (10, 2))
model(x)
tensor([[1, 0],
        [2, 0],
        [2, 1],
        [4, 0],
        [5, 1],
        [6, 1],
        [7, 0],
        [7, 1],
        [8, 0],
        [8, 1],
        [9, 0],
        [9, 1]])

Let’s export it.

ep = torch.export.export(
    model,
    (x,),
    dynamic_shapes=({0: torch.export.Dim("batch"), 1: AUTO},),
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[s0, s1]"):
             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_named_ds_auto.py:79 in forward, code: return torch.nonzero(x)
            nonzero: "i64[u0, 2]" = torch.ops.aten.nonzero.default(x);  x = None

             #
            sym_size_int_3: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_3);  sym_constrain_range_for_size_default = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_named_ds_auto.py:79 in forward, code: return torch.nonzero(x)
            ge_1: "Sym(u0 >= 0)" = sym_size_int_3 >= 0;  sym_size_int_3 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            return (nonzero,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='nonzero'), target=None)])
Range constraints: {u0: VR[0, 9223372036854775806], u1: VR[0, 9223372036854775806], s0: VR[0, int_oo], s1: VR[2, int_oo]}

Let’s export it into ONNX.

onx = torch.onnx.export(ep, dynamic_shapes=({0: "batch", 1: "dx"},), dynamo=True).model_proto

for inp in onx.graph.input:
    print(f" input: {pretty_onnx(inp)}")
for out in onx.graph.output:
    print(f"output: {pretty_onnx(out)}")
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
 input: INT64[batch,dx] x
output: INT64[u0,2] nonzero

The exporter has detected a dimension could not be infered from the input shape somewhere in the graph and introduced a new dimension name. Let’s rename it as well. Let’s also change the output name because the functionality may not be implemented yet when the output dynamic shapes are given as a tuple.

try:
    onx = torch.onnx.export(
        ep,
        dynamic_shapes=({0: "batch", 1: "dx"},),
        output_dynamic_shapes={"zeros": {0: "num_zeros"}},
        output_names=["zeros"],
        dynamo=True,
    ).model_proto
    raise AssertionError(
        "able to rename output dynamic dimensions, please update the tutorial"
    )
except (TypeError, torch.onnx._internal.exporter._errors.ConversionError) as e:
    print(f"unable to rename output dynamic dimensions due to {e}")
    onx = None

if onx is not None:
    for inp in onx.graph.input:
        print(f" input: {pretty_onnx(inp)}")
    for out in onx.graph.output:
        print(f"output: {pretty_onnx(out)}")
unable to rename output dynamic dimensions due to export() got an unexpected keyword argument 'output_dynamic_shapes'

Total running time of the script: (0 minutes 3.220 seconds)

Related examples

to_onnx: Rename Dynamic Shapes

to_onnx: Rename Dynamic Shapes

Use DYNAMIC or AUTO when dynamic shapes has constraints

Use DYNAMIC or AUTO when dynamic shapes has constraints

torch.onnx.export and a model with a test

torch.onnx.export and a model with a test

torch.onnx.export and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

A dynamic dimension lost by torch.export.export

A dynamic dimension lost by torch.export.export

Gallery generated by Sphinx-Gallery