onnx_diagnostic.export.cf_simple_loop_for

class onnx_diagnostic.export.cf_simple_loop_for.SimpleLoopForOp[source][source]

Higher order op for simple_loop_for().

onnx_diagnostic.export.cf_simple_loop_for.inner(mode, n_iter, body_fn, operands, concatenation_dims=None)[source][source]

Registered tracing implementation.

onnx_diagnostic.export.cf_simple_loop_for.loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None)[source][source]

Registered eager mode implementation.

onnx_diagnostic.export.cf_simple_loop_for.simple_loop_for(n_iter: int | Tensor, body_fn: Callable, operands: Tuple[Tensor, ...] = (), concatenation_dims: int | Sequence[int] | None = None) Tensor | Tuple[Tensor, ...][source][source]

Implements a simple loop for, the body is defined by a function which takes the iteration number stored in a tensor, and other tensors. It results one or several tensors in a tuple. All of them are finally concatenated along the first dimension.

Parameters:
  • n_iter – iteration number

  • body – function

  • operands – bidy arguments

  • concatenation_dims – dimension or dimensions used to concatenate the output sequences

Returns:

contenated outputs, the output is a Tensor

An example with one output:

<<<

import torch
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for


class Model(torch.nn.Module):
    def forward(self, n_iter, x):
        def body(i, x):
            return (x[: i.item() + 1].unsqueeze(1),)

        return simple_loop_for(n_iter, body, (x,))


model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
ep = torch.export.export(
    model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
print(ep)

>>>

    <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute
    <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, n_iter: "i64[]", x: "f32[s77]"):
                # No stacktrace found for following nodes
                body_graph_0 = self.body_graph_0
                simple_loop_for = torch.ops.higher_order.simple_loop_for(n_iter, body_graph_0, (x,), None);  n_iter = body_graph_0 = x = None
                getitem: "f32[u2, 1]" = simple_loop_for[0];  simple_loop_for = None
                return (getitem,)
                
            class body_graph_0(torch.nn.Module):
                def forward(self, i_1: "i64[]", x_1: "f32[s77]"):
                    # No stacktrace found for following nodes
                    sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0)
                    item: "Sym(u0)" = torch.ops.aten.item.default(i_1);  i_1 = None
                    add: "Sym(u0 + 1)" = item + 1;  item = None
                    slice_1: "f32[u1]" = torch.ops.aten.slice.Tensor(x_1, 0, 0, add);  x_1 = add = None
                    sym_size_int_1: "Sym(u1)" = torch.ops.aten.sym_size.int(slice_1, 0)
                    ge: "Sym(u1 >= 0)" = sym_size_int_1 >= 0
                    _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'");  ge = _assert_scalar_default = None
                    le: "Sym(u1 <= s77)" = sym_size_int_1 <= sym_size_int;  sym_size_int_1 = sym_size_int = None
                    _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= s77 on node 'le'");  le = _assert_scalar_default_1 = None
                    unsqueeze: "f32[u1, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
                    return (unsqueeze,)
                    
    Graph signature: 
        # inputs
        n_iter: USER_INPUT
        x: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        
    Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], u2: VR[0, int_oo], s77: VR[2, int_oo]}

Another example with two outputs and a final concatenation on different axes.

<<<

import torch
from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for


class Model(torch.nn.Module):
    def forward(self, n_iter, x):
        def body(i, x):
            return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))

        return simple_loop_for(n_iter, body, (x,), (0, 1))


model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
ep = torch.export.export(
    model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
print(ep)

>>>

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, n_iter: "i64[]", x: "f32[s77]"):
                # No stacktrace found for following nodes
                body_graph_0 = self.body_graph_0
                simple_loop_for = torch.ops.higher_order.simple_loop_for(n_iter, body_graph_0, (x,), (0, 1));  n_iter = body_graph_0 = x = None
                getitem: "f32[u4, 1]" = simple_loop_for[0]
                getitem_1: "f32[1, u5]" = simple_loop_for[1];  simple_loop_for = None
                return (getitem, getitem_1)
                
            class body_graph_0(torch.nn.Module):
                def forward(self, i_1: "i64[]", x_1: "f32[s77]"):
                    # No stacktrace found for following nodes
                    sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0)
                    item: "Sym(u0)" = torch.ops.aten.item.default(i_1);  i_1 = None
                    add: "Sym(u0 + 1)" = item + 1;  item = None
                    slice_1: "f32[u1]" = torch.ops.aten.slice.Tensor(x_1, 0, 0, add)
                    sym_size_int_1: "Sym(u1)" = torch.ops.aten.sym_size.int(slice_1, 0)
                    ge: "Sym(u1 >= 0)" = sym_size_int_1 >= 0
                    _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'");  ge = _assert_scalar_default = None
                    le: "Sym(u1 <= s77)" = sym_size_int_1 <= sym_size_int;  sym_size_int_1 = None
                    _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= s77 on node 'le'");  le = _assert_scalar_default_1 = None
                    unsqueeze: "f32[u1, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
                    slice_2: "f32[u2]" = torch.ops.aten.slice.Tensor(x_1, 0, add, 9223372036854775807);  x_1 = add = None
                    sym_size_int_2: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_2, 0)
                    sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_2)
                    ge_1: "Sym(u2 >= 0)" = sym_size_int_2 >= 0
                    _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default_2 = None
                    le_1: "Sym(u2 <= s77)" = sym_size_int_2 <= sym_size_int;  sym_size_int_2 = sym_size_int = None
                    _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u2 <= s77 on node 'le_1'");  le_1 = _assert_scalar_default_3 = None
                    ge_2: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0;  sym_storage_offset_default = None
                    _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u3 >= 0 on node 'ge_2'");  ge_2 = _assert_scalar_default_4 = None
                    unsqueeze_1: "f32[1, u2]" = torch.ops.aten.unsqueeze.default(slice_2, 0);  slice_2 = None
                    return (unsqueeze, unsqueeze_1)
                    
    Graph signature: 
        # inputs
        n_iter: USER_INPUT
        x: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        getitem_1: USER_OUTPUT
        
    Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], u2: VR[0, int_oo], u3: VR[0, int_oo], u4: VR[0, int_oo], u5: VR[0, int_oo], u6: VR[0, int_oo], s77: VR[2, int_oo]}
onnx_diagnostic.export.cf_simple_loop_for.simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None)[source][source]

Registered FakeMode implementation.

onnx_diagnostic.export.cf_simple_loop_for.trace_simple_loop_for(proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims)[source][source]

See function simple_loop_for.