Export a model with a control flow (If)

Control flow cannot be exported with a change. The code of the model can be changed or patched to introduce function torch.cond().

A model with a test

import torch
from onnx_diagnostic import doc

We define a model with a control flow (-> graph break)

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x

class ModelWithControlFlow(torch.nn.Module):
    def __init__(self):
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),

    def forward(self, x):
        out = self.mlp(x)
        return out

model = ModelWithControlFlow()

Let’s check it runs.

x = torch.randn(1, 3)
tensor([[-0.8528]], grad_fn=<MulBackward0>)

As expected, it does not export.

    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[1, 3]"):
     # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1, 1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_cond.py:25 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[1, 3]"):
     # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1, 1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_cond.py:25 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:689 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_cond.py", line 25, in forward
    if x.sum():

Suggested Patch

Let’s avoid the graph break by replacing the forward.

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))

print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlow'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

Let’s see what the fx graph looks like.

ep = torch.export.export(model, (x,))
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%linear_1,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)
doc.plot_legend("If -> torch.cond", "torch.export.export", "tomato")
plot export cond

