A dynamic dimension lost by torch.export.export

Dynamic shapes ensures a model is valid not matter what the dimension value is for a dynamic dimension. torch.export.export() is trying to keep track of that information for every intermediate result the model produces. But something it fails. Let’s see one case.

A dynamic dimension is replaced by a constant by function pad

It could by any other function. A function is taking an integer as an argument. Despite the fact this value may change with different input, the exporter loses than information as it consider the value as an integer, therefore, a constant.

import torch

def dummy_function(idx, x_len):
    # [1, 2, 3] becomes [1, 2, 3, x_len]
    return torch.nn.functional.pad(idx, (0, 1), value=x_len)

class Model(torch.nn.Module):
    def forward(self, x, y):
        padded = dummy_function(x, y.shape[0])
        return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))

model = Model()
inputs = (
    (torch.arange(3) + 1).to(torch.int64),
    torch.tensor([0, 5], dtype=torch.int64),
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s export.

AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
    model, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False

Let’s check it works.

tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s print the graph.

    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %pad : [num_users=2] = call_function[target=torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0), kwargs = {})
    %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%pad, [-1, 1]), kwargs = {})
    %max_1 : [num_users=1] = call_function[target=torch.ops.aten.max.default](args = (%pad,), kwargs = {})
    %item : [num_users=3] = call_function[target=torch.ops.aten.item.default](args = (%max_1,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%item,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%item, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (%item,), kwargs = {device: cpu, pin_memory: False})
    %reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [1, -1]), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%reshape, %reshape_1), kwargs = {})
    return (add,)

It shows the following line [torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0) which corresponds to torch.nn.functional.pad(idx, (0, 1), value=x_len). But in this case, x_len is equal to y.shape[0] which was defined as a dynamic dimension. Se if we choose something like the following:

inputs2 = (
    (torch.arange(3) + 1).to(torch.int64),
    torch.tensor([0, 5, 6], dtype=torch.int64),

The original model works.

tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [3, 4, 5]])

But the exported program does not.

except Exception as e:
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

How to fix it?

In this particular case, function is not the only way pad to produce the desired result.

def dummy_function_cat(idx, x_len):
    # [1, 2, 3] becomes [1, 2, 3, x_len]
    return torch.cat([idx, torch.tensor([x_len], dtype=torch.int64)], dim=0)

class ModelCat(torch.nn.Module):
    def forward(self, x, y):
        padded = dummy_function_cat(x, y.shape[0])
        return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))

modelcat = ModelCat()
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s export.

epcat = torch.export.export(
    modelcat, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False

Let’s check it works.

tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s print the graph.

    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_2,), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
    %stack : [num_users=1] = call_function[target=torch.ops.aten.stack.default](args = ([%scalar_tensor],), kwargs = {})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%stack, cpu, torch.int64), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%to,), kwargs = {})
    %cat : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %detach_],), kwargs = {})
    %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [-1, 1]), kwargs = {})
    %max_1 : [num_users=1] = call_function[target=torch.ops.aten.max.default](args = (%cat,), kwargs = {})
    %item : [num_users=3] = call_function[target=torch.ops.aten.item.default](args = (%max_1,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%item,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%item, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (%item,), kwargs = {device: cpu, pin_memory: False})
    %reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [1, -1]), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%reshape, %reshape_1), kwargs = {})
    return (add,)

And the final verification.

tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [3, 4, 5]])

It finally works.

Total running time of the script: (0 minutes 0.186 seconds)

Related examples

to_onnx: Rename Dynamic Shapes

to_onnx: Rename Dynamic Shapes

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

Gallery generated by Sphinx-Gallery