Note
Go to the end to download the full example code.
A dynamic dimension lost by torch.export.export¶
Dynamic shapes ensures a model is valid not matter what the
dimension value is for a dynamic dimension.
torch.export.export()
is trying to keep track of that information
for every intermediate result the model produces.
But something it fails. Let’s see one case.
A dynamic dimension is replaced by a constant by function pad¶
It could by any other function. A function is taking an integer as an argument. Despite the fact this value may change with different input, the exporter loses than information as it consider the value as an integer, therefore, a constant.
import torch
def dummy_function(idx, x_len):
# [1, 2, 3] becomes [1, 2, 3, x_len]
return torch.nn.functional.pad(idx, (0, 1), value=x_len)
class Model(torch.nn.Module):
def forward(self, x, y):
padded = dummy_function(x, y.shape[0])
return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))
model = Model()
inputs = (
(torch.arange(3) + 1).to(torch.int64),
torch.tensor([0, 5], dtype=torch.int64),
)
print(model(*inputs))
tensor([[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[2, 3, 4]])
Let’s export.
AUTO = torch.export.Dim.AUTO
ep = torch.export.export(
model, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False
)
Let’s check it works.
tensor([[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[2, 3, 4]])
Let’s print the graph.
print(ep.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=0] = placeholder[target=y]
%pad : [num_users=2] = call_function[target=torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0), kwargs = {})
%reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%pad, [-1, 1]), kwargs = {})
%max_1 : [num_users=1] = call_function[target=torch.ops.aten.max.default](args = (%pad,), kwargs = {})
%item : [num_users=3] = call_function[target=torch.ops.aten.item.default](args = (%max_1,), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%item,), kwargs = {})
%ge : [num_users=1] = call_function[target=operator.ge](args = (%item, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
%arange : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (%item,), kwargs = {device: cpu, pin_memory: False})
%reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [1, -1]), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%reshape, %reshape_1), kwargs = {})
return (add,)
It shows the following line
[torch.ops.aten.pad.default](args = (%x, [0, 1], constant, 2.0)
which corresponds to torch.nn.functional.pad(idx, (0, 1), value=x_len)
.
But in this case, x_len
is equal to y.shape[0]
which was defined
as a dynamic dimension. Se if we choose something like the following:
inputs2 = (
(torch.arange(3) + 1).to(torch.int64),
torch.tensor([0, 5, 6], dtype=torch.int64),
)
The original model works.
print(model(*inputs2))
tensor([[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[3, 4, 5]])
But the exported program does not.
Expected input at *args[1].shape[0] to be equal to 2, but got 3
How to fix it?¶
In this particular case, function is not the only way pad
to produce the desired result.
def dummy_function_cat(idx, x_len):
# [1, 2, 3] becomes [1, 2, 3, x_len]
return torch.cat([idx, torch.tensor([x_len], dtype=torch.int64)], dim=0)
class ModelCat(torch.nn.Module):
def forward(self, x, y):
padded = dummy_function_cat(x, y.shape[0])
return padded.reshape((-1, 1)) + torch.arange(padded.max()).reshape((1, -1))
modelcat = ModelCat()
print(modelcat(*inputs))
tensor([[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[2, 3, 4]])
Let’s export.
epcat = torch.export.export(
modelcat, inputs, dynamic_shapes={"x": {0: AUTO}, "y": {0: AUTO}}, strict=False
)
Let’s check it works.
print(epcat.module()(*inputs))
tensor([[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[2, 3, 4]])
Let’s print the graph.
print(epcat.graph)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%sym_size_int_2 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%y, 0), kwargs = {})
%scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%sym_size_int_2,), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
%stack : [num_users=1] = call_function[target=torch.ops.aten.stack.default](args = ([%scalar_tensor],), kwargs = {})
%to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%stack, cpu, torch.int64), kwargs = {})
%detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%to,), kwargs = {})
%cat : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%x, %detach_],), kwargs = {})
%reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%cat, [-1, 1]), kwargs = {})
%max_1 : [num_users=1] = call_function[target=torch.ops.aten.max.default](args = (%cat,), kwargs = {})
%item : [num_users=3] = call_function[target=torch.ops.aten.item.default](args = (%max_1,), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%item,), kwargs = {})
%ge : [num_users=1] = call_function[target=operator.ge](args = (%item, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge, Runtime assertion failed for expression u0 >= 0 on node 'ge'), kwargs = {})
%arange : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (%item,), kwargs = {device: cpu, pin_memory: False})
%reshape_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%arange, [1, -1]), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%reshape, %reshape_1), kwargs = {})
return (add,)
And the final verification.
print(epcat.module()(*inputs2))
tensor([[1, 2, 3],
[2, 3, 4],
[3, 4, 5],
[3, 4, 5]])
It finally works.
Total running time of the script: (0 minutes 4.679 seconds)
Related examples
to_onnx and submodules from LLMs
A few tricks about dynamic shapes
to_onnx and a model with a loop (scan)