Export a model with a control flow (If)

Control flow cannot be exported with a change. The code of the model can be changed or patched to introduce function torch.cond().

A model with a test

import torch

We define a model with a control flow (-> graph break)

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlow(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlow()

Let’s check it runs.

x = torch.randn(1, 3)
model(x)
tensor([[0.1805]], grad_fn=<MulBackward0>)

As expected, it does not export.

try:
    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
    print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[1, 3]"):
     # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1, 1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_cond.py:24 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:683 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_cond.py", line 24, in forward
    if x.sum():

Suggested Patch

Let’s avoid the graph break by replacing the forward.

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlow'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

Let’s see what the fx graph looks like.

ep = torch.export.export(model, (x,))
print(ep.graph)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
graph():
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%linear_1]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

Total running time of the script: (0 minutes 5.919 seconds)

Related examples

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Export with DynamicCache and dynamic shapes

Export with DynamicCache and dynamic shapes

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery