to_onnx and a model with a test

Control flow cannot be exported with a change. The code of the model can be changed or patched to introduce function torch.cond().

A model with a test

import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx

We define a model with a control flow (-> graph break)

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlow(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlow()

Let’s check it runs.

x = torch.randn(1, 3)
model(x)
tensor([[5.4273]], grad_fn=<MulBackward0>)

As expected, it does not export.

try:
    torch.export.export(model, (x,))
    raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
    print(e)
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 42, in forward
    out = self.mlp(x)
  File "/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 27, in forward
    if x.sum():

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

The exporter fails with the same eror as it expects torch.export.export to work.

try:
    to_onnx(model, (x,))
except Exception as e:
    print(e)
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 42, in forward
    out = self.mlp(x)
  File "/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 27, in forward
    if x.sum():

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Suggested Patch

Let’s avoid the graph break by replacing the forward.

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlow'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

Let’s see what the fx graph looks like.

print(torch.export.export(model, (x,)).graph)
graph():
    %p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
    %p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
    %p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
    %p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
    %linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
    %gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
    %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
    %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
    %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%linear_1]), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
    return (getitem,)

Let’s export again.

onx = to_onnx(model, (x,))
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 3]
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='mlp.0.weight' type=dtype('float32') shape=(2, 3)
init: name='mlp.0.bias' type=dtype('float32') shape=(2,) -- array([0.39745384, 0.5367278 ], dtype=float32)
init: name='mlp.1.weight' type=dtype('float32') shape=(1, 2) -- array([0.6619442 , 0.53932554], dtype=float32)
init: name='mlp.1.bias' type=dtype('float32') shape=(1,) -- array([0.44847998], dtype=float32)
Gemm(x, mlp.0.weight, mlp.0.bias, transB=1) -> linear
  Gemm(linear, mlp.1.weight, mlp.1.bias, transB=1) -> linear_1
    ReduceSum(linear_1, keepdims=0) -> sum_1
      Greater(sum_1, init1_s_) -> gt
        If(gt, else_branch=G1, then_branch=G2) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 1]
----- subgraph ---- If - aten_cond - att.else_branch=G1 -- level=1 --  -> cond#0
false_graph_0[local_functions](linear_1) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
----- subgraph ---- If - aten_cond - att.then_branch=G2 -- level=1 --  -> cond#0
true_graph_0[local_functions](linear_1) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
----- function name=true_graph_0 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'linear_1'
Constant(value=2.0) -> init1_s_
Constant(value=[1]) -> init7_s1_1
  Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
    Mul(linear_1, _onx_reshape0) -> output_0
output: name='output_0' type=? shape=?
----- function name=false_graph_0 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='local_functions' version=1
input: 'linear_1'
Neg(linear_1) -> output_0
output: name='output_0' type=? shape=?

We can also inline the local function.

onx = to_onnx(model, (x,), inline=True)
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=True, external_threshold=1024...
input: name='x' type=dtype('float32') shape=[1, 3]
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='mlp.0.weight' type=dtype('float32') shape=(2, 3)
init: name='mlp.0.bias' type=dtype('float32') shape=(2,) -- array([0.39745384, 0.5367278 ], dtype=float32)
init: name='mlp.1.weight' type=dtype('float32') shape=(1, 2) -- array([0.6619442 , 0.53932554], dtype=float32)
init: name='mlp.1.bias' type=dtype('float32') shape=(1,) -- array([0.44847998], dtype=float32)
Gemm(x, mlp.0.weight, mlp.0.bias, transB=1) -> linear
  Gemm(linear, mlp.1.weight, mlp.1.bias, transB=1) -> linear_1
    ReduceSum(linear_1, keepdims=0) -> sum_1
      Greater(sum_1, init1_s_) -> gt
        If(gt, else_branch=G1, then_branch=G2) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 1]
----- subgraph ---- If - aten_cond - att.else_branch=G1 -- level=1 --  -> cond#0
Neg(linear_1) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
----- subgraph ---- If - aten_cond - att.then_branch=G2 -- level=1 --  -> cond#0
Constant(value=[1]) -> init7_s1_122
Constant(value=2.0) -> init1_s_22
  Reshape(init1_s_22, init7_s1_122) -> _onx_reshape032
Mul(linear_1, _onx_reshape032) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None

And visually.

plot exporter recipes c cond
<Axes: >

Total running time of the script: (0 minutes 7.225 seconds)

Gallery generated by Sphinx-Gallery