Export with loops

This is a simple example of loop which cannot be efficiently rewritten with scan.

import torch
from onnx_diagnostic import doc
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for


class Model(torch.nn.Module):
    def __init__(self, crop_size):
        super().__init__()
        self.crop_size = crop_size

    def forward(self, W, splits):
        crop_size = self.crop_size
        starts = splits[:-1]
        ends = splits[1:]
        cropped = []
        for start, end in zip(starts, ends):
            extract = W[:, start:end]
            if extract.shape[1] < crop_size:
                cropped.append(extract)
            else:
                cropped.append(extract[:, :crop_size])
        return torch.cat(cropped, axis=1)


model = Model(4)
args = (torch.rand((2, 22)), torch.tensor([0, 5, 15, 20, 22], dtype=torch.int64))

expected = model(*args)
print(f"-- exected shape: {expected.shape}")
-- exected shape: torch.Size([2, 14])

Rewrite with higher order ops scan

The loop cannot be exported as is. It needs to be rewritten.

class ModelWithScan(Model):
    def forward(self, W, splits):
        crop_size = self.crop_size
        starts = splits[:-1]
        ends = splits[1:]

        def body_scan(init, split, W):
            extract = W[:, split[0].item() : split[1].item()]
            cropped = extract[:, : torch.sym_min(extract.shape[1], crop_size)]
            carried = torch.cat([init, cropped], axis=1)
            return carried

        starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
        return torch.ops.higher_order.scan(
            body_scan, [torch.empty((W.shape[0], 0), dtype=W.dtype)], [starts_ends], [W]
        )


rewritten_model_with_scan = ModelWithScan(4)
(results,) = rewritten_model_with_scan(*args)

print(f"-- max discrepancies with scan: { torch.abs(expected - results).max()}")
-- max discrepancies with scan: 0.0

This approach has one flows, the variable carried grows at every iteration and the cost of the copy is quadratic when the same operation in the first model is linear. We cannot simply return variable cropped because its shape is not always the same.

Introduce of a new higher order ops: simple_loop_for

simple_loop_for was designed to support this specific case. It takes all the outputs coming from the body function and stores them in list. Then it contenates them according to concatenation_dims.

class ModelWithLoop(Model):
    def forward(self, W, splits):
        crop_size = self.crop_size
        starts = splits[:-1]
        ends = splits[1:]

        def body_loop(i, splits, W):
            split = splits[i.item() : (i + 1).item()][0]  # [i.item()] fails
            extract = W[:, split[0].item() : split[1].item()]
            cropped = extract[:, : torch.sym_min(extract.shape[1], crop_size)]
            return (cropped,)

        starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
        n_iterations = torch.tensor(starts_ends.shape[0], dtype=torch.int64)
        return simple_loop_for(
            n_iterations, body_loop, (starts_ends, W), concatenation_dims=[1]
        )


rewritten_model_with_loop = ModelWithLoop(4)
results = rewritten_model_with_loop(*args)

print(f"-- max discrepancies with loop: { torch.abs(expected - results).max()}")
-- max discrepancies with loop: 0.0

torch.export.export?

dynamic_shapes = (
    {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC},
    {0: torch.export.Dim.DYNAMIC},
)
try:
    ep = torch.export.export(rewritten_model_with_scan, args, dynamic_shapes=dynamic_shapes)
    print("----- exported program with scan:")
    print(ep)
except Exception as e:
    print(f"export failed due to {e}")
export failed due to object of type 'Node' has no len()

And loops?

ep = torch.export.export(rewritten_model_with_loop, args, dynamic_shapes=dynamic_shapes)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, w: "f32[s55, s62]", splits: "i64[s10]"):
            # No stacktrace found for following nodes
            sym_size_int_3: "Sym(s10)" = torch.ops.aten.sym_size.int(splits, 0)

            # File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:90 in forward, code: starts = splits[:-1]
            slice_1: "i64[s10 - 1]" = torch.ops.aten.slice.Tensor(splits, 0, 0, -1)

            # File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:91 in forward, code: ends = splits[1:]
            slice_2: "i64[s10 - 1]" = torch.ops.aten.slice.Tensor(splits, 0, 1, 9223372036854775807);  splits = None

            # File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:99 in forward, code: starts_ends = torch.cat([starts.unsqueeze(1), ends.unsqueeze(1)], axis=1)
            unsqueeze: "i64[s10 - 1, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
            unsqueeze_1: "i64[s10 - 1, 1]" = torch.ops.aten.unsqueeze.default(slice_2, 1);  slice_2 = None
            cat: "i64[s10 - 1, 2]" = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1);  unsqueeze = unsqueeze_1 = None

            # File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:100 in forward, code: n_iterations = torch.tensor(starts_ends.shape[0], dtype=torch.int64)
            add: "Sym(s10 - 1)" = -1 + sym_size_int_3;  sym_size_int_3 = None
            scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(add, dtype = torch.int64, device = device(type='cpu'), pin_memory = False);  add = None
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(scalar_tensor, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "i64[]" = torch.ops.aten.to.device(scalar_tensor, device(type='cpu'), torch.int64);  scalar_tensor = None
            detach_: "i64[]" = torch.ops.aten.detach_.default(to);  to = None

            # File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:101 in forward, code: return simple_loop_for(
            body_graph_0 = self.body_graph_0
            simple_loop_for = torch.ops.higher_order.simple_loop_for(detach_, body_graph_0, (cat, w), [1]);  detach_ = body_graph_0 = cat = w = None
            getitem: "f32[s55, u16]" = simple_loop_for[0];  simple_loop_for = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, i_1: "i64[]", splits_1: "i64[s10 - 1, 2]", W_1: "f32[s55, s62]"):
                w_1 = W_1

                # No stacktrace found for following nodes
                sym_size_int_5: "Sym(s62)" = torch.ops.aten.sym_size.int(w_1, 1)

                # File: ~/github/onnx-diagnostic/_doc/technical/plot_simple_for_loop.py:101 in forward, code: return simple_loop_for(
                item: "Sym(u0)" = torch.ops.aten.item.default(i_1)
                add: "i64[]" = torch.ops.aten.add.Tensor(i_1, 1);  i_1 = None
                item_1: "Sym(u1)" = torch.ops.aten.item.default(add);  add = None
                slice_1: "i64[u2, 2]" = torch.ops.aten.slice.Tensor(splits_1, 0, item, item_1);  splits_1 = item = item_1 = None
                sym_size_int_6: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
                sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
                ge: "Sym(u2 >= 1)" = sym_size_int_6 >= 1
                _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 1 on node 'ge'");  ge = _assert_scalar_default = None
                gt_2: "Sym(u2 > 0)" = sym_size_int_6 > 0
                sym_sum: "Sym(u2 + 1)" = torch.sym_sum([1, sym_size_int_6]);  sym_size_int_6 = None
                gt_3: "Sym(u2 + 1 > 0)" = sym_sum > 0;  sym_sum = None
                and__1: "Sym((u2 > 0) & (u2 + 1 > 0))" = gt_2 & gt_3;  gt_2 = gt_3 = None
                _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(and__1, "Runtime assertion failed for expression (0 < u2) & (0 < u2 + 1) on node 'and__1'");  and__1 = _assert_scalar_default_1 = None
                ge_1: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0;  sym_storage_offset_default = None
                _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default_2 = None
                select: "i64[2]" = torch.ops.aten.select.int(slice_1, 0, 0);  slice_1 = None
                select_1: "i64[]" = torch.ops.aten.select.int(select, 0, 0)
                item_2: "Sym(u4)" = torch.ops.aten.item.default(select_1);  select_1 = None
                select_2: "i64[]" = torch.ops.aten.select.int(select, 0, 1);  select = None
                item_3: "Sym(u5)" = torch.ops.aten.item.default(select_2);  select_2 = None
                slice_2: "f32[s55, s62]" = torch.ops.aten.slice.Tensor(w_1, 0, 0, 9223372036854775807);  w_1 = None
                slice_3: "f32[s55, u6]" = torch.ops.aten.slice.Tensor(slice_2, 1, item_2, item_3);  slice_2 = item_2 = item_3 = None
                sym_size_int_7: "Sym(u6)" = torch.ops.aten.sym_size.int(slice_3, 1)
                sym_storage_offset_default_1: "Sym(u7)" = torch.ops.aten.sym_storage_offset.default(slice_3)
                ge_2: "Sym(u6 >= 0)" = sym_size_int_7 >= 0
                _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u6 >= 0 on node 'ge_2'");  ge_2 = _assert_scalar_default_3 = None
                le_3: "Sym(u6 <= s62)" = sym_size_int_7 <= sym_size_int_5;  sym_size_int_5 = None
                _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(le_3, "Runtime assertion failed for expression u6 <= s62 on node 'le_3'");  le_3 = _assert_scalar_default_4 = None
                ge_3: "Sym(u7 >= 0)" = sym_storage_offset_default_1 >= 0;  sym_storage_offset_default_1 = None
                _assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u7 >= 0 on node 'ge_3'");  ge_3 = _assert_scalar_default_5 = None
                sym_min: "Sym(Min(4, u6))" = torch.sym_min(sym_size_int_7, 4)
                slice_4: "f32[s55, u6]" = torch.ops.aten.slice.Tensor(slice_3, 0, 0, 9223372036854775807);  slice_3 = None
                slice_5: "f32[s55, u8]" = torch.ops.aten.slice.Tensor(slice_4, 1, 0, sym_min);  slice_4 = sym_min = None
                sym_size_int_8: "Sym(u8)" = torch.ops.aten.sym_size.int(slice_5, 1)
                ge_4: "Sym(u8 >= 0)" = sym_size_int_8 >= 0
                _assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u8 >= 0 on node 'ge_4'");  ge_4 = _assert_scalar_default_6 = None
                le_4: "Sym(u8 <= u6)" = sym_size_int_8 <= sym_size_int_7;  sym_size_int_8 = sym_size_int_7 = None
                _assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(le_4, "Runtime assertion failed for expression u8 <= u6 on node 'le_4'");  le_4 = _assert_scalar_default_7 = None
                return (slice_5,)

Graph signature:
    # inputs
    w: USER_INPUT
    splits: USER_INPUT

    # outputs
    getitem: USER_OUTPUT

Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[1, int_oo], u3: VR[0, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[0, int_oo], u7: VR[0, int_oo], u8: VR[0, int_oo], u9: VR[-int_oo, int_oo], u10: VR[1, int_oo], u11: VR[0, int_oo], u12: VR[-int_oo, int_oo], u13: VR[-int_oo, int_oo], u14: VR[0, int_oo], u15: VR[0, int_oo], u16: VR[0, int_oo], s55: VR[2, int_oo], s62: VR[2, int_oo], s10: VR[3, int_oo]}
doc.plot_legend("export a loop\nreturning\ndifferent shapes", "loops", "green")
plot simple for loop

Total running time of the script: (0 minutes 4.778 seconds)

Related examples

Gemm or Matmul + Add

Gemm or Matmul + Add

Dynamic Shapes and Broadcasting

Dynamic Shapes and Broadcasting

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

Gallery generated by Sphinx-Gallery