Note
Go to the end to download the full example code.
Dynamic Shapes and Broadcasting¶
torch.export.export() makes strict assumption on dynamic shapes
to the generic case. Let’s consider two tensors with only one dimension.
x * y allows four configurations:
shape(x) = (1,)andshape(y) = (1,)shape(x) = (1,)andshape(y) = (p,)shape(x) = (q,)andshape(y) = (1,)shape(x) = (p,)andshape(y) = (p,)
The expected shape for shape(x * y) is (max(p,q),).
Simple Case¶
import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from onnx_diagnostic.torch_export_patches import torch_export_patches
from torch.fx import Tracer
class Model(torch.nn.Module):
def forward(self, x, y):
return x * y
Dim = torch.export.Dim
ep = torch.export.export(
Model(),
(torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)),
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s17]", y: "f32[s17]"):
# File: ~/github/onnx-diagnostic/_doc/technical/plot_broadcast_export_issue.py:31 in forward, code: return x * y
mul: "f32[s17]" = torch.ops.aten.mul.Tensor(x, y); x = y = None
return (mul,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
mul: USER_OUTPUT
Range constraints: {s17: VR[2, int_oo]}
We see clearly that the export assumed that x ad y had the same shape.
No other configuration seemed to work at export time,
including with torch.fx.experimental._config.patch(backed_size_oblivious=True):
the shape of one tensor equal to (1,).
output = [n for n in ep.graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
output is output arg is (mul,)
The final shape is:
shape = output.args[0][0].meta["val"].shape
print("output shape is ", shape)
output shape is torch.Size([s17])
Tracing¶
Let’s compare with what a simple tracing would do. Let’s use torch.fx.Tracer.
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %y), kwargs = {})
return mul
output = [n for n in graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
print("The tracer leaves no trace:", output.args[0].__dict__)
output is output arg is mul
The tracer leaves no trace: {}
Shape propagation¶
gm = torch.fx.GraphModule(Model(), graph)
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
# d1 = shape_env.create_unbacked_symint()
# d2 = shape_env.create_unbacked_symint()
fake_inputs = fake_mode.from_tensor(
torch.zeros((3,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
fake_inputs are (FakeTensor(..., size=(s26,)), FakeTensor(..., size=(s26,)))
output is FakeTensor(..., size=(s26,))
Handle Different Shapes¶
fake_inputs = fake_mode.from_tensor(
torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False)
print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
fake_inputs are (FakeTensor(..., size=(s53,)), FakeTensor(..., size=(1,)))
output is FakeTensor(..., size=(s53,))
Conclusion¶
We need to give distinct dimensions to get distinct names.
fake_inputs = fake_mode.from_tensor(
torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
print("fake_inputs are ", fake_inputs)
fake_inputs are (FakeTensor(..., size=(s53,)), FakeTensor(..., size=(s26,)))
try:
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
except Exception as e:
print("error", e)
error The size of tensor a (s53) must match the size of tensor b (s26) at non-singleton dimension 0)
While executing %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %y), kwargs = {})
Original traceback:
None
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
By applying the patches:
with torch_export_patches():
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
output is FakeTensor(..., size=(Max(s26, s53),))
This is what we want. Let’s go back to torch.export.export()
with torch_export_patches():
ep = torch.export.export(
Model(),
(
torch.tensor([2, 3], dtype=torch.float32),
torch.tensor([2, 3, 4], dtype=torch.float32),
),
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s77]", y: "f32[s17]"):
# File: ~/github/onnx-diagnostic/_doc/technical/plot_broadcast_export_issue.py:31 in forward, code: return x * y
mul: "f32[Max(s17, s77)]" = torch.ops.aten.mul.Tensor(x, y); x = y = None
return (mul,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
mul: USER_OUTPUT
Range constraints: {s77: VR[2, int_oo], s17: VR[2, int_oo]}
output = [n for n in ep.graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
shape = output.args[0][0].meta["val"].shape
print("output shape is ", shape)
output is output arg is (mul,)
output shape is torch.Size([Max(s17, s77)])
Total running time of the script: (0 minutes 0.482 seconds)
Related examples
LayerNormalization implementation cannot be exchanged