onnx_diagnostic.export.control_flow

onnx_diagnostic.export.control_flow.convert_custom_loop_into_onnx(g: Any, sts: Dict[str, Any], outputs: List[str], *args: str, body_callable: Callable[[...], ModelProto], reduction_dim: Sequence[int] | None = None, name: str = 'loop_for') str | List[str][source][source]

Converts a custom op higher_ops::loop_for... into e sequence of node.

Parameters:
  • g – GreaphBuilder

  • sts – if not defined, torch does not know the output shapes

  • outputs – output names

  • args – input argument known at export time

  • body – GraphProto, the loop body

  • reduction_dim – the dimension to follow when aggregating the list of tensors after the loop ran

  • name – to give the onnx nodes a name

Returns:

output names

onnx_diagnostic.export.control_flow.convert_into_onnx(body_gm: GraphModule, args: Sequence[Tensor], target_opset: int | None = None, verbose: int = 0, exporter_kwargs: Dict[str, Any] | None = None) ModelProto[source][source]

Converts a torch.fx.GraphModule into ONNX. It returns a ModelProto.

Parameters:
  • body_gm – a torch.fx.GraphModule

  • args – arguments known at export time

  • target_opset – targeted opset

  • verbose – verbosity level

  • exporter_kwargs – additional exporter arguments

Returns:

a ModelProto

onnx_diagnostic.export.control_flow.enable_code_export_control_flow()[source][source]

Enables the code means to be exported.

onnx_diagnostic.export.control_flow.is_exporting() bool[source][source]

Returns torch.compiler.is_exporting() or torch.compiler.is_compiling(). Changes _TEST_EXPORT to make it trigger.

onnx_diagnostic.export.control_flow.loop_for(n_iter: SymInt | Tensor, body_fn: Callable[[...], Tuple[Tensor]], args: Sequence[Tensor], reduction_dim: Sequence[int] | None = None) Tuple[Tensor, ...][source][source]

High operators used to easily export a loop in ONNX. Does not fully work with torch.export.export(), it does replaces a custom op with a loop operator afterwards. Every iteration produces tensors, all of them are gathered into lists, all these lists are concatenated into tensors.

Parameters:
  • n_iter – number of iterations, it can be fixed on variable, in that case it should a tensor with no dimension

  • body_fn – function body, takes only tensors and returns only tensors, the first argument is the iteration number in a tensor with no dimension, all the others are not changed during the loop

  • args – the available tensors at every loop

  • reduction_dim – the loop aggregated the results into list, one of each output, each of them is concatenated into one tensor along one dimension, by default, it is the first dimension, but it can be defined otherwise

<<<

import torch
import onnxruntime
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.export.control_flow import loop_for


class Model(torch.nn.Module):
    def forward(self, n_iter, x):
        def body(i, x):
            return x[: i.item() + 1].unsqueeze(1)

        return loop_for(n_iter, body, (x,))


model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = model(n_iter, x)
print("expected:", expected)

onx = to_onnx(
    model,
    (n_iter, x),
    dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
    exporter="custom",
    use_control_flow_dispatcher=True,
).model_proto

sess = onnxruntime.InferenceSession(
    onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
print("got:", got)


# The loop is exported as a custom ops.
ep = torch.export.export(
    model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
print(ep)

>>>

    expected: tensor([[0.],
            [0.],
            [1.],
            [0.],
            [1.],
            [2.],
            [0.],
            [1.],
            [2.],
            [3.]])
    got: [array([[0.],
           [0.],
           [1.],
           [0.],
           [1.],
           [2.],
           [0.],
           [1.],
           [2.],
           [3.]], dtype=float32)]
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, n_iter: "i64[]", x: "f32[s77]"):
                # No stacktrace found for following nodes
                loop_for_body_129679871440864_u1x1_: "f32[u1, 1]" = torch.ops.onnx_higher_ops.loop_for_body_129679871440864_u1x1_.default(n_iter, x);  n_iter = x = None
                return (loop_for_body_129679871440864_u1x1_,)
                
    Graph signature: 
        # inputs
        n_iter: USER_INPUT
        x: USER_INPUT
        
        # outputs
        loop_for_body_129679871440864_u1x1_: USER_OUTPUT
        
    Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], s77: VR[2, int_oo]}

Another example with two outputs:

<<<

import torch
import onnxruntime
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.export.control_flow import loop_for


class Model(torch.nn.Module):
    def forward(self, n_iter, x):
        def body(i, x):
            return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(1) + 1

        two = loop_for(n_iter, body, (x,))
        return two[0] + two[1]


model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = model(n_iter, x)
print("expected:", expected)

onx = to_onnx(
    model,
    (n_iter, x),
    dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
    exporter="custom",
    use_control_flow_dispatcher=True,
).model_proto

sess = onnxruntime.InferenceSession(
    onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
print("got:", got)


# The loop is exported as a custom ops.
ep = torch.export.export(
    model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
print(ep)

>>>

    expected: tensor([[1.],
            [1.],
            [3.],
            [1.],
            [3.],
            [5.],
            [1.],
            [3.],
            [5.],
            [7.]])
    got: [array([[1.],
           [1.],
           [3.],
           [1.],
           [3.],
           [5.],
           [1.],
           [3.],
           [5.],
           [7.]], dtype=float32)]
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, n_iter: "i64[]", x: "f32[s77]"):
                # No stacktrace found for following nodes
                loop_for_body_129679802890688_u1x1_u2x1_ = torch.ops.onnx_higher_ops.loop_for_body_129679802890688_u1x1_u2x1_.default(n_iter, x);  n_iter = x = None
                getitem: "f32[u1, 1]" = loop_for_body_129679802890688_u1x1_u2x1_[0]
                getitem_1: "f32[u2, 1]" = loop_for_body_129679802890688_u1x1_u2x1_[1];  loop_for_body_129679802890688_u1x1_u2x1_ = None
                add: "f32[u1, 1]" = torch.ops.aten.add.Tensor(getitem, getitem_1);  getitem = getitem_1 = None
                return (add,)
                
    Graph signature: 
        # inputs
        n_iter: USER_INPUT
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], u2: VR[0, int_oo], s77: VR[2, int_oo]}

A last example with reduction_dim:

<<<

import torch
import onnxruntime
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.export.control_flow import loop_for


class Model(torch.nn.Module):
    def forward(self, n_iter, x):
        def body(i, x):
            return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1

        two = loop_for(n_iter, body, (x,), reduction_dim=[0, 1])
        return two[0] + two[1].T


model = Model()
n_iter = torch.tensor(4, dtype=torch.int64)
x = torch.arange(10, dtype=torch.float32)
expected = model(n_iter, x)
print("expected:", expected)

onx = to_onnx(
    model,
    (n_iter, x),
    dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC})),
    exporter="custom",
    use_control_flow_dispatcher=True,
).model_proto

sess = onnxruntime.InferenceSession(
    onx.SerializeToString(), providers=["CPUExecutionProvider"]
)
got = sess.run(None, dict(zip(["n_iter", "x"], [n_iter.numpy(), x.numpy()])))
print("got:", got)


# The loop is exported as a custom ops.
ep = torch.export.export(
    model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
)
print(ep)

>>>

    expected: tensor([[1.],
            [1.],
            [3.],
            [1.],
            [3.],
            [5.],
            [1.],
            [3.],
            [5.],
            [7.]])
    got: [array([[1.],
           [1.],
           [3.],
           [1.],
           [3.],
           [5.],
           [1.],
           [3.],
           [5.],
           [7.]], dtype=float32)]
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, n_iter: "i64[]", x: "f32[s77]"):
                # No stacktrace found for following nodes
                loop_for_body_129680165357184_u1x1_1xu2_0x1 = torch.ops.onnx_higher_ops.loop_for_body_129680165357184_u1x1_1xu2_0x1.default(n_iter, x);  n_iter = x = None
                getitem: "f32[u1, 1]" = loop_for_body_129680165357184_u1x1_1xu2_0x1[0]
                getitem_1: "f32[1, u2]" = loop_for_body_129680165357184_u1x1_1xu2_0x1[1];  loop_for_body_129680165357184_u1x1_1xu2_0x1 = None
                numpy_t: "f32[u2, 1]" = torch.ops.aten.numpy_T.default(getitem_1);  getitem_1 = None
                add: "f32[u1, 1]" = torch.ops.aten.add.Tensor(getitem, numpy_t);  getitem = numpy_t = None
                return (add,)
                
    Graph signature: 
        # inputs
        n_iter: USER_INPUT
        x: USER_INPUT
        
        # outputs
        add: USER_OUTPUT
        
    Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[0, int_oo], u2: VR[0, int_oo], s77: VR[2, int_oo]}
onnx_diagnostic.export.control_flow.make_custom_loop_for(n_iter: Tensor, body_fn: Callable, reduction_dim: Sequence[int] | None, args: Sequence[Tensor], body_gm: GraphModule | None = None, body_mutated_inputs: List[Any] | None = None, body_outputs: List[Any] | None = None) Tuple[str, CustomOpDef][source][source]

Defines a custom operator for a loop in order to avoid torch.export.export() digging into it. It registers the custom op and a custom conversion to ONNX.

Parameters:
  • n_iter – number of iterations defined by a tensor of no dimension

  • body_fn – the loop body defined as a function

  • reduction_dim – dimension used to concatenated the results

  • args – list of tensors, input to the body

  • body_gm – torch.fx.GraphModule equivalent to body_gm

  • body_mutated_inputs – inputs to body_gm

  • body_outputs – outputs to body_gm

Returns:

a name and the custom op definition, the name is used to cache the custom op