to_onnx and a model with a loop (scan)

Control flow cannot be exported with a change. The code of the model can be changed or patched to introduce function torch.ops.higher_order.scan().

Pairwise Distance

We appy loops to the pairwise distances (torch.nn.PairwiseDistance).

import scipy.spatial.distance as spd
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx


class ModuleWithControlFlowLoop(torch.nn.Module):
    def forward(self, x, y):
        dist = torch.empty((x.shape[0], y.shape[0]), dtype=x.dtype)
        for i in range(x.shape[0]):
            sub = y - x[i : i + 1]
            d = torch.sqrt((sub * sub).sum(axis=1))
            dist[i, :] = d
        return dist


model = ModuleWithControlFlowLoop()
x = torch.randn(3, 4)
y = torch.randn(5, 4)
pwd = spd.cdist(x.numpy(), y.numpy())
expected = torch.from_numpy(pwd)
print(f"shape={pwd.shape}, discrepancies={torch.abs(expected - model(x,y)).max()}")
shape=(3, 5), discrepancies=1.857952760531134e-07

torch.export.export() works because it unrolls the loop. It works if the input size never change.

ep = torch.export.export(model, (x, y))
print(ep.graph)
graph():
    %x : [num_users=3] = placeholder[target=x]
    %y : [num_users=3] = placeholder[target=y]
    %empty : [num_users=3] = call_function[target=torch.ops.aten.empty.memory_format](args = ([3, 5],), kwargs = {dtype: torch.float32, device: cpu, pin_memory: False})
    %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 1), kwargs = {})
    %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %slice_1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1]), kwargs = {})
    %sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%sum_1,), kwargs = {})
    %select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%empty, 0, 0), kwargs = {})
    %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select, 0, 0, 9223372036854775807), kwargs = {})
    %copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %sqrt), kwargs = {})
    %select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%empty, 0, 0), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_1, %copy, 0, 0, 9223372036854775807), kwargs = {})
    %select_scatter : [num_users=3] = call_function[target=torch.ops.aten.select_scatter.default](args = (%empty, %slice_scatter, 0, 0), kwargs = {})
    %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 1, 2), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %slice_4), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %sub_1), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_1, [1]), kwargs = {})
    %sqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%sum_2,), kwargs = {})
    %select_4 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter, 0, 1), kwargs = {})
    %slice_6 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_4, 0, 0, 9223372036854775807), kwargs = {})
    %copy_1 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_6, %sqrt_1), kwargs = {})
    %select_5 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter, 0, 1), kwargs = {})
    %slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_5, %copy_1, 0, 0, 9223372036854775807), kwargs = {})
    %select_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter, %slice_scatter_1, 0, 1), kwargs = {})
    %slice_8 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 2, 3), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %slice_8), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_2, %sub_2), kwargs = {})
    %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_2, [1]), kwargs = {})
    %sqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%sum_3,), kwargs = {})
    %select_8 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_1, 0, 2), kwargs = {})
    %slice_10 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_8, 0, 0, 9223372036854775807), kwargs = {})
    %copy_2 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_10, %sqrt_2), kwargs = {})
    %select_9 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_1, 0, 2), kwargs = {})
    %slice_scatter_2 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_9, %copy_2, 0, 0, 9223372036854775807), kwargs = {})
    %select_scatter_2 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_1, %slice_scatter_2, 0, 2), kwargs = {})
    return (select_scatter_2,)

However, with dynamic shapes, that’s another story.

x_rows = torch.export.Dim("x_rows")
y_rows = torch.export.Dim("y_rows")
dim = torch.export.Dim("dim")
try:
    ep = torch.export.export(
        model, (x, y), dynamic_shapes={"x": {0: x_rows, 1: dim}, "y": {0: y_rows, 1: dim}}
    )
    print(ep.graph)
except Exception as e:
    print(e)
Constraints violated (x_rows)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of x_rows = L['x'].size()[0] in the specified range are valid because x_rows was inferred to be a constant (3).

Suggested fixes:
  x_rows = 3

Suggested Patch

We need to rewrite the module with function torch.ops.higher_order.scan().

def dist(y: torch.Tensor, scanned_x: torch.Tensor):
    sub = y - scanned_x.reshape((1, -1))
    sq = sub * sub
    rd = torch.sqrt(sq.sum(axis=1))
    # clone --> UnsupportedAliasMutationException:
    # Combine_fn might be aliasing the input!
    return [y.clone(), rd]


class ModuleWithControlFlowLoopScan(torch.nn.Module):

    def forward(self, x, y):
        carry, out = torch.ops.higher_order.scan(
            dist,
            [y],
            [x],
            dim=0,
            reverse=False,
            additional_inputs=[],
        )
        return out


model = ModuleWithControlFlowLoopScan()
print(f"shape={pwd.shape}, discrepancies={torch.abs(expected - model(x,y)).max()}")
shape=(3, 5), discrepancies=1.857952760531134e-07

That works. Let’s export again.

ep = torch.export.export(
    model, (x, y), dynamic_shapes={"x": {0: x_rows, 1: dim}, "y": {0: y_rows, 1: dim}}
)
print(ep.graph)
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
    %scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], 0, False, []), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
    return (getitem_1,)

Let’s export again with ONNX.

onx = to_onnx(
    model, (x, y), dynamic_shapes={"x": {0: x_rows, 1: dim}, "y": {0: y_rows, 1: dim}}
)
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=['x_rows', 'dim']
input: name='y' type=dtype('float32') shape=['y_rows', 'dim']
Scan(y, x, body=G1, num_scan_inputs=1, scan_input_directions=[0], scan_output_axes=[0], scan_output_directions=[0]) -> scan#0, output_0
output: name='output_0' type=dtype('float32') shape=['x_rows', 'y_rows']
----- subgraph ---- Scan - aten_scan - att.body=G1 -- level=1 -- init_0_y,scan_0_x -> output_0,output_1
input: name='init_0_y' type='NOTENSOR' shape=None
input: name='scan_0_x' type='NOTENSOR' shape=None
scan_combine_graph_0[local_functions](init_0_y, scan_0_x) -> output_0, output_1
output: name='output_0' type='NOTENSOR' shape=None
output: name='output_1' type='NOTENSOR' shape=None
----- function name=scan_combine_graph_0 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'arg0_1'
input: 'arg1_1'
Constant(value=[1, -1]) -> init7_s2_1_-1
  Reshape(arg1_1, init7_s2_1_-1) -> view
    Sub(arg0_1, view) -> sub_4
      Mul(sub_4, sub_4) -> mul_7
Constant(value=[1]) -> init7_s1_1
  ReduceSum(mul_7, init7_s1_1, keepdims=0) -> sum_1
    Sqrt(sum_1) -> output_1
Identity(arg0_1) -> output_0
output: name='output_0' type=? shape=?
output: name='output_1' type=? shape=?

We can also inline the local function.

onx = to_onnx(
    model,
    (x, y),
    dynamic_shapes={"x": {0: x_rows, 1: dim}, "y": {0: y_rows, 1: dim}},
    inline=True,
)
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=True, external_threshold=1024...
input: name='x' type=dtype('float32') shape=['x_rows', 'dim']
input: name='y' type=dtype('float32') shape=['y_rows', 'dim']
Scan(y, x, body=G1, num_scan_inputs=1, scan_input_directions=[0], scan_output_axes=[0], scan_output_directions=[0]) -> scan#0, output_0
output: name='output_0' type=dtype('float32') shape=['x_rows', 'y_rows']
----- subgraph ---- Scan - aten_scan - att.body=G1 -- level=1 -- init_0_y,scan_0_x -> output_0,output_1
input: name='init_0_y' type='NOTENSOR' shape=None
input: name='scan_0_x' type='NOTENSOR' shape=None
Constant(value=[1]) -> init7_s1_12
Constant(value=[1, -1]) -> init7_s2_1_-12
  Reshape(scan_0_x, init7_s2_1_-12) -> view2
    Sub(init_0_y, view2) -> sub_42
      Mul(sub_42, sub_42) -> mul_72
  ReduceSum(mul_72, init7_s1_12, keepdims=0) -> sum_12
    Sqrt(sum_12) -> output_1
Identity(init_0_y) -> output_0
output: name='output_0' type='NOTENSOR' shape=None
output: name='output_1' type='NOTENSOR' shape=None

And visually.

plot exporter recipes c scan pdist
<Axes: >

Total running time of the script: (0 minutes 2.354 seconds)

Gallery generated by Sphinx-Gallery