Do not use python int with dynamic shapes

torch.export.export() uses torch.SymInt to operate on shapes and optimizes the graph it produces. It checks if two tensors share the same dimension, if the shapes can be broadcast, … To do that, python types must not be used or the algorithm looses information.

Wrong Model

import math
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class Model(torch.nn.Module):
    def dim(self, i, divisor):
        return int(math.ceil(i / divisor))  # noqa: RUF046

    def forward(self, x):
        new_shape = (self.dim(x.shape[0], 8), x.shape[1])
        return torch.zeros(new_shape)


model = Model()
x = torch.rand((10, 15))
y = model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x.shape=torch.Size([10, 15]), y.shape=torch.Size([2, 15])

Export

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]"):
             #
            sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1);  x = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:26 in forward, code: return torch.zeros(new_shape)
            zeros: "f32[2, s16]" = torch.ops.aten.zeros.default([2, sym_size_int_3], device = device(type='cpu'), pin_memory = False);  sym_size_int_3 = None
            return (zeros,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    zeros: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}

The last dimension became static. We must not use int. math.ceil() should be avoided as well since it is a python operation. The exporter may fail to detect it is operating on shapes.

Rewrite

class RewrittenModel(torch.nn.Module):
    def dim(self, i, divisor):
        return (i + divisor - 1) // divisor

    def forward(self, x):
        new_shape = (self.dim(x.shape[0], 8), x.shape[1])
        return torch.zeros(new_shape)


rewritten_model = RewrittenModel()
y = rewritten_model(x)
print(f"x.shape={x.shape}, y.shape={y.shape}")
x.shape=torch.Size([10, 15]), y.shape=torch.Size([2, 15])

Export

ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]"):
             #
            sym_size_int_2: "Sym(s35)" = torch.ops.aten.sym_size.int(x, 0)
            sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1);  x = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:56 in forward, code: new_shape = (self.dim(x.shape[0], 8), x.shape[1])
            add: "Sym(s35 + 8)" = sym_size_int_2 + 8;  sym_size_int_2 = None
            sub: "Sym(s35 + 7)" = add - 1;  add = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/__init__.py:465 in __floordiv__, code: return self.__int_floordiv__(other)
            floordiv: "Sym(((s35 + 7)//8))" = sub // 8;  sub = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:57 in forward, code: return torch.zeros(new_shape)
            zeros: "f32[((s35 + 7)//8), s16]" = torch.ops.aten.zeros.default([floordiv, sym_size_int_3], device = device(type='cpu'), pin_memory = False);  floordiv = sym_size_int_3 = None
            return (zeros,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    zeros: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}

Find the error

Function onnx_diagnostic.torch_export_patches.bypass_export_some_errors() has a parameter stop_if_static which patches torch to raise exception when something like that is happening.

with bypass_export_some_errors(stop_if_static=True):
    ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
    print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]"):
             #
            sym_size_int_3: "Sym(s16)" = torch.ops.aten.sym_size.int(x, 1);  x = None

             # File: /home/xadupre/github/onnx-diagnostic/_doc/recipes/plot_dynamic_shapes_python_int.py:26 in forward, code: return torch.zeros(new_shape)
            zeros: "f32[2, s16]" = torch.ops.aten.zeros.default([2, sym_size_int_3], device = device(type='cpu'), pin_memory = False);  sym_size_int_3 = None
            return (zeros,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    zeros: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
doc.plot_legend("dynamic shapes\ndo not cast to\npython int", "dynamic shapes", "yellow")
plot dynamic shapes python int

Total running time of the script: (0 minutes 0.126 seconds)

Related examples

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Half certain nonzero

Half certain nonzero

Gallery generated by Sphinx-Gallery