A dynamic dimension lost by torch.export.export

Dynamic shapes ensures a model is valid not matter what the dimension value is for a dynamic dimension. torch.export.export() is trying to keep track of that information for every intermediate result the model produces. But something it fails. Let’s see one case.

A dynamic dimension is replaced by a constant by function pad

It could by any other function. A function is taking an integer as an argument. Despite the fact this value may change with different input, the exporter loses than information as it consider the value as an integer, therefore, a constant.

import torch


def dummy_function(idx, x_len):
    # [1, 2, 3] becomes [1, 2, 3, x_len]
    return torch.nn.functional.pad(idx, (0, 1), value=x_len)


class Model(torch.nn.Module):
    def forward(self, x, y):
        padded = dummy_function(x, y.shape[0])
        return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))


model = Model()
inputs = (
    (torch.arange(3) + 1).to(torch.int64),
    torch.tensor([0, 5], dtype=torch.int64),
)
print(model(*inputs))
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s export.

AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
    model, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False
)

Let’s check it works.

print(ep.module()(*inputs))
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s print the graph.

print(ep.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=0] = placeholder[target=y]
    %pad : [num_users=2] = call_function[target=torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0), kwargs = {})
    %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%pad, [-1, 1]), kwargs = {})
    %max_1 : [num_users=1] = call_function[target=torch.ops.aten.max.default](args = (%pad,), kwargs = {})
    %item : [num_users=3] = call_function[target=torch.ops.aten.item.default](args = (%max_1,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%item,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%item, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (%item,), kwargs = {device: cpu, pin_memory: False})
    %reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [1, -1]), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%reshape, %reshape_1), kwargs = {})
    return (add,)

It shows the following line [torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0) which corresponds to torch.nn.functional.pad(idx, (0, 1), value=x_len). But in this case, x_len is equal to y.shape[0] which was defined as a dynamic dimension. Se if we choose something like the following:

inputs2 = (
    (torch.arange(3) + 1).to(torch.int64),
    torch.tensor([0, 5, 6], dtype=torch.int64),
)

The original model works.

print(model(*inputs2))
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [3, 4, 5]])

But the exported program does not.

try:
    print(ep.module()(*inputs2))
except Exception as e:
    print(e)
Expected input at *args[1].shape[0] to be equal to 2, but got 3

How to fix it?

In this particular case, function is not the only way pad to produce the desired result.

def dummy_function_cat(idx, x_len):
    # [1, 2, 3] becomes [1, 2, 3, x_len]
    return torch.cat([idx, torch.tensor([x_len], dtype=torch.int64)], dim=0)


class ModelCat(torch.nn.Module):
    def forward(self, x, y):
        padded = dummy_function_cat(x, y.shape[0])
        return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))


modelcat = ModelCat()
print(modelcat(*inputs))
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s export.

epcat = torch.export.export(
    modelcat, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False
)

Let’s check it works.

print(epcat.module()(*inputs))
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [2, 3, 4]])

Let’s print the graph.

print(epcat.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_2,), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
    %stack : [num_users=1] = call_function[target=torch.ops.aten.stack.default](args = ([%scalar_tensor],), kwargs = {})
    %to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%stack, cpu, torch.int64), kwargs = {})
    %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%to,), kwargs = {})
    %cat : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %detach_],), kwargs = {})
    %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [-1, 1]), kwargs = {})
    %max_1 : [num_users=1] = call_function[target=torch.ops.aten.max.default](args = (%cat,), kwargs = {})
    %item : [num_users=3] = call_function[target=torch.ops.aten.item.default](args = (%max_1,), kwargs = {})
    %sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%item,), kwargs = {})
    %ge : [num_users=1] = call_function[target=operator.ge](args = (%item, 0), kwargs = {})
    %_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
    %arange : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (%item,), kwargs = {device: cpu, pin_memory: False})
    %reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [1, -1]), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%reshape, %reshape_1), kwargs = {})
    return (add,)

And the final verification.

print(epcat.module()(*inputs2))
tensor([[1, 2, 3],
        [2, 3, 4],
        [3, 4, 5],
        [3, 4, 5]])

It finally works.

Total running time of the script: (0 minutes 4.679 seconds)

Related examples

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

A few tricks about dynamic shapes

A few tricks about dynamic shapes

to_onnx and a model with a loop (scan)

to_onnx and a model with a loop (scan)

Gallery generated by Sphinx-Gallery