Note
Go to the end to download the full example code.
to_onnx and a model with a test¶
Control flow cannot be exported with a change.
The code of the model can be changed or patched
to introduce function torch.cond()
.
A model with a test¶
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.helpers import pretty_onnx
from experimental_experiment.torch_interpreter import to_onnx
We define a model with a control flow (-> graph break)
class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x
class ModelWithControlFlow(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)
def forward(self, x):
out = self.mlp(x)
return out
model = ModelWithControlFlow()
Let’s check it runs.
x = torch.randn(1, 3)
model(x)
tensor([[5.4273]], grad_fn=<MulBackward0>)
As expected, it does not export.
try:
torch.export.export(model, (x,))
raise AssertionError("This export should failed unless pytorch now supports this model.")
except Exception as e:
print(e)
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 42, in forward
out = self.mlp(x)
File "/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 27, in forward
if x.sum():
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
The exporter fails with the same eror as it expects torch.export.export to work.
Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 42, in forward
out = self.mlp(x)
File "/home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_recipes_c_cond.py", line 27, in forward
if x.sum():
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Suggested Patch¶
Let’s avoid the graph break by replacing the forward.
def new_forward(x):
def identity2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward
the list of submodules
<class '__main__.ModelWithControlFlow'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>
Let’s see what the fx graph looks like.
print(torch.export.export(model, (x,)).graph)
graph():
%p_mlp_0_weight : [num_users=1] = placeholder[target=p_mlp_0_weight]
%p_mlp_0_bias : [num_users=1] = placeholder[target=p_mlp_0_bias]
%p_mlp_1_weight : [num_users=1] = placeholder[target=p_mlp_1_weight]
%p_mlp_1_bias : [num_users=1] = placeholder[target=p_mlp_1_bias]
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_mlp_0_weight, %p_mlp_0_bias), kwargs = {})
%linear_1 : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%linear, %p_mlp_1_weight, %p_mlp_1_bias), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%linear_1,), kwargs = {})
%gt : [num_users=1] = call_function[target=torch.ops.aten.gt.Scalar](args = (%sum_1, 0), kwargs = {})
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
%cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%gt, %true_graph_0, %false_graph_0, [%linear_1]), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%cond, 0), kwargs = {})
return (getitem,)
Let’s export again.
onx = to_onnx(model, (x,))
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='x' type=dtype('float32') shape=[1, 3]
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='mlp.0.weight' type=dtype('float32') shape=(2, 3)
init: name='mlp.0.bias' type=dtype('float32') shape=(2,) -- array([0.39745384, 0.5367278 ], dtype=float32)
init: name='mlp.1.weight' type=dtype('float32') shape=(1, 2) -- array([0.6619442 , 0.53932554], dtype=float32)
init: name='mlp.1.bias' type=dtype('float32') shape=(1,) -- array([0.44847998], dtype=float32)
Gemm(x, mlp.0.weight, mlp.0.bias, transB=1) -> linear
Gemm(linear, mlp.1.weight, mlp.1.bias, transB=1) -> linear_1
ReduceSum(linear_1, keepdims=0) -> sum_1
Greater(sum_1, init1_s_) -> gt
If(gt, else_branch=G1, then_branch=G2) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 1]
----- subgraph ---- If - aten_cond - att.else_branch=G1 -- level=1 -- -> cond#0
false_graph_0[local_functions](linear_1) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
----- subgraph ---- If - aten_cond - att.then_branch=G2 -- level=1 -- -> cond#0
true_graph_0[local_functions](linear_1) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
----- function name=true_graph_0 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
input: 'linear_1'
Constant(value=2.0) -> init1_s_
Constant(value=[1]) -> init7_s1_1
Reshape(init1_s_, init7_s1_1) -> _onx_reshape0
Mul(linear_1, _onx_reshape0) -> output_0
output: name='output_0' type=? shape=?
----- function name=false_graph_0 domain=local_functions
----- doc_string: function_options=FunctionOptions(export_as_function=Tru...
opset: domain='' version=18
opset: domain='local_functions' version=1
input: 'linear_1'
Neg(linear_1) -> output_0
output: name='output_0' type=? shape=?
We can also inline the local function.
onx = to_onnx(model, (x,), inline=True)
print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_functions' version=1
doc_string: large_model=False, inline=True, external_threshold=1024...
input: name='x' type=dtype('float32') shape=[1, 3]
init: name='init1_s_' type=dtype('float32') shape=() -- array([0.], dtype=float32)
init: name='mlp.0.weight' type=dtype('float32') shape=(2, 3)
init: name='mlp.0.bias' type=dtype('float32') shape=(2,) -- array([0.39745384, 0.5367278 ], dtype=float32)
init: name='mlp.1.weight' type=dtype('float32') shape=(1, 2) -- array([0.6619442 , 0.53932554], dtype=float32)
init: name='mlp.1.bias' type=dtype('float32') shape=(1,) -- array([0.44847998], dtype=float32)
Gemm(x, mlp.0.weight, mlp.0.bias, transB=1) -> linear
Gemm(linear, mlp.1.weight, mlp.1.bias, transB=1) -> linear_1
ReduceSum(linear_1, keepdims=0) -> sum_1
Greater(sum_1, init1_s_) -> gt
If(gt, else_branch=G1, then_branch=G2) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 1]
----- subgraph ---- If - aten_cond - att.else_branch=G1 -- level=1 -- -> cond#0
Neg(linear_1) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
----- subgraph ---- If - aten_cond - att.then_branch=G2 -- level=1 -- -> cond#0
Constant(value=[1]) -> init7_s1_122
Constant(value=2.0) -> init1_s_22
Reshape(init1_s_22, init7_s1_122) -> _onx_reshape032
Mul(linear_1, _onx_reshape032) -> cond#0
output: name='cond#0' type='NOTENSOR' shape=None
And visually.
<Axes: >
Total running time of the script: (0 minutes 7.225 seconds)