Note
Go to the end to download the full example code.
torch.onnx.export and a model with a test¶
Control flow cannot be exported with a change.
The code of the model can be changed or patched
to introduce function torch.cond()
.
A model with a test¶
from onnx.printer import to_text
import torch
We define a model with a control flow (-> graph break)
class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x
class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)
def forward(self, x):
out = self.mlp(x)
return out
model = ModelWithControlFlowTest()
Let’s check it runs.
x = torch.randn(3)
model(x)
tensor([1.6173], grad_fn=<MulBackward0>)
As expected, it does not export.
try:
torch.export.export(model, (x,))
raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
print(e)
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py", line 40, in forward
out = self.mlp(x)
File "/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py", line 25, in forward
if x.sum():
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
It does export with torch.onnx.export because it uses JIT to trace the execution. But the model is not exactly the same as the initial model.
ep = torch.onnx.export(model, (x,), dynamo=True)
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script...
/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_oe_cond.py:25: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if x.sum():
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... ✅
[torch.onnx] Run decomposition...
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:75: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
getattr_node = gm.graph.get_attr(lifted_node)
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1800: UserWarning: Node lifted_tensor_6 target lifted_tensor_6 lifted_tensor_6 of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version: 10,
opset_import: ["pkg.onnxscript.torch_lib.common" : 1, "" : 18],
producer_name: "pytorch",
producer_version: "2.6.0.dev20241128+cu124"
>
main_graph (float[3] input_1) => (float[1] mul)
<float[2] "model.mlp.0.bias" = {-0.288676,-0.368527}, float[2,3] "model.mlp.0.weight" = {-0.471258,0.324141,0.257305,0.0853913,-0.559284,-0.361027}, float[1] "model.mlp.1.bias" = {-0.145392}, float[1,2] "model.mlp.1.weight" = {0.501595,-0.663844}, float[1,3] view, float[3,2] t, float[1,2] addmm, float[2] view_1, float[1,2] view_2, float[2,1] t_1, float[1,1] addmm_1, float[1] view_3, float scalar_tensor_default>
{
[node_Constant_0] val_0 = Constant <value: tensor = int64[2] {1,3}> ()
[node_Cast_1] val_1 = Cast <to: int = 7> (val_0)
[node_Reshape_2] view = Reshape <allowzero: int = 0> (input_1, val_1)
[node_Transpose_3] t = Transpose <perm: ints = [1, 0]> ("model.mlp.0.weight")
[node_Gemm_4] addmm = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view, t, "model.mlp.0.bias")
[node_Constant_5] val_2 = Constant <value: tensor = int64[1] {2}> ()
[node_Cast_6] val_3 = Cast <to: int = 7> (val_2)
[node_Reshape_7] view_1 = Reshape <allowzero: int = 0> (addmm, val_3)
[node_Constant_8] val_4 = Constant <value: tensor = int64[2] {1,2}> ()
[node_Cast_9] val_5 = Cast <to: int = 7> (val_4)
[node_Reshape_10] view_2 = Reshape <allowzero: int = 0> (view_1, val_5)
[node_Transpose_11] t_1 = Transpose <perm: ints = [1, 0]> ("model.mlp.1.weight")
[node_Gemm_12] addmm_1 = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view_2, t_1, "model.mlp.1.bias")
[node_Constant_13] val_6 = Constant <value: tensor = int64[1] {1}> ()
[node_Cast_14] val_7 = Cast <to: int = 7> (val_6)
[node_Reshape_15] view_3 = Reshape <allowzero: int = 0> (addmm_1, val_7)
[node_Constant_16] val_8 = Constant <value: tensor = int64 {2}> ()
[node_Cast_17] scalar_tensor_default = Cast <to: int = 1> (val_8)
[node_Mul_18] mul = Mul (view_3, scalar_tensor_default)
}
<
domain: "pkg.onnxscript.torch_lib.common",
opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
[n0] tmp = Shape (input)
[n1] return_val = Size (tmp)
}
<
domain: "pkg.onnxscript.torch_lib.common",
opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
[n0] tmp = Shape (input)
[n1] tmp_0 = Size (tmp)
[n2] tmp_1 = Constant <value_int: int = 0> ()
[n3] return_val = Equal (tmp_0, tmp_1)
}
Suggested Patch¶
Let’s avoid the graph break by replacing the forward.
def new_forward(x):
def identity2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward
the list of submodules
<class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>
Let’s see what the fx graph looks like.
print(torch.export.export(model, (x,)).graph)
graph():
%p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
%p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
%p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
%p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
%linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
%gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%linear_1]), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
return (getitem,)
Let’s export again.
ep = torch.onnx.export(model, (x,), dynamo=True)
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version: 10,
opset_import: ["pkg.onnxscript.torch_lib.common" : 1, "" : 18, "pkg.torch.__subgraph__" : 1],
producer_name: "pytorch",
producer_version: "2.6.0.dev20241128+cu124"
>
main_graph (float[3] x) => (float[1] getitem)
<float[2,3] "mlp.0.weight" = {-0.471258,0.324141,0.257305,0.0853913,-0.559284,-0.361027}, float[2] "mlp.0.bias" = {-0.288676,-0.368527}, float[1,2] "mlp.1.weight" = {0.501595,-0.663844}, float[1] "mlp.1.bias" = {-0.145392}, float[1,3] view, float[3,2] t, float[1,2] addmm, float[2] view_1, float[1,2] view_2, float[2,1] t_1, float[1,1] addmm_1, float[1] view_3, float sum_1, float scalar_tensor_default, bool gt>
{
[node_Constant_0] val_0 = Constant <value: tensor = int64[2] {1,3}> ()
[node_Cast_1] val_1 = Cast <to: int = 7> (val_0)
[node_Reshape_2] view = Reshape <allowzero: int = 0> (x, val_1)
[node_Transpose_3] t = Transpose <perm: ints = [1, 0]> ("mlp.0.weight")
[node_Gemm_4] addmm = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view, t, "mlp.0.bias")
[node_Constant_5] val_2 = Constant <value: tensor = int64[1] {2}> ()
[node_Cast_6] val_3 = Cast <to: int = 7> (val_2)
[node_Reshape_7] view_1 = Reshape <allowzero: int = 0> (addmm, val_3)
[node_Constant_8] val_4 = Constant <value: tensor = int64[2] {1,2}> ()
[node_Cast_9] val_5 = Cast <to: int = 7> (val_4)
[node_Reshape_10] view_2 = Reshape <allowzero: int = 0> (view_1, val_5)
[node_Transpose_11] t_1 = Transpose <perm: ints = [1, 0]> ("mlp.1.weight")
[node_Gemm_12] addmm_1 = Gemm <beta: float = 1, transB: int = 0, alpha: float = 1, transA: int = 0> (view_2, t_1, "mlp.1.bias")
[node_Constant_13] val_6 = Constant <value: tensor = int64[1] {1}> ()
[node_Cast_14] val_7 = Cast <to: int = 7> (val_6)
[node_Reshape_15] view_3 = Reshape <allowzero: int = 0> (addmm_1, val_7)
[node_ReduceSum_16] sum_1 = ReduceSum <noop_with_empty_axes: int = 0, keepdims: int = 0> (view_3)
[node_Constant_17] val_8 = Constant <value: tensor = int64 {0}> ()
[node_Cast_18] scalar_tensor_default = Cast <to: int = 1> (val_8)
[node_Greater_19] gt = Greater (sum_1, scalar_tensor_default)
[node_If_20] getitem = If (gt) <then_branch: graph = true_graph_0 () => ( mul_true_graph_0) {
[node_true_graph_0_0] mul_true_graph_0 = pkg.torch.__subgraph__.true_graph_0 (view_3)
}, else_branch: graph = false_graph_0 () => ( neg_false_graph_0) {
[node_false_graph_0_0] neg_false_graph_0 = pkg.torch.__subgraph__.false_graph_0 (view_3)
}>
}
<
domain: "pkg.torch.__subgraph__",
opset_import: ["" : 18]
>
false_graph_0 (view_3) => (neg)
{
[node_Neg_0] neg = Neg (view_3)
}
<
domain: "pkg.torch.__subgraph__",
opset_import: ["" : 18]
>
true_graph_0 (view_3) => (mul)
{
[node_Constant_0] val_0 = Constant <value: tensor = int64 {2}> ()
[node_Cast_1] scalar_tensor_default = Cast <to: int = 1> (val_0)
[node_Mul_2] mul = Mul (view_3, scalar_tensor_default)
}
<
domain: "pkg.onnxscript.torch_lib.common",
opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
[n0] tmp = Shape (input)
[n1] return_val = Size (tmp)
}
<
domain: "pkg.onnxscript.torch_lib.common",
opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
[n0] tmp = Shape (input)
[n1] tmp_0 = Size (tmp)
[n2] tmp_1 = Constant <value_int: int = 0> ()
[n3] return_val = Equal (tmp_0, tmp_1)
}
Let’s optimize to see a small model.
ep = torch.onnx.export(model, (x,), dynamo=True)
ep.optimize()
print(to_text(ep.model_proto))
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 2 of general pattern rewrite rules.
<
ir_version: 10,
opset_import: ["pkg.onnxscript.torch_lib.common" : 1, "" : 18, "pkg.torch.__subgraph__" : 1],
producer_name: "pytorch",
producer_version: "2.6.0.dev20241128+cu124"
>
main_graph (float[3] x) => (float[1] getitem)
<float[2] "mlp.0.bias" = {-0.288676,-0.368527}, float[1] "mlp.1.bias" = {-0.145392}, float[3,2] t, float[2] val_9, float[2] view_1, float[2,1] t_1, float[1] val_10, float[1] view_3, float sum_1, float scalar_tensor_default, bool gt>
{
[node_Constant_23] t = Constant <value: tensor = float[3,2] t {-0.471258,0.0853913,0.324141,-0.559284,0.257305,-0.361027}> ()
[node_MatMul_32] val_9 = MatMul (x, t)
[node_Add_33] view_1 = Add (val_9, "mlp.0.bias")
[node_Constant_28] t_1 = Constant <value: tensor = float[2,1] t_1 {0.501595,-0.663844}> ()
[node_MatMul_34] val_10 = MatMul (view_1, t_1)
[node_Add_35] view_3 = Add (val_10, "mlp.1.bias")
[node_ReduceSum_16] sum_1 = ReduceSum <noop_with_empty_axes: int = 0, keepdims: int = 0> (view_3)
[node_Constant_31] scalar_tensor_default = Constant <value: tensor = float scalar_tensor_default {0}> ()
[node_Greater_19] gt = Greater (sum_1, scalar_tensor_default)
[node_If_20] getitem = If (gt) <then_branch: graph = true_graph_0 () => (float[1] mul_true_graph_0)
<float scalar_tensor_default_2>
{
[node_Constant_1] scalar_tensor_default_2 = Constant <value: tensor = float scalar_tensor_default_2 {2}> ()
[node_Mul_2] mul_true_graph_0 = Mul (view_3, scalar_tensor_default_2)
}, else_branch: graph = false_graph_0 () => (float[1] neg_false_graph_0) {
[node_Neg_0] neg_false_graph_0 = Neg (view_3)
}>
}
Total running time of the script: (0 minutes 3.240 seconds)