Note
Go to the end to download the full example code.
to_onnx and a model with a loop (scan)¶
Control flow cannot be exported with a change.
The code of the model can be changed or patched
to introduce function torch.ops.higher_order.scan()
.
Pairwise Distance¶
We appy loops to the pairwise distances (torch.nn.PairwiseDistance
).
import scipy.spatial.distance as spd
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
class ModuleWithControlFlowLoop(torch.nn.Module):
def forward(self, x, y):
dist = torch.empty((x.shape[0], y.shape[0]), dtype=x.dtype)
for i in range(x.shape[0]):
sub = y - x[i : i + 1]
d = torch.sqrt((sub * sub).sum(axis=1))
dist[i, :] = d
return dist
model = ModuleWithControlFlowLoop()
x = torch.randn(3, 4)
y = torch.randn(5, 4)
pwd = spd.cdist(x.numpy(), y.numpy())
expected = torch.from_numpy(pwd)
print(f"shape={pwd.shape}, discrepancies={torch.abs(expected - model(x,y)).max()}")
shape=(3, 5), discrepancies=1.857952760531134e-07
torch.export.export()
works because it unrolls the loop.
It works if the input size never change.
ep = torch.export.export(model, (x, y))
print(ep.graph)
graph():
%x : [num_users=3] = placeholder[target=x]
%y : [num_users=3] = placeholder[target=y]
%empty : [num_users=3] = call_function[target=torch.ops.aten.empty.memory_format](args = ([3, 5],), kwargs = {dtype: torch.float32, device: cpu, pin_memory: False})
%slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 0, 1), kwargs = {})
%sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %slice_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub, %sub), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul, [1]), kwargs = {})
%sqrt : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%sum_1,), kwargs = {})
%select : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%empty, 0, 0), kwargs = {})
%slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select, 0, 0, 9223372036854775807), kwargs = {})
%copy : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_2, %sqrt), kwargs = {})
%select_1 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%empty, 0, 0), kwargs = {})
%slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_1, %copy, 0, 0, 9223372036854775807), kwargs = {})
%select_scatter : [num_users=3] = call_function[target=torch.ops.aten.select_scatter.default](args = (%empty, %slice_scatter, 0, 0), kwargs = {})
%slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 1, 2), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %slice_4), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_1, %sub_1), kwargs = {})
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_1, [1]), kwargs = {})
%sqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%sum_2,), kwargs = {})
%select_4 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter, 0, 1), kwargs = {})
%slice_6 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_4, 0, 0, 9223372036854775807), kwargs = {})
%copy_1 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_6, %sqrt_1), kwargs = {})
%select_5 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter, 0, 1), kwargs = {})
%slice_scatter_1 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_5, %copy_1, 0, 0, 9223372036854775807), kwargs = {})
%select_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter, %slice_scatter_1, 0, 1), kwargs = {})
%slice_8 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%x, 0, 2, 3), kwargs = {})
%sub_2 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %slice_8), kwargs = {})
%mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_2, %sub_2), kwargs = {})
%sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_2, [1]), kwargs = {})
%sqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.sqrt.default](args = (%sum_3,), kwargs = {})
%select_8 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_1, 0, 2), kwargs = {})
%slice_10 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%select_8, 0, 0, 9223372036854775807), kwargs = {})
%copy_2 : [num_users=1] = call_function[target=torch.ops.aten.copy.default](args = (%slice_10, %sqrt_2), kwargs = {})
%select_9 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%select_scatter_1, 0, 2), kwargs = {})
%slice_scatter_2 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%select_9, %copy_2, 0, 0, 9223372036854775807), kwargs = {})
%select_scatter_2 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%select_scatter_1, %slice_scatter_2, 0, 2), kwargs = {})
return (select_scatter_2,)
However, with dynamic shapes, that’s another story.
Constraints violated (x_rows)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of x_rows = L['x'].size()[0] in the specified range are valid because x_rows was inferred to be a constant (3).
Suggested fixes:
x_rows = 3
Suggested Patch¶
We need to rewrite the module with function
torch.ops.higher_order.scan()
.
def dist(y: torch.Tensor, scanned_x: torch.Tensor):
sub = y - scanned_x.reshape((1, -1))
sq = sub * sub
rd = torch.sqrt(sq.sum(axis=1))
# clone --> UnsupportedAliasMutationException:
# Combine_fn might be aliasing the input!
return [y.clone(), rd]
class ModuleWithControlFlowLoopScan(torch.nn.Module):
def forward(self, x, y):
carry, out = torch.ops.higher_order.scan(
dist,
[y],
[x],
dim=0,
reverse=False,
additional_inputs=[],
)
return out
model = ModuleWithControlFlowLoopScan()
print(f"shape={pwd.shape}, discrepancies={torch.abs(expected - model(x,y)).max()}")
shape=(3, 5), discrepancies=1.857952760531134e-07
That works. Let’s export again.
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%scan_combine_graph_0 : [num_users=1] = get_attr[target=scan_combine_graph_0]
%scan : [num_users=1] = call_function[target=torch.ops.higher_order.scan](args = (%scan_combine_graph_0, [%y], [%x], 0, False, []), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%scan, 1), kwargs = {})
return (getitem_1,)
Let’s export again with ONNX.
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=['x_rows', 'dim']
input: name='y' type=dtype('float32') shape=['y_rows', 'dim']
Scan(y, x, body=G1, num_scan_inputs=1, scan_input_directions=[0], scan_output_axes=[0], scan_output_directions=[0]) -> scan#0, output_0
output: name='output_0' type=dtype('float32') shape=['x_rows', 'y_rows']
----- subgraph ---- Scan - aten_scan - att.body=G1 -- level=1 -- init_0_y,scan_0_x -> output_0,output_1
input: name='init_0_y' type='NOTENSOR' shape=None
input: name='scan_0_x' type='NOTENSOR' shape=None
scan_combine_graph_0[local_functions](init_0_y, scan_0_x) -> output_0, output_1
output: name='output_0' type='NOTENSOR' shape=None
output: name='output_1' type='NOTENSOR' shape=None
----- function name=scan_combine_graph_0 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'arg0_1'
input: 'arg1_1'
Constant(value=[1, -1]) -> init7_s2_1_-1
Reshape(arg1_1, init7_s2_1_-1) -> view
Sub(arg0_1, view) -> sub_4
Mul(sub_4, sub_4) -> mul_7
Constant(value=[1]) -> init7_s1_1
ReduceSum(mul_7, init7_s1_1, keepdims=0) -> sum_1
Sqrt(sum_1) -> output_1
Identity(arg0_1) -> output_0
output: name='output_0' type=? shape=?
output: name='output_1' type=? shape=?
We can also inline the local function.
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=True, external_threshold=1024...
input: name='x' type=dtype('float32') shape=['x_rows', 'dim']
input: name='y' type=dtype('float32') shape=['y_rows', 'dim']
Scan(y, x, body=G1, num_scan_inputs=1, scan_input_directions=[0], scan_output_axes=[0], scan_output_directions=[0]) -> scan#0, output_0
output: name='output_0' type=dtype('float32') shape=['x_rows', 'y_rows']
----- subgraph ---- Scan - aten_scan - att.body=G1 -- level=1 -- init_0_y,scan_0_x -> output_0,output_1
input: name='init_0_y' type='NOTENSOR' shape=None
input: name='scan_0_x' type='NOTENSOR' shape=None
Constant(value=[1]) -> init7_s1_12
Constant(value=[1, -1]) -> init7_s2_1_-12
Reshape(scan_0_x, init7_s2_1_-12) -> view2
Sub(init_0_y, view2) -> sub_42
Mul(sub_42, sub_42) -> mul_72
ReduceSum(mul_72, init7_s1_12, keepdims=0) -> sum_12
Sqrt(sum_12) -> output_1
Identity(init_0_y) -> output_0
output: name='output_0' type='NOTENSOR' shape=None
output: name='output_1' type='NOTENSOR' shape=None
And visually.
<Axes: >
Total running time of the script: (0 minutes 2.354 seconds)