torch.onnx.export and a model with a test

Tests cannot be exported into ONNX unless they refactored to use torch.cond().

A model with a test

from onnx.printer import to_text
import torch

We define a model with a control flow (-> graph break)

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlowTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlowTest()

Let’s check it runs.

x = torch.randn(3)
model(x)
tensor([0.6760], grad_fn=<MulBackward0>)

As expected, it does not export.

try:
    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
    print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:24 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None




def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:24 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:973 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py", line 24, in forward
    if x.sum():


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

It does export with torch.onnx.export() because it uses JIT to trace the execution. But the model is not exactly the same as the initial model.

ep = torch.onnx.export(model, (x,), dynamo=True)
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...



def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:24 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None




def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:24 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=True)`...
class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_

         # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:39 in forward, code: out = self.mlp(x)
        l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_);  l_x_ = None
        l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0);  l__self___mlp_0 = None

         # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:24 in forward, code: if x.sum():
        sum_1: "f32[][]cpu" = l__self___mlp_1.sum();  l__self___mlp_1 = sum_1 = None

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_

         # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:39 in forward, code: out = self.mlp(x)
        l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_);  l_x_ = None
        l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0);  l__self___mlp_0 = None

         # File: ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:24 in forward, code: if x.sum():
        sum_1: "f32[][]cpu" = l__self___mlp_1.sum();  l__self___mlp_1 = sum_1 = None

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=True)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export draft_export`...
[torch.onnx] Draft Export report:

###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################

1. Data dependent error.
    When exporting, we were unable to evaluate the value of `Eq(u0, 1)`.
    This was encountered 1 times.
    This occurred at the following user stacktrace:
        File ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py, lineno 1767, in _wrapped_call_impl
        File ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py, lineno 1778, in _call_impl
        File ~/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py, lineno 24, in forward
            if x.sum():

        Locals:
            x: ['Tensor(shape: torch.Size([1]), stride: (1,), storage_offset: 0)']

    And the following framework stacktrace:
        File ~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py, lineno 1326, in __torch_function__
        File ~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py, lineno 1373, in __torch_function__
            return func(*args, **kwargs)

    As a result, it was specialized to a constant (e.g. `1` in the 1st occurrence), and asserts were inserted into the graph.

    Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
    Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.


[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export draft_export`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
   ir_version: 10,
   opset_import: ["" : 18],
   producer_name: "pytorch",
   producer_version: "2.8.0.dev20250519+cu126"
>
main_graph (float[3] x) => (float[1] mul)
   <float[2] "mlp.0.bias" =  {-0.403322,-0.409063}, float[1] "mlp.1.bias" =  {0.19677}, float[2] "mlp.0.bias", float[1] "mlp.1.bias", float[3,2] val_0, float[2] val_1, float[2] linear, float[2,1] val_2, float[1] val_3, float[1] linear_1, float scalar_tensor_default>
{
   [node_Constant_9] val_0 = Constant <value: tensor = float[3,2] val_0 {0.284533,0.244287,0.0663717,0.37044,-0.510749,-0.546452}> ()
   [node_MatMul_1] val_1 = MatMul (x, val_0)
   [node_Add_2] linear = Add (val_1, "mlp.0.bias")
   [node_Constant_10] val_2 = Constant <value: tensor = float[2,1] val_2 {0.363914,-0.449402}> ()
   [node_MatMul_4] val_3 = MatMul (linear, val_2)
   [node_Add_5] linear_1 = Add (val_3, "mlp.1.bias")
   [node_Constant_11] scalar_tensor_default = Constant <value: tensor = float scalar_tensor_default {2}> ()
   [node_Mul_8] mul = Mul (linear_1, scalar_tensor_default)
}

Suggested Patch

Let’s avoid the graph break by replacing the forward.

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

Let’s see what the fx graph looks like.

print(torch.export.export(model, (x,)).graph)
graph():
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%linear_1,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

Let’s export again.

ep = torch.onnx.export(model, (x,), dynamo=True)
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
   ir_version: 10,
   opset_import: ["" : 18],
   producer_name: "pytorch",
   producer_version: "2.8.0.dev20250519+cu126"
>
main_graph (float[3] x) => (float[1] getitem)
   <float[2] "mlp.0.bias" =  {-0.403322,-0.409063}, float[1] "mlp.1.bias" =  {0.19677}, float[2] "mlp.0.bias", float[1] "mlp.1.bias", float[3,2] val_0, float[2] val_1, float[2] linear, float[2,1] val_2, float[1] val_3, float[1] linear_1, float sum_1, float scalar_tensor_default, bool gt>
{
   [node_Constant_11] val_0 = Constant <value: tensor = float[3,2] val_0 {0.284533,0.244287,0.0663717,0.37044,-0.510749,-0.546452}> ()
   [node_MatMul_1] val_1 = MatMul (x, val_0)
   [node_Add_2] linear = Add (val_1, "mlp.0.bias")
   [node_Constant_12] val_2 = Constant <value: tensor = float[2,1] val_2 {0.363914,-0.449402}> ()
   [node_MatMul_4] val_3 = MatMul (linear, val_2)
   [node_Add_5] linear_1 = Add (val_3, "mlp.1.bias")
   [node_ReduceSum_6] sum_1 = ReduceSum <noop_with_empty_axes: int = 0, keepdims: int = 0> (linear_1)
   [node_Constant_13] scalar_tensor_default = Constant <value: tensor = float scalar_tensor_default {0}> ()
   [node_Greater_9] gt = Greater (sum_1, scalar_tensor_default)
   [node_If_10] getitem = If (gt) <then_branch: graph = true_graph_0 () => (float[1] mul_true_graph_0)
      <float scalar_tensor_default_2>
{
      [node_Constant_1] scalar_tensor_default_2 = Constant <value: tensor = float scalar_tensor_default_2 {2}> ()
      [node_Mul_2] mul_true_graph_0 = Mul (linear_1, scalar_tensor_default_2)
   }, else_branch: graph = false_graph_0 () => (float[1] neg_false_graph_0) {
      [node_Neg_0] neg_false_graph_0 = Neg (linear_1)
   }>
}

Let’s optimize to see a small model.

ep = torch.onnx.export(model, (x,), dynamo=True)
ep.optimize()
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
   ir_version: 10,
   opset_import: ["" : 18],
   producer_name: "pytorch",
   producer_version: "2.8.0.dev20250519+cu126"
>
main_graph (float[3] x) => (float[1] getitem)
   <float[2] "mlp.0.bias" =  {-0.403322,-0.409063}, float[1] "mlp.1.bias" =  {0.19677}, float[2] "mlp.0.bias", float[1] "mlp.1.bias", float[3,2] val_0, float[2] val_1, float[2] linear, float[2,1] val_2, float[1] val_3, float[1] linear_1, float sum_1, float scalar_tensor_default, bool gt>
{
   [node_Constant_11] val_0 = Constant <value: tensor = float[3,2] val_0 {0.284533,0.244287,0.0663717,0.37044,-0.510749,-0.546452}> ()
   [node_MatMul_1] val_1 = MatMul (x, val_0)
   [node_Add_2] linear = Add (val_1, "mlp.0.bias")
   [node_Constant_12] val_2 = Constant <value: tensor = float[2,1] val_2 {0.363914,-0.449402}> ()
   [node_MatMul_4] val_3 = MatMul (linear, val_2)
   [node_Add_5] linear_1 = Add (val_3, "mlp.1.bias")
   [node_ReduceSum_6] sum_1 = ReduceSum <noop_with_empty_axes: int = 0, keepdims: int = 0> (linear_1)
   [node_Constant_13] scalar_tensor_default = Constant <value: tensor = float scalar_tensor_default {0}> ()
   [node_Greater_9] gt = Greater (sum_1, scalar_tensor_default)
   [node_If_10] getitem = If (gt) <then_branch: graph = true_graph_0 () => (float[1] mul_true_graph_0)
      <float scalar_tensor_default_2>
{
      [node_Constant_1] scalar_tensor_default_2 = Constant <value: tensor = float scalar_tensor_default_2 {2}> ()
      [node_Mul_2] mul_true_graph_0 = Mul (linear_1, scalar_tensor_default_2)
   }, else_branch: graph = false_graph_0 () => (float[1] neg_false_graph_0) {
      [node_Neg_0] neg_false_graph_0 = Neg (linear_1)
   }>
}

Total running time of the script: (0 minutes 3.326 seconds)

Related examples

to_onnx and a model with a test

to_onnx and a model with a test

torch.onnx.export and a custom operator inplace

torch.onnx.export and a custom operator inplace

torch.onnx.export and a custom operator registered with a function

torch.onnx.export and a custom operator registered with a function

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export and padding one dimension to a mulitple of a constant

torch.onnx.export and padding one dimension to a mulitple of a constant

Gallery generated by Sphinx-Gallery