Dynamic Shapes and Broadcasting

torch.export.export() makes strict assumption on dynamic shapes to the generic case. Let’s consider two tensors with only one dimension. x * y allows four configurations:

  • shape(x) = (1,) and shape(y) = (1,)

  • shape(x) = (1,) and shape(y) = (p,)

  • shape(x) = (q,) and shape(y) = (1,)

  • shape(x) = (p,) and shape(y) = (p,)

The expected shape for shape(x * y) is (max(p,q),).

Simple Case

import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from onnx_diagnostic.torch_export_patches import torch_export_patches
from torch.fx import Tracer


class Model(torch.nn.Module):
    def forward(self, x, y):
        return x * y


Dim = torch.export.Dim

ep = torch.export.export(
    Model(),
    (torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)),
    dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s17]", y: "f32[s17]"):
            # File: ~/github/onnx-diagnostic/_doc/technical/plot_broadcast_export_issue.py:31 in forward, code: return x * y
            mul: "f32[s17]" = torch.ops.aten.mul.Tensor(x, y);  x = y = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {s17: VR[2, int_oo]}

We see clearly that the export assumed that x ad y had the same shape. No other configuration seemed to work at export time, including with torch.fx.experimental._config.patch(backed_size_oblivious=True): the shape of one tensor equal to (1,).

output = [n for n in ep.graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
output is  output  arg is (mul,)

The final shape is:

shape = output.args[0][0].meta["val"].shape
print("output shape is ", shape)
output shape is  torch.Size([s17])

Tracing

Let’s compare with what a simple tracing would do. Let’s use torch.fx.Tracer.

graph = Tracer().trace(Model())
print(graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%x, %y), kwargs = {})
    return mul
output = [n for n in graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
print("The tracer leaves no trace:", output.args[0].__dict__)
output is  output  arg is mul
The tracer leaves no trace: {}

Shape propagation

gm = torch.fx.GraphModule(Model(), graph)

shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
# d1 = shape_env.create_unbacked_symint()
# d2 = shape_env.create_unbacked_symint()
fake_inputs = fake_mode.from_tensor(
    torch.zeros((3,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)

print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
fake_inputs are  (FakeTensor(..., size=(s26,)), FakeTensor(..., size=(s26,)))
output is FakeTensor(..., size=(s26,))

Handle Different Shapes

fake_inputs = fake_mode.from_tensor(
    torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False)

print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
print("output is", res)
fake_inputs are  (FakeTensor(..., size=(s53,)), FakeTensor(..., size=(1,)))
output is FakeTensor(..., size=(s53,))

Conclusion

We need to give distinct dimensions to get distinct names.

fake_inputs = fake_mode.from_tensor(
    torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
print("fake_inputs are ", fake_inputs)
fake_inputs are  (FakeTensor(..., size=(s53,)), FakeTensor(..., size=(s26,)))
try:
    res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
except Exception as e:
    print("error", e)

By applying the patches:

with torch_export_patches():
    res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
    print("output is", res)
output is FakeTensor(..., size=(s53,))

This is what we want. Let’s go back to torch.export.export()

with torch_export_patches():
    ep = torch.export.export(
        Model(),
        (
            torch.tensor([2, 3], dtype=torch.float32),
            torch.tensor([2, 3, 4], dtype=torch.float32),
        ),
        dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
    )
    print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s77]", y: "f32[s17]"):
            # File: ~/github/onnx-diagnostic/_doc/technical/plot_broadcast_export_issue.py:31 in forward, code: return x * y
            mul: "f32[Max(s17, s77)]" = torch.ops.aten.mul.Tensor(x, y);  x = y = None
            return (mul,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    mul: USER_OUTPUT

Range constraints: {s77: VR[2, int_oo], s17: VR[2, int_oo]}
output = [n for n in ep.graph.nodes if n.op == "output"][0]
print("output is ", output.name, " arg is", output.args[0])
shape = output.args[0][0].meta["val"].shape
print("output shape is ", shape)
output is  output  arg is (mul,)
output shape is  torch.Size([Max(s17, s77)])

Total running time of the script: (0 minutes 0.090 seconds)

Related examples

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

Export with loops

Export with loops

LayerNormalization implementation cannot be exchanged

LayerNormalization implementation cannot be exchanged

Gallery generated by Sphinx-Gallery