Note
Go to the end to download the full example code.
to_onnx and padding one dimension to a mulitple of a constant¶
This is a frequent task which does not play well with dynamic shapes.
Let’s see how to avoid using torch.cond()
.
A model with a test¶
import onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.torch_interpreter import to_onnx
We define a model padding to a multiple of a constant.
class PadToMultiple(torch.nn.Module):
def __init__(
self,
multiple: int,
dim: int = 0,
):
super().__init__()
self.dim_to_pad = dim
self.multiple = multiple
def forward(self, x):
shape = x.shape
dim = x.shape[self.dim_to_pad]
next_dim = ((dim + self.multiple - 1) // self.multiple) * self.multiple
to_pad = next_dim - dim
pad = torch.zeros(
(*shape[: self.dim_to_pad], to_pad, *shape[self.dim_to_pad + 1 :]), dtype=x.dtype
)
return torch.cat([x, pad], dim=self.dim_to_pad)
model = PadToMultiple(4, dim=1)
Let’s check it runs.
x.shape=torch.Size([6, 7, 8]), y.shape=torch.Size([6, 8, 8])
x2.shape=torch.Size([6, 8, 8]), y2.shape=torch.Size([6, 8, 8])
Export¶
Let’s defined the dynamic shapes and checks it exports.
DYNAMIC = torch.export.Dim.DYNAMIC
ep = torch.export.export(
model, (x,), dynamic_shapes=({0: DYNAMIC, 1: DYNAMIC, 2: DYNAMIC},), strict=False
)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s77, s27, s53]"):
#
sym_size_int_3: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_4: "Sym(s27)" = torch.ops.aten.sym_size.int(x, 1)
sym_size_int_5: "Sym(s53)" = torch.ops.aten.sym_size.int(x, 2)
# File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:40 in forward, code: next_dim = ((dim + self.multiple - 1) // self.multiple) * self.multiple
add: "Sym(s27 + 4)" = sym_size_int_4 + 4
sub: "Sym(s27 + 3)" = add - 1; add = None
floordiv: "Sym(((s27 + 3)//4))" = sub // 4; sub = None
mul: "Sym(4*(((s27 + 3)//4)))" = floordiv * 4; floordiv = None
# File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:41 in forward, code: to_pad = next_dim - dim
sub_1: "Sym(-s27 + 4*(((s27 + 3)//4)))" = mul - sym_size_int_4; mul = sym_size_int_4 = None
# File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:42 in forward, code: pad = torch.zeros(
zeros: "f32[s77, -s27 + 4*(((s27 + 3)//4)), s53]" = torch.ops.aten.zeros.default([sym_size_int_3, sub_1, sym_size_int_5], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); sym_size_int_3 = sub_1 = sym_size_int_5 = None
# File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_dynpad.py:45 in forward, code: return torch.cat([x, pad], dim=self.dim_to_pad)
cat: "f32[s77, 4*(((s27 + 3)//4)), s53]" = torch.ops.aten.cat.default([x, zeros], 1); x = zeros = None
return (cat,)
Graph signature:
# inputs
x: USER_INPUT
# outputs
cat: USER_OUTPUT
Range constraints: {s77: VR[2, int_oo], s27: VR[2, int_oo], s53: VR[2, int_oo]}
We can also inline the local function.
onx = to_onnx(model, (x,), dynamic_shapes=({0: "batch", 1: "seq_len", 2: "num_frames"},))
print(pretty_onnx(onx))
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=['batch', 'seq_len', 'num_frames']
init: name='init7_s_4' type=int64 shape=() -- array([4]) -- shape_type_compute._cast_inputs.1(add)##Opset.make_node.1/Shape##shape_type_compute._cast_inputs.1(mul)
init: name='init7_s_1' type=int64 shape=() -- array([1]) -- shape_type_compute._cast_inputs.0
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0]) -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
Shape(x, end=1, start=0) -> x::Shape:1
Shape(x, end=2, start=1) -> x::Shape1:2
Squeeze(x::Shape1:2) -> sym_size_int_4
Add(sym_size_int_4, init7_s_4) -> _onx_add_sym_size_int_4
Sub(_onx_add_sym_size_int_4, init7_s_1) -> sub
Div(sub, init7_s_4) -> _onx_div_sub
Mul(_onx_div_sub, init7_s_4) -> _onx_mul_floordiv
Sub(_onx_mul_floordiv, sym_size_int_4) -> sub_1
Unsqueeze(sub_1, init7_s1_0) -> sub_1::UnSq0
Shape(x, end=3, start=2) -> x::Shape2:3
Concat(x::Shape:1, sub_1::UnSq0, x::Shape2:3, axis=0) -> _onx_concat_sym_size_int_3::UnSq0
ConstantOfShape(_onx_concat_sym_size_int_3::UnSq0, value=[0.0]) -> zeros
Concat(x, zeros, axis=1) -> output_0
output: name='output_0' type=dtype('float32') shape=['batch', 'seq_len+-seq_len+4*((seq_len+3)//4)', 'num_frames']
We save it.
onnx.save(onx, "plot_exporter_recipes_c_dynpad.onnx")
Validation¶
Let’s validate the exported model a set of inputs.
ref = ExtendedReferenceEvaluator(onx)
inputs = [
torch.randn((6, 8, 8)),
torch.randn((6, 7, 8)),
torch.randn((5, 8, 17)),
torch.randn((1, 24, 4)),
torch.randn((3, 9, 11)),
]
for inp in inputs:
expected = model(inp)
got = ref.run(None, {"x": inp.numpy()})
diff = max_diff(expected, got[0])
print(f"diff with shape={inp.shape} -> {expected.shape}: discrepancies={diff['abs']}")
diff with shape=torch.Size([6, 8, 8]) -> torch.Size([6, 8, 8]): discrepancies=0.0
diff with shape=torch.Size([6, 7, 8]) -> torch.Size([6, 8, 8]): discrepancies=0.0
diff with shape=torch.Size([5, 8, 17]) -> torch.Size([5, 8, 17]): discrepancies=0.0
diff with shape=torch.Size([1, 24, 4]) -> torch.Size([1, 24, 4]): discrepancies=0.0
diff with shape=torch.Size([3, 9, 11]) -> torch.Size([3, 12, 11]): discrepancies=0.0
And visually.

Total running time of the script: (0 minutes 0.617 seconds)
Related examples

to_onnx and a custom operator registered with a function
to_onnx and a custom operator registered with a function