

# A dynamic dimension lost by torch.export.export

Dynamic shapes ensures a model is valid not matter what the
dimension value is for a dynamic dimension.
:func:`torch.export.export` is trying to keep track of that information
for every intermediate result the model produces.
But something it fails. Let's see one case.

## A dynamic dimension is replaced by a constant by function pad

It could by any other function. A function is taking an integer as an argument.
Despite the fact this value may change with different input, the exporter
loses than information as it consider the value as an integer, therefore,
a constant.


In [None]:
import torch


def dummy_function(idx, x_len):
    # [1, 2, 3] becomes [1, 2, 3, x_len]
    return torch.nn.functional.pad(idx, (0, 1), value=x_len)


class Model(torch.nn.Module):
    def forward(self, x, y):
        padded = dummy_function(x, y.shape[0])
        return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))


model = Model()
inputs = (
    (torch.arange(3) + 1).to(torch.int64),
    torch.tensor([0, 5], dtype=torch.int64),
)
print(model(*inputs))

Let's export.



In [None]:
AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
    model, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False
)

Let's check it works.



In [None]:
print(ep.module()(*inputs))

Let's print the graph.



In [None]:
print(ep.graph)

It shows the following line
``[torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0)``
which corresponds to ``torch.nn.functional.pad(idx, (0, 1), value=x_len)``.
But in this case, ``x_len`` is equal to ``y.shape[0]`` which was defined
as a dynamic dimension. Se if we choose something like the following:



In [None]:
inputs2 = (
    (torch.arange(3) + 1).to(torch.int64),
    torch.tensor([0, 5, 6], dtype=torch.int64),
)

The original model works.



In [None]:
print(model(*inputs2))

But the exported program does not.



In [None]:
try:
    print(ep.module()(*inputs2))
except Exception as e:
    print(e)

## How to fix it?

In this particular case, function is not the only way ``pad``
to produce the desired result.



In [None]:
def dummy_function_cat(idx, x_len):
    # [1, 2, 3] becomes [1, 2, 3, x_len]
    return torch.cat([idx, torch.tensor([x_len], dtype=torch.int64)], dim=0)


class ModelCat(torch.nn.Module):
    def forward(self, x, y):
        padded = dummy_function_cat(x, y.shape[0])
        return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))


modelcat = ModelCat()
print(modelcat(*inputs))

Let's export.



In [None]:
epcat = torch.export.export(
    modelcat, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False
)

Let's check it works.



In [None]:
print(epcat.module()(*inputs))

Let's print the graph.



In [None]:
print(epcat.graph)

And the final verification.



In [None]:
print(epcat.module()(*inputs2))

It finally works.

