to_onnx and padding one dimension to a mulitple of a constant

This is a frequent task which does not play well with dynamic shapes. Let’s see how to avoid using torch.cond().

A model with a test

import onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.helpers import pretty_onnx, max_diff
from experimental_experiment.torch_interpreter import to_onnx

We define a model padding to a multiple of a constant.

class PadToMultiple(torch.nn.Module):
    def __init__(
        self,
        multiple: int,
        dim: int = 0,
    ):
        super().__init__()
        self.dim_to_pad = dim
        self.multiple = multiple

    def forward(self, x):
        shape = x.shape
        dim = x.shape[self.dim_to_pad]
        next_dim = ((dim + self.multiple - 1) // self.multiple) * self.multiple
        to_pad = next_dim - dim
        pad = torch.zeros(
            (*shape[: self.dim_to_pad], to_pad, *shape[self.dim_to_pad + 1 :]), dtype=x.dtype
        )
        return torch.cat([x, pad], dim=self.dim_to_pad)


model = PadToMultiple(4, dim=1)

Let’s check it runs.

x = torch.randn((6, 7, 8))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")

# Let's check it runs on another example.
x2 = torch.randn((6, 8, 8))
y2 = model(x2)
print(f"x2.shape={x2.shape}, y2.shape={y2.shape}")
x.shape=torch.Size([6, 7, 8]), y.shape=torch.Size([6, 8, 8])
x2.shape=torch.Size([6, 8, 8]), y2.shape=torch.Size([6, 8, 8])

Export

Let’s defined the dynamic shapes and checks it exports.

DYNAMIC = torch.export.Dim.DYNAMIC
ep = torch.export.export(
    model, (x,), dynamic_shapes=({0: DYNAMIC, 1: DYNAMIC, 2: DYNAMIC},), strict=False
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1, s2]"):
             #
            sym_size_int_3: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_4: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
            sym_size_int_5: "Sym(s2)" = torch.ops.aten.sym_size.int(x, 2)
            add_1: "Sym(s1 + 3)" = 3 + sym_size_int_4
            floordiv_1: "Sym(((s1 + 3)//4))" = add_1 // 4;  add_1 = None
            mul_1: "Sym(4*(((s1 + 3)//4)))" = 4 * floordiv_1;  floordiv_1 = None
            le: "Sym(s1 <= 4*(((s1 + 3)//4)))" = sym_size_int_4 <= mul_1
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression s1 <= 4*(((s1 + 3)//4)) on node 'le'");  le = _assert_scalar_default = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:39 in forward, code: next_dim = ((dim + self.multiple - 1) // self.multiple) * self.multiple
            add: "Sym(s1 + 4)" = sym_size_int_4 + 4;  add = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:40 in forward, code: to_pad = next_dim - dim
            sub_1: "Sym(-s1 + 4*(((s1 + 3)//4)))" = mul_1 - sym_size_int_4;  mul_1 = sym_size_int_4 = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:41 in forward, code: pad = torch.zeros(
            zeros: "f32[s0, -s1 + 4*(((s1 + 3)//4)), s2]" = torch.ops.aten.zeros.default([sym_size_int_3, sub_1, sym_size_int_5], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int_3 = sub_1 = sym_size_int_5 = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:44 in forward, code: return torch.cat([x, pad], dim=self.dim_to_pad)
            cat: "f32[s0, 4*(((s1 + 3)//4)), s2]" = torch.ops.aten.cat.default([x, zeros], 1);  x = zeros = None
            return (cat,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s2: VR[2, int_oo]}

We can also inline the local function.

onx = to_onnx(model, (x,), dynamic_shapes=({0: "batch", 1: "seq_len", 2: "num_frames"},))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=['batch', 'seq_len', 'num_frames']
init: name='init7_s_3' type=int64 shape=() -- array([3])              -- shape_type_compute._cast_inputs.1(add)
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_4' type=int64 shape=(1,) -- array([4])           -- Opset.make_node.1/Shape
init: name='init7_s_4' type=int64 shape=() -- array([4])              -- shape_type_compute._cast_inputs.1(mul)
Reshape(init7_s_4, init7_s1_1) -> _reshape_init7_s_40
Shape(x, end=1, start=0) -> _shape_x0
Shape(x, end=2, start=1) -> _shape_x02
  Squeeze(_shape_x02) -> sym_size_int_4
    Add(init7_s_3, sym_size_int_4) -> _onx_add_init7_s_30
      Div(_onx_add_init7_s_30, init7_s1_4) -> _onx_div_add_10
  Mul(_reshape_init7_s_40, _onx_div_add_10) -> _onx_mul__reshape_init7_s_400
Shape(x, end=3, start=2) -> _shape_x03
Reshape(sym_size_int_4, init7_s1_1) -> _reshape_sym_size_int_403
  Sub(_onx_mul__reshape_init7_s_400, _reshape_sym_size_int_403) -> sub
  Concat(_shape_x0, sub, _shape_x03, axis=0) -> _onx_concat_unsqueeze_sym_size_int_300
    ConstantOfShape(_onx_concat_unsqueeze_sym_size_int_300, value=[0.0]) -> zeros
      Concat(x, zeros, axis=1) -> output_0
output: name='output_0' type=dtype('float32') shape=['batch', 'seq_len+-seq_len+4*((seq_len+3)//4)', 'num_frames']

We save it.

onnx.save(onx, "plot_exporter_recipes_c_dynpad.onnx")

Validation

Let’s validate the exported model a set of inputs.

ref = ExtendedReferenceEvaluator(onx)
inputs = [
    torch.randn((6, 8, 8)),
    torch.randn((6, 7, 8)),
    torch.randn((5, 8, 17)),
    torch.randn((1, 24, 4)),
    torch.randn((3, 9, 11)),
]
for inp in inputs:
    expected = model(inp)
    got = ref.run(None, {"x": inp.numpy()})
    diff = max_diff(expected, got[0])
    print(f"diff with shape={inp.shape} -> {expected.shape}: discrepancies={diff['abs']}")
diff with shape=torch.Size([6, 8, 8]) -> torch.Size([6, 8, 8]): discrepancies=0.0
diff with shape=torch.Size([6, 7, 8]) -> torch.Size([6, 8, 8]): discrepancies=0.0
diff with shape=torch.Size([5, 8, 17]) -> torch.Size([5, 8, 17]): discrepancies=0.0
diff with shape=torch.Size([1, 24, 4]) -> torch.Size([1, 24, 4]): discrepancies=0.0
diff with shape=torch.Size([3, 9, 11]) -> torch.Size([3, 12, 11]): discrepancies=0.0

And visually.

plot exporter recipes c dynpad

Total running time of the script: (0 minutes 0.585 seconds)

Related examples

torch.onnx.export and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

A dynamic dimension lost by torch.export.export

A dynamic dimension lost by torch.export.export

to_onnx and a model with a test

to_onnx and a model with a test

to_onnx: Rename Dynamic Shapes

to_onnx: Rename Dynamic Shapes

Gallery generated by Sphinx-Gallery