torch.onnx.export and a model with a test

Control flow cannot be exported with a change. The code of the model can be changed or patched to introduce function torch.cond().

A model with a test

from onnx.printer import to_text
import torch

We define a model with a control flow (-> graph break)

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlowTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlowTest()

Let’s check it runs.

x = torch.randn(3)
model(x)
tensor([1.6173], grad_fn=<MulBackward0>)

As expected, it does not export.

try:
    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
    print(e)
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py", line 40, in forward
    out = self.mlp(x)
  File "/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py", line 25, in forward
    if x.sum():

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

It does export with torch.onnx.export because it uses JIT to trace the execution. But the model is not exactly the same as the initial model.

ep = torch.onnx.export(model, (x,), dynamo=True)
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script...
/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:25: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if x.sum():
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... ✅
[torch.onnx] Run decomposition...
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  getattr_node = gm.graph.get_attr(lifted_node)
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_6 target lifted_tensor_6 lifted_tensor_6 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
   ir_version: 10,
   opset_import: ["pkg.onnxscript.torch_lib.common" : 1, "" : 18],
   producer_name: "pytorch",
   producer_version: "2.6.0.dev20241128+cu124"
>
main_graph (float[3] input_1) => (float[1] mul)
   <float[2] "model.mlp.0.bias" =  {-0.288676,-0.368527}, float[2,3] "model.mlp.0.weight" =  {-0.471258,0.324141,0.257305,0.0853913,-0.559284,-0.361027}, float[1] "model.mlp.1.bias" =  {-0.145392}, float[1,2] "model.mlp.1.weight" =  {0.501595,-0.663844}, float[1,3] view, float[3,2] t, float[1,2] addmm, float[2] view_1, float[1,2] view_2, float[2,1] t_1, float[1,1] addmm_1, float[1] view_3, float scalar_tensor_default>
{
   [node_Constant_0] val_0 = Constant <value: tensor = int64[2] {1,3}> ()
   [node_Cast_1] val_1 = Cast <to: int = 7> (val_0)
   [node_Reshape_2] view = Reshape <allowzero: int = 0> (input_1, val_1)
   [node_Transpose_3] t = Transpose <perm: ints = [1, 0]> ("model.mlp.0.weight")
   [node_Gemm_4] addmm = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view, t, "model.mlp.0.bias")
   [node_Constant_5] val_2 = Constant <value: tensor = int64[1] {2}> ()
   [node_Cast_6] val_3 = Cast <to: int = 7> (val_2)
   [node_Reshape_7] view_1 = Reshape <allowzero: int = 0> (addmm, val_3)
   [node_Constant_8] val_4 = Constant <value: tensor = int64[2] {1,2}> ()
   [node_Cast_9] val_5 = Cast <to: int = 7> (val_4)
   [node_Reshape_10] view_2 = Reshape <allowzero: int = 0> (view_1, val_5)
   [node_Transpose_11] t_1 = Transpose <perm: ints = [1, 0]> ("model.mlp.1.weight")
   [node_Gemm_12] addmm_1 = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view_2, t_1, "model.mlp.1.bias")
   [node_Constant_13] val_6 = Constant <value: tensor = int64[1] {1}> ()
   [node_Cast_14] val_7 = Cast <to: int = 7> (val_6)
   [node_Reshape_15] view_3 = Reshape <allowzero: int = 0> (addmm_1, val_7)
   [node_Constant_16] val_8 = Constant <value: tensor = int64 {2}> ()
   [node_Cast_17] scalar_tensor_default = Cast <to: int = 1> (val_8)
   [node_Mul_18] mul = Mul (view_3, scalar_tensor_default)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   [n0] tmp = Shape (input)
   [n1] return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   [n0] tmp = Shape (input)
   [n1] tmp_0 = Size (tmp)
   [n2] tmp_1 = Constant <value_int: int = 0> ()
   [n3] return_val = Equal (tmp_0, tmp_1)
}

Suggested Patch

Let’s avoid the graph break by replacing the forward.

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

Let’s see what the fx graph looks like.

print(torch.export.export(model, (x,)).graph)
graph():
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%linear_1]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

Let’s export again.

ep = torch.onnx.export(model, (x,), dynamo=True)
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
   ir_version: 10,
   opset_import: ["pkg.onnxscript.torch_lib.common" : 1, "" : 18, "pkg.torch.__subgraph__" : 1],
   producer_name: "pytorch",
   producer_version: "2.6.0.dev20241128+cu124"
>
main_graph (float[3] x) => (float[1] getitem)
   <float[2,3] "mlp.0.weight" =  {-0.471258,0.324141,0.257305,0.0853913,-0.559284,-0.361027}, float[2] "mlp.0.bias" =  {-0.288676,-0.368527}, float[1,2] "mlp.1.weight" =  {0.501595,-0.663844}, float[1] "mlp.1.bias" =  {-0.145392}, float[1,3] view, float[3,2] t, float[1,2] addmm, float[2] view_1, float[1,2] view_2, float[2,1] t_1, float[1,1] addmm_1, float[1] view_3, float sum_1, float scalar_tensor_default, bool gt>
{
   [node_Constant_0] val_0 = Constant <value: tensor = int64[2] {1,3}> ()
   [node_Cast_1] val_1 = Cast <to: int = 7> (val_0)
   [node_Reshape_2] view = Reshape <allowzero: int = 0> (x, val_1)
   [node_Transpose_3] t = Transpose <perm: ints = [1, 0]> ("mlp.0.weight")
   [node_Gemm_4] addmm = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view, t, "mlp.0.bias")
   [node_Constant_5] val_2 = Constant <value: tensor = int64[1] {2}> ()
   [node_Cast_6] val_3 = Cast <to: int = 7> (val_2)
   [node_Reshape_7] view_1 = Reshape <allowzero: int = 0> (addmm, val_3)
   [node_Constant_8] val_4 = Constant <value: tensor = int64[2] {1,2}> ()
   [node_Cast_9] val_5 = Cast <to: int = 7> (val_4)
   [node_Reshape_10] view_2 = Reshape <allowzero: int = 0> (view_1, val_5)
   [node_Transpose_11] t_1 = Transpose <perm: ints = [1, 0]> ("mlp.1.weight")
   [node_Gemm_12] addmm_1 = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view_2, t_1, "mlp.1.bias")
   [node_Constant_13] val_6 = Constant <value: tensor = int64[1] {1}> ()
   [node_Cast_14] val_7 = Cast <to: int = 7> (val_6)
   [node_Reshape_15] view_3 = Reshape <allowzero: int = 0> (addmm_1, val_7)
   [node_ReduceSum_16] sum_1 = ReduceSum <noop_with_empty_axes: int = 0, keepdims: int = 0> (view_3)
   [node_Constant_17] val_8 = Constant <value: tensor = int64 {0}> ()
   [node_Cast_18] scalar_tensor_default = Cast <to: int = 1> (val_8)
   [node_Greater_19] gt = Greater (sum_1, scalar_tensor_default)
   [node_If_20] getitem = If (gt) <then_branch: graph = true_graph_0 () => ( mul_true_graph_0) {
      [node_true_graph_0_0] mul_true_graph_0 = pkg.torch.__subgraph__.true_graph_0 (view_3)
   }, else_branch: graph = false_graph_0 () => ( neg_false_graph_0) {
      [node_false_graph_0_0] neg_false_graph_0 = pkg.torch.__subgraph__.false_graph_0 (view_3)
   }>
}
<
  domain: "pkg.torch.__subgraph__",
  opset_import: ["" : 18]
>
false_graph_0 (view_3) => (neg)
{
   [node_Neg_0] neg = Neg (view_3)
}
<
  domain: "pkg.torch.__subgraph__",
  opset_import: ["" : 18]
>
true_graph_0 (view_3) => (mul)
{
   [node_Constant_0] val_0 = Constant <value: tensor = int64 {2}> ()
   [node_Cast_1] scalar_tensor_default = Cast <to: int = 1> (val_0)
   [node_Mul_2] mul = Mul (view_3, scalar_tensor_default)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   [n0] tmp = Shape (input)
   [n1] return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   [n0] tmp = Shape (input)
   [n1] tmp_0 = Size (tmp)
   [n2] tmp_1 = Constant <value_int: int = 0> ()
   [n3] return_val = Equal (tmp_0, tmp_1)
}

Let’s optimize to see a small model.

ep = torch.onnx.export(model, (x,), dynamo=True)
ep.optimize()
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 2 of general pattern rewrite rules.
<
   ir_version: 10,
   opset_import: ["pkg.onnxscript.torch_lib.common" : 1, "" : 18, "pkg.torch.__subgraph__" : 1],
   producer_name: "pytorch",
   producer_version: "2.6.0.dev20241128+cu124"
>
main_graph (float[3] x) => (float[1] getitem)
   <float[2] "mlp.0.bias" =  {-0.288676,-0.368527}, float[1] "mlp.1.bias" =  {-0.145392}, float[3,2] t, float[2] val_9, float[2] view_1, float[2,1] t_1, float[1] val_10, float[1] view_3, float sum_1, float scalar_tensor_default, bool gt>
{
   [node_Constant_23] t = Constant <value: tensor = float[3,2] t {-0.471258,0.0853913,0.324141,-0.559284,0.257305,-0.361027}> ()
   [node_MatMul_32] val_9 = MatMul (x, t)
   [node_Add_33] view_1 = Add (val_9, "mlp.0.bias")
   [node_Constant_28] t_1 = Constant <value: tensor = float[2,1] t_1 {0.501595,-0.663844}> ()
   [node_MatMul_34] val_10 = MatMul (view_1, t_1)
   [node_Add_35] view_3 = Add (val_10, "mlp.1.bias")
   [node_ReduceSum_16] sum_1 = ReduceSum <noop_with_empty_axes: int = 0, keepdims: int = 0> (view_3)
   [node_Constant_31] scalar_tensor_default = Constant <value: tensor = float scalar_tensor_default {0}> ()
   [node_Greater_19] gt = Greater (sum_1, scalar_tensor_default)
   [node_If_20] getitem = If (gt) <then_branch: graph = true_graph_0 () => (float[1] mul_true_graph_0)
      <float scalar_tensor_default_2>
{
      [node_Constant_1] scalar_tensor_default_2 = Constant <value: tensor = float scalar_tensor_default_2 {2}> ()
      [node_Mul_2] mul_true_graph_0 = Mul (view_3, scalar_tensor_default_2)
   }, else_branch: graph = false_graph_0 () => (float[1] neg_false_graph_0) {
      [node_Neg_0] neg_false_graph_0 = Neg (view_3)
   }>
}

Total running time of the script: (0 minutes 3.240 seconds)

Gallery generated by Sphinx-Gallery