Dynamic Shapes and Broadcasting

torch.export.export() makes strict assumption on dynamic shapes to the generic case. Let’s consider two tensors with only one dimension. x * y allows four configurations:

  • shape(x) = (1,) and shape(y) = (1,)

  • shape(x) = (1,) and shape(y) = (p,)

  • shape(x) = (q,) and shape(y) = (1,)

  • shape(x) = (p,) and shape(y) = (p,)

The expected shape for shape(x * y) is (max(p,q),).

Simple Case

import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from onnx_diagnostic.torch_export_patches import torch_export_patches
from torch.fx import Tracer


class Model(torch.nn.Module):
    def forward(self, x, y):
        return x * y


Dim = torch.export.Dim

ep = torch.export.export(
    Model(),
    (torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)),
    dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s17]", y: "f32[s17]"):
             # File: ~/github/onnx-diagnostic/_doc/technical/plot_broadcast_export_issue.py:31 in forward, code: return x * y
            mul: "f32[s17]" = torch.ops.aten.mul.Tensor(x, y);  x = y = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {s17: VR[2, int_oo]}

We see clearly that the export assumed that x ad y had the same shape. No other configuration seemed to work at export time, including with torch.fx.experimental._config.patch(backed_size_oblivious=True): the shape of one tensor equal to (1,).

output = [n for n in ep.graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
output is  output  arg is (mul,)

The final shape is:

shape = output.args[0][0].meta["val"].shape
print("output shape is ", shape)
output shape is  torch.Size([s17])

Tracing

Let’s compare with what a simple tracing would do. Let’s use torch.fx.Tracer.

graph = Tracer().trace(Model())
print(graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %y), kwargs = {})
    return mul
output = [n for n in graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
print("The tracer leaves no trace:", output.args[0].__dict__)
output is  output  arg is mul
The tracer leaves no trace: {}

Shape propagation

gm = torch.fx.GraphModule(Model(), graph)

shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
# d1 = shape_env.create_unbacked_symint()
# d2 = shape_env.create_unbacked_symint()
fake_inputs = fake_mode.from_tensor(
    torch.zeros((3,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)

print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
fake_inputs are  (FakeTensor(..., size=(s26,)), FakeTensor(..., size=(s26,)))
output is FakeTensor(..., size=(s26,))

Handle Different Shapes

fake_inputs = fake_mode.from_tensor(
    torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False)

print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
fake_inputs are  (FakeTensor(..., size=(s53,)), FakeTensor(..., size=(1,)))
output is FakeTensor(..., size=(s53,))

Conclusion

We need to give distinct dimensions to get distinct names.

fake_inputs = fake_mode.from_tensor(
    torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
print("fake_inputs are ", fake_inputs)
fake_inputs are  (FakeTensor(..., size=(s53,)), FakeTensor(..., size=(s26,)))
try:
    res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
except Exception as e:
    print("error", e)
error The size of tensor a (s53) must match the size of tensor b (s26) at non-singleton dimension 0)

While executing %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %y), kwargs = {})
Original traceback:
None
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

By applying the patches:

with torch_export_patches():
    res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
    print("output is", res)
output is FakeTensor(..., size=(Max(s26, s53),))

This is what we want. Let’s go back to torch.export.export()

with torch_export_patches():
    ep = torch.export.export(
        Model(),
        (
            torch.tensor([2, 3], dtype=torch.float32),
            torch.tensor([2, 3, 4], dtype=torch.float32),
        ),
        dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
    )
    print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s77]", y: "f32[s17]"):
             # File: ~/github/onnx-diagnostic/_doc/technical/plot_broadcast_export_issue.py:31 in forward, code: return x * y
            mul: "f32[Max(s17, s77)]" = torch.ops.aten.mul.Tensor(x, y);  x = y = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {s77: VR[2, int_oo], s17: VR[2, int_oo]}
output = [n for n in ep.graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
shape = output.args[0][0].meta["val"].shape
print("output shape is ", shape)
output is  output  arg is (mul,)
output shape is  torch.Size([Max(s17, s77)])

Total running time of the script: (0 minutes 0.482 seconds)

Related examples

LayerNormalization implementation cannot be exchanged

LayerNormalization implementation cannot be exchanged

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

Gallery generated by Sphinx-Gallery