torch.onnx.export and padding one dimension to a mulitple of a constant

This is a frequent task which does not play well with dynamic shapes. Let’s see how to avoid using torch.cond().

A model with a test

from onnx.reference import ReferenceEvaluator
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from experimental_experiment.helpers import max_diff

We define a model padding to a multiple of a constant.

class PadToMultiple(torch.nn.Module):
    def __init__(
        self,
        multiple: int,
        dim: int = 0,
    ):
        super().__init__()
        self.dim_to_pad = dim
        self.multiple = multiple

    def forward(self, x):
        shape = x.shape
        dim = x.shape[self.dim_to_pad]
        next_dim = ((dim + self.multiple - 1) // self.multiple) * self.multiple
        to_pad = next_dim - dim
        pad = torch.zeros(
            (*shape[: self.dim_to_pad], to_pad, *shape[self.dim_to_pad + 1 :]), dtype=x.dtype
        )
        return torch.cat([x, pad], dim=self.dim_to_pad)


model = PadToMultiple(4, dim=1)

Let’s check it runs.

x = torch.randn((6, 7, 8))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")

# Let's check it runs on another example.
x2 = torch.randn((6, 8, 8))
y2 = model(x2)
print(f"x2.shape={x2.shape}, y2.shape={y2.shape}")
x.shape=torch.Size([6, 7, 8]), y.shape=torch.Size([6, 8, 8])
x2.shape=torch.Size([6, 8, 8]), y2.shape=torch.Size([6, 8, 8])

Export

Let’s defined the dynamic shapes and checks it exports.

DYNAMIC = torch.export.Dim.DYNAMIC
ep = torch.export.export(
    model, (x,), dynamic_shapes=({0: DYNAMIC, 1: DYNAMIC, 2: DYNAMIC},), strict=False
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1, s2]"):
             #
            sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_5: "Sym(s2)" = torch.ops.aten.sym_size.int(x, 2)
            add_1: "Sym(s1 + 3)" = 3 + sym_size_int_4
            floordiv_1: "Sym(((s1 + 3)//4))" = add_1 // 4;  add_1 = None
            mul_1: "Sym(4*(((s1 + 3)//4)))" = 4 * floordiv_1;  floordiv_1 = None
            le: "Sym(s1 <= 4*(((s1 + 3)//4)))" = sym_size_int_4 <= mul_1
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression s1 <= 4*(((s1 + 3)//4)) on node 'le'");  le = _assert_scalar_default = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_dynpad.py:37 in forward, code: next_dim = ((dim + self.multiple - 1) // self.multiple) * self.multiple
            add: "Sym(s1 + 4)" = sym_size_int_4 + 4;  add = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_dynpad.py:38 in forward, code: to_pad = next_dim - dim
            sub_1: "Sym(-s1 + 4*(((s1 + 3)//4)))" = mul_1 - sym_size_int_4;  mul_1 = sym_size_int_4 = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_dynpad.py:39 in forward, code: pad = torch.zeros(
            zeros: "f32[s0, -s1 + 4*(((s1 + 3)//4)), s2]" = torch.ops.aten.zeros.default([sym_size_int_3, sub_1, sym_size_int_5], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int_3 = sub_1 = sym_size_int_5 = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_dynpad.py:42 in forward, code: return torch.cat([x, pad], dim=self.dim_to_pad)
            cat: "f32[s0, 4*(((s1 + 3)//4)), s2]" = torch.ops.aten.cat.default([x, zeros], 1);  x = zeros = None
            return (cat,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s2: VR[2, int_oo]}

We can also inline the local function.

ep = torch.onnx.export(
    model, (x,), dynamic_shapes=({0: "batch", 1: "seq_len", 2: "num_frames"},), dynamo=True
)
[torch.onnx] Obtain model graph for `PadToMultiple()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `PadToMultiple()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅

Let’s save it.

ep.save("plot_exporter_recipes_oe_dynpad.onnx")

Validation

Let’s validate the exported model a set of inputs.

ref = ReferenceEvaluator(ep.model_proto)
inputs = [
    torch.randn((6, 8, 8)),
    torch.randn((6, 7, 8)),
    torch.randn((5, 8, 17)),
    torch.randn((1, 24, 4)),
    torch.randn((3, 9, 11)),
]
for inp in inputs:
    expected = model(inp)
    got = ref.run(None, {"x": inp.numpy()})
    diff = max_diff(expected, got[0])
    print(f"diff with shape={inp.shape} -> {expected.shape}: discrepancies={diff['abs']}")
diff with shape=torch.Size([6, 8, 8]) -> torch.Size([6, 8, 8]): discrepancies=0.0
diff with shape=torch.Size([6, 7, 8]) -> torch.Size([6, 8, 8]): discrepancies=0.0
diff with shape=torch.Size([5, 8, 17]) -> torch.Size([5, 8, 17]): discrepancies=0.0
diff with shape=torch.Size([1, 24, 4]) -> torch.Size([1, 24, 4]): discrepancies=0.0
diff with shape=torch.Size([3, 9, 11]) -> torch.Size([3, 12, 11]): discrepancies=0.0

And visually.

plot exporter recipes oe dynpad

Total running time of the script: (0 minutes 1.018 seconds)

Related examples

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

torch.onnx.export and a model with a test

torch.onnx.export and a model with a test

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

A dynamic dimension lost by torch.export.export

A dynamic dimension lost by torch.export.export

Gallery generated by Sphinx-Gallery