Note
Go to the end to download the full example code.
Export a model with a control flow (If)¶
Control flow cannot be exported with a change.
The code of the model can be changed or patched
to introduce function torch.cond().
A model with a test¶
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import torch_export_rewrite
We define a model with a control flow (-> graph break)
class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        else:
            return -x
class ModelWithControlFlow(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )
    def forward(self, x):
        out = self.mlp(x)
        return out
model = ModelWithControlFlow()
Let’s check it runs.
x = torch.randn(1, 3)
model(x)
tensor([[0.4957]], grad_fn=<MulBackward0>)
As expected, it does not export.
try:
    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
    print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[1, 3]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1, 1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None
     # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_cond.py:26 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[1, 3]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[1, 2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1, 1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None
     # File: ~/github/onnx-diagnostic/_doc/recipes/plot_export_cond.py:26 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None
Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)
consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_trueCaused by: (_export/non_strict_utils.py:1118 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
  File "~/github/onnx-diagnostic/_doc/recipes/plot_export_cond.py", line 26, in forward
    if x.sum():
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Suggested Patch¶
Let’s avoid the graph break by replacing the forward.
def new_forward(x):
    def identity2(x):
        return x * 2
    def neg(x):
        return -x
    return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlow'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>
Let’s see what the fx graph looks like.
ep = torch.export.export(model, (x,))
print(ep.graph)
~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
  warnings.warn(
graph():
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%linear_1,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)
Automated Rewrite of the Control Flow¶
Functions torch_export_rewrite
or torch_export_patches
can automatically rewrite a method of a class or a function,
the method to rewrite is specified parameter rewrite.
It is experimental. The function contains options to
rewrite one test but not another one already supported by the exporter.
It may give a first version of the rewritten code if only a manual
rewriting can make the model exportable.
with torch_export_rewrite(rewrite=[ForwardWithControlFlowTest.forward], verbose=2) as f:
    ep = torch.export.export(model, (x,))
[torch_export_rewrite] rewrites method ForwardWithControlFlowTest.forward
[transform_method] -- source -- <function ForwardWithControlFlowTest.forward at 0x700f6d4c0540>
    def forward(self, x):
        if x.sum():
            return x * 2
        else:
            return -x
[transform_method] --
[transform_method] -- new code --
def forward(self, x):
    def branch_cond_then_1(x):
        return x * 2
    def branch_cond_else_1(x):
        return -x
    return torch.cond(x.sum(), branch_cond_then_1, branch_cond_else_1, [x])
[transform_method] --
~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
  warnings.warn(
[torch_export_rewrite] restored method ForwardWithControlFlowTest.forward
This gives:
print(ep.graph)
graph():
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, (%linear_1,)), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)
doc.plot_legend("If -> torch.cond", "torch.export.export", "yellowgreen")

Total running time of the script: (0 minutes 0.552 seconds)
Related examples
 
0, 1, 2 for a Dynamic Dimension in the dummy example to export a model
 
Cannot export torch.sym_max(x.shape[0], y.shape[0])
 
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints