Note
Go to the end to download the full example code.
Export with loops¶
This is a simple example of loop which cannot be efficiently rewritten
with scan.
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
class Model(torch.nn.Module):
def __init__(self, crop_size):
super().__init__()
self.crop_size = crop_size
def forward(self, W, splits):
crop_size = self.crop_size
starts = splits[:-1]
ends = splits[1:]
cropped = []
for start, end in zip(starts, ends):
extract = W[:, start:end]
if extract.shape[1] < crop_size:
cropped.append(extract)
else:
cropped.append(extract[:, :crop_size])
return torch.cat(cropped, axis=1)
model = Model(4)
args = (torch.rand((2, 22)), torch.tensor([0, 5, 15, 20, 22], dtype=torch.int64))
expected = model(*args)
print(f"-- exected shape: {expected.shape}")
-- exected shape: torch.Size([2, 14])
Rewrite with higher order ops scan¶
The loop cannot be exported as is. It needs to be rewritten.
class ModelWithScan(Model):
def forward(self, W, splits):
crop_size = self.crop_size
starts = splits[:-1]
ends = splits[1:]
def body_scan(init, split, W):
extract = W[:, split[0].item() : split[1].item()]
cropped = extract[:, : torch.sym_min(extract.shape[1], crop_size)]
carried = torch.cat([init, cropped], axis=1)
return carried
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
return torch.ops.higher_order.scan(
body_scan, [torch.empty((W.shape[0], 0), dtype=W.dtype)], [starts_ends], [W]
)
rewritten_model_with_scan = ModelWithScan(4)
(results,) = rewritten_model_with_scan(*args)
print(f"-- max discrepancies with scan: { torch.abs(expected - results).max()}")
-- max discrepancies with scan: 0.0
This approach has one flows, the variable carried grows at every
iteration and the cost of the copy is quadratic when the same operation
in the first model is linear.
We cannot simply return variable cropped because its shape
is not always the same.
Introduce of a new higher order ops: simple_loop_for¶
simple_loop_for was designed to support this specific case.
It takes all the outputs coming from the body function and stores
them in list. Then it contenates them according to concatenation_dims.
class ModelWithLoop(Model):
def forward(self, W, splits):
crop_size = self.crop_size
starts = splits[:-1]
ends = splits[1:]
def body_loop(i, splits, W):
split = splits[i.item() : (i + 1).item()][0] # [i.item()] fails
extract = W[:, split[0].item() : split[1].item()]
cropped = extract[:, : torch.sym_min(extract.shape[1], crop_size)]
return (cropped,)
starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
n_iterations = torch.tensor(starts_ends.shape[0], dtype=torch.int64)
return simple_loop_for(
n_iterations, body_loop, (starts_ends, W), concatenation_dims=[1]
)
rewritten_model_with_loop = ModelWithLoop(4)
results = rewritten_model_with_loop(*args)
print(f"-- max discrepancies with loop: { torch.abs(expected - results).max()}")
-- max discrepancies with loop: 0.0
torch.export.export?¶
dynamic_shapes = (
{0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
{0: torch.export.Dim.DYNAMIC},
)
try:
ep = torch.export.export(rewritten_model_with_scan, args, dynamic_shapes=dynamic_shapes)
print("----- exported program with scan:")
print(ep)
except Exception as e:
print(f"export failed due to {e}")
export failed due to object of type 'Node' has no len()
And loops?
ep = torch.export.export(rewritten_model_with_loop, args, dynamic_shapes=dynamic_shapes)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, w: "f32[s55, s62]", splits: "i64[s10]"):
# No stacktrace found for following nodes
sym_size_int_3: "Sym(s10)" = torch.ops.aten.sym_size.int(splits, 0)
# File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:90 in forward, code: starts = splits[:-1]
slice_1: "i64[s10 - 1]" = torch.ops.aten.slice.Tensor(splits, 0, 0, -1)
# File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:91 in forward, code: ends = splits[1:]
slice_2: "i64[s10 - 1]" = torch.ops.aten.slice.Tensor(splits, 0, 1, 9223372036854775807); splits = None
# File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:99 in forward, code: starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
unsqueeze: "i64[s10 - 1, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None
unsqueeze_1: "i64[s10 - 1, 1]" = torch.ops.aten.unsqueeze.default(slice_2, 1); slice_2 = None
cat: "i64[s10 - 1, 2]" = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1); unsqueeze = unsqueeze_1 = None
# File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:100 in forward, code: n_iterations = torch.tensor(starts_ends.shape[0], dtype=torch.int64)
add: "Sym(s10 - 1)" = -1 + sym_size_int_3; sym_size_int_3 = None
scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(add, dtype = torch.int64, device = device(type='cpu'), pin_memory = False); add = None
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(scalar_tensor, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "i64[]" = torch.ops.aten.to.device(scalar_tensor, device(type='cpu'), torch.int64); scalar_tensor = None
detach_: "i64[]" = torch.ops.aten.detach_.default(to); to = None
# File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:101 in forward, code: return simple_loop_for(
body_graph_0 = self.body_graph_0
simple_loop_for = torch.ops.higher_order.simple_loop_for(detach_, body_graph_0, (cat, w), [1]); detach_ = body_graph_0 = cat = w = None
getitem: "f32[s55, u16]" = simple_loop_for[0]; simple_loop_for = None
return (getitem,)
class body_graph_0(torch.nn.Module):
def forward(self, i_1: "i64[]", splits_1: "i64[s10 - 1, 2]", W_1: "f32[s55, s62]"):
w_1 = W_1
# No stacktrace found for following nodes
sym_size_int_5: "Sym(s62)" = torch.ops.aten.sym_size.int(w_1, 1)
# File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:101 in forward, code: return simple_loop_for(
item: "Sym(u0)" = torch.ops.aten.item.default(i_1)
add: "i64[]" = torch.ops.aten.add.Tensor(i_1, 1); i_1 = None
item_1: "Sym(u1)" = torch.ops.aten.item.default(add); add = None
slice_1: "i64[u2, 2]" = torch.ops.aten.slice.Tensor(splits_1, 0, item, item_1); splits_1 = item = item_1 = None
sym_size_int_6: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
ge: "Sym(u2 >= 1)" = sym_size_int_6 >= 1
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 1 on node 'ge'"); ge = _assert_scalar_default = None
gt_2: "Sym(u2 > 0)" = sym_size_int_6 > 0
sym_sum: "Sym(u2 + 1)" = torch.sym_sum([1, sym_size_int_6]); sym_size_int_6 = None
gt_3: "Sym(u2 + 1 > 0)" = sym_sum > 0; sym_sum = None
and__1: "Sym((u2 > 0) & (u2 + 1 > 0))" = gt_2 & gt_3; gt_2 = gt_3 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(and__1, "Runtime assertion failed for expression (0 < u2) & (0 < u2 + 1) on node 'and__1'"); and__1 = _assert_scalar_default_1 = None
ge_1: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None
select: "i64[2]" = torch.ops.aten.select.int(slice_1, 0, 0); slice_1 = None
select_1: "i64[]" = torch.ops.aten.select.int(select, 0, 0)
item_2: "Sym(u4)" = torch.ops.aten.item.default(select_1); select_1 = None
select_2: "i64[]" = torch.ops.aten.select.int(select, 0, 1); select = None
item_3: "Sym(u5)" = torch.ops.aten.item.default(select_2); select_2 = None
slice_2: "f32[s55, s62]" = torch.ops.aten.slice.Tensor(w_1, 0, 0, 9223372036854775807); w_1 = None
slice_3: "f32[s55, u6]" = torch.ops.aten.slice.Tensor(slice_2, 1, item_2, item_3); slice_2 = item_2 = item_3 = None
sym_size_int_7: "Sym(u6)" = torch.ops.aten.sym_size.int(slice_3, 1)
sym_storage_offset_default_1: "Sym(u7)" = torch.ops.aten.sym_storage_offset.default(slice_3)
ge_2: "Sym(u6 >= 0)" = sym_size_int_7 >= 0
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u6 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_3 = None
le_3: "Sym(u6 <= s62)" = sym_size_int_7 <= sym_size_int_5; sym_size_int_5 = None
_assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(le_3, "Runtime assertion failed for expression u6 <= s62 on node 'le_3'"); le_3 = _assert_scalar_default_4 = None
ge_3: "Sym(u7 >= 0)" = sym_storage_offset_default_1 >= 0; sym_storage_offset_default_1 = None
_assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u7 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_default_5 = None
sym_min: "Sym(Min(4, u6))" = torch.sym_min(sym_size_int_7, 4)
slice_4: "f32[s55, u6]" = torch.ops.aten.slice.Tensor(slice_3, 0, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[s55, u8]" = torch.ops.aten.slice.Tensor(slice_4, 1, 0, sym_min); slice_4 = sym_min = None
sym_size_int_8: "Sym(u8)" = torch.ops.aten.sym_size.int(slice_5, 1)
ge_4: "Sym(u8 >= 0)" = sym_size_int_8 >= 0
_assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u8 >= 0 on node 'ge_4'"); ge_4 = _assert_scalar_default_6 = None
le_4: "Sym(u8 <= u6)" = sym_size_int_8 <= sym_size_int_7; sym_size_int_8 = sym_size_int_7 = None
_assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(le_4, "Runtime assertion failed for expression u8 <= u6 on node 'le_4'"); le_4 = _assert_scalar_default_7 = None
return (slice_5,)
Graph signature:
# inputs
w: USER_INPUT
splits: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[1, int_oo], u3: VR[0, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[0, int_oo], u7: VR[0, int_oo], u8: VR[0, int_oo], u9: VR[-int_oo, int_oo], u10: VR[1, int_oo], u11: VR[0, int_oo], u12: VR[-int_oo, int_oo], u13: VR[-int_oo, int_oo], u14: VR[0, int_oo], u15: VR[0, int_oo], u16: VR[0, int_oo], s55: VR[2, int_oo], s62: VR[2, int_oo], s10: VR[3, int_oo]}
doc.plot_legend("export a loop\nreturning\ndifferent shapes", "loops", "green")

Total running time of the script: (0 minutes 4.778 seconds)
Related examples