onnx_diagnostic.export.cf_simple_loop_for¶
- class onnx_diagnostic.export.cf_simple_loop_for.SimpleLoopForOp[source][source]¶
Higher order op for
simple_loop_for().
- onnx_diagnostic.export.cf_simple_loop_for.inner(mode, n_iter, body_fn, operands, concatenation_dims=None)[source][source]¶
Registered tracing implementation.
- onnx_diagnostic.export.cf_simple_loop_for.loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None)[source][source]¶
Registered eager mode implementation.
- onnx_diagnostic.export.cf_simple_loop_for.simple_loop_for(n_iter: int | Tensor, body_fn: Callable, operands: Tuple[Tensor, ...] = (), concatenation_dims: int | Sequence[int] | None = None) Tensor | Tuple[Tensor, ...][source][source]¶
Implements a simple loop for, the body is defined by a function which takes the iteration number stored in a tensor, and other tensors. It results one or several tensors in a tuple. All of them are finally concatenated along the first dimension.
- Parameters:
n_iter – iteration number
body – function
operands – bidy arguments
concatenation_dims – dimension or dimensions used to concatenate the output sequences
- Returns:
contenated outputs, the output is a Tensor
An example with one output:
<<<
import torch from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for class Model(torch.nn.Module): def forward(self, n_iter, x): def body(i, x): return (x[: i.item() + 1].unsqueeze(1),) return simple_loop_for(n_iter, body, (x,)) model = Model() n_iter = torch.tensor(4, dtype=torch.int64) x = torch.arange(10, dtype=torch.float32) ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) print(ep)
>>>
<frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, n_iter: "i64[]", x: "f32[s77]"): # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 simple_loop_for = torch.ops.higher_order.simple_loop_for(n_iter, body_graph_0, (x,), None); n_iter = body_graph_0 = x = None getitem: "f32[u2, 1]" = simple_loop_for[0]; simple_loop_for = None return (getitem,) class body_graph_0(torch.nn.Module): def forward(self, i_1: "i64[]", x_1: "f32[s77]"): # No stacktrace found for following nodes sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0) item: "Sym(u0)" = torch.ops.aten.item.default(i_1); i_1 = None add: "Sym(u0 + 1)" = item + 1; item = None slice_1: "f32[u1]" = torch.ops.aten.slice.Tensor(x_1, 0, 0, add); x_1 = add = None sym_size_int_1: "Sym(u1)" = torch.ops.aten.sym_size.int(slice_1, 0) ge: "Sym(u1 >= 0)" = sym_size_int_1 >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar_default = None le: "Sym(u1 <= s77)" = sym_size_int_1 <= sym_size_int; sym_size_int_1 = sym_size_int = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= s77 on node 'le'"); le = _assert_scalar_default_1 = None unsqueeze: "f32[u1, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None return (unsqueeze,) Graph signature: # inputs n_iter: USER_INPUT x: USER_INPUT # outputs getitem: USER_OUTPUT Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], u2: VR[0, int_oo], s77: VR[2, int_oo]}
Another example with two outputs and a final concatenation on different axes.
<<<
import torch from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for class Model(torch.nn.Module): def forward(self, n_iter, x): def body(i, x): return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0)) return simple_loop_for(n_iter, body, (x,), (0, 1)) model = Model() n_iter = torch.tensor(4, dtype=torch.int64) x = torch.arange(10, dtype=torch.float32) ep = torch.export.export( model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})) ) print(ep)
>>>
ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, n_iter: "i64[]", x: "f32[s77]"): # No stacktrace found for following nodes body_graph_0 = self.body_graph_0 simple_loop_for = torch.ops.higher_order.simple_loop_for(n_iter, body_graph_0, (x,), (0, 1)); n_iter = body_graph_0 = x = None getitem: "f32[u4, 1]" = simple_loop_for[0] getitem_1: "f32[1, u5]" = simple_loop_for[1]; simple_loop_for = None return (getitem, getitem_1) class body_graph_0(torch.nn.Module): def forward(self, i_1: "i64[]", x_1: "f32[s77]"): # No stacktrace found for following nodes sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0) item: "Sym(u0)" = torch.ops.aten.item.default(i_1); i_1 = None add: "Sym(u0 + 1)" = item + 1; item = None slice_1: "f32[u1]" = torch.ops.aten.slice.Tensor(x_1, 0, 0, add) sym_size_int_1: "Sym(u1)" = torch.ops.aten.sym_size.int(slice_1, 0) ge: "Sym(u1 >= 0)" = sym_size_int_1 >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar_default = None le: "Sym(u1 <= s77)" = sym_size_int_1 <= sym_size_int; sym_size_int_1 = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= s77 on node 'le'"); le = _assert_scalar_default_1 = None unsqueeze: "f32[u1, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None slice_2: "f32[u2]" = torch.ops.aten.slice.Tensor(x_1, 0, add, 9223372036854775807); x_1 = add = None sym_size_int_2: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_2, 0) sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_2) ge_1: "Sym(u2 >= 0)" = sym_size_int_2 >= 0 _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None le_1: "Sym(u2 <= s77)" = sym_size_int_2 <= sym_size_int; sym_size_int_2 = sym_size_int = None _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u2 <= s77 on node 'le_1'"); le_1 = _assert_scalar_default_3 = None ge_2: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u3 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_4 = None unsqueeze_1: "f32[1, u2]" = torch.ops.aten.unsqueeze.default(slice_2, 0); slice_2 = None return (unsqueeze, unsqueeze_1) Graph signature: # inputs n_iter: USER_INPUT x: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], u2: VR[0, int_oo], u3: VR[0, int_oo], u4: VR[0, int_oo], u5: VR[0, int_oo], u6: VR[0, int_oo], s77: VR[2, int_oo]}