onnx_diagnostic.torch_export_patches.patch_module¶
- class onnx_diagnostic.torch_export_patches.patch_module.RewriteControlFlow(prefix: str = 'branch_cond', skip_objects: Dict[str, object] | None = None, args_names: Set[str] | None = None, filter_node: Callable[[ast.Node], bool] | None = None, pre_rewriter: Callable[[ast.Node], ast.Node] | None = None, post_rewriter: Callable[[ast.Node], ast.Node] | None = None)[source][source]¶
The class rewrites tests with function
torch.cond()
.empty_tensor
is a function returning an empty tensor, when a branch returns something the other branch does not.- Parameters:
prefix – prefix used for nested tests
skip_objects – to skip variable names if included in that list such as modules
args_names – defines the local variables
filter_nodes – a function which is used to decide which node to rewrite, True by default
pre_rewriter – a rewriter applied before the automated rewriting
post_rewriter – a rewriter applied after the automated rewriting
- class onnx_diagnostic.torch_export_patches.patch_module.RewrittenMethod(tree, func)[source][source]¶
Stores a rewritten method using
onnx_diagnostic.torch_export_patches.patch_module.transform_method()
.- Parameters:
tree – ast tree
func – callable compiled from the tree
- class onnx_diagnostic.torch_export_patches.patch_module.ShapeFinder[source][source]¶
Finds <x> in the expression
x.shape[0]
.
- class onnx_diagnostic.torch_export_patches.patch_module.UsedVarsFinder[source][source]¶
Finds used and defined local variables with a section.
- onnx_diagnostic.torch_export_patches.patch_module.inplace_add_parent(tree: ast.Node)[source][source]¶
Adds parents to an AST tree.
- onnx_diagnostic.torch_export_patches.patch_module.make_diff(code1: str, code2: str, output: str | None = None) str [source][source]¶
Creates a diff between two codes.
- Parameters:
code1 – first code
code2 – second code
output – if not empty, stores the output in this file
- Returns:
diff
- onnx_diagnostic.torch_export_patches.patch_module.normalize_assignment_in_test(tree: ast.Node)[source][source]¶
Split AugAssign into BinOp and Assign to simplify whatever comes after.
- onnx_diagnostic.torch_export_patches.patch_module.transform_method(func: Callable, prefix: str = 'branch_cond', verbose: int = 0, filter_node: Callable[[ast.Node], bool] | None = None, pre_rewriter: Callable[[ast.Node], ast.Node] | None = None, post_rewriter: Callable[[ast.Node], ast.Node] | None = None) RewrittenMethod [source][source]¶
Returns a new function based on func where every test (if) is replaced by a call to
torch.cond()
. Some known rewriting are part of the default patches (see Control flow rewriting).A test must return the same things if it returns something or assign something. It cannot return in one branch and assign in the other branch.
Warning
room for improvement
When it assigns a value to a constant, the current implementation does check which ones is really used after the test. The rewritten local functions returns every assigned variable. This could be reduced. See method
_filter_target
.- Parameters:
func – method or function to rewrite
prefix – prefix used to create the functions for the branches
verbose – verbosity
filter_node – a function which tells which node to rewrite
pre_rewriter – a rewriter applied before the automated rewriting
post_rewriter – a rewriter applied after the automated rewriting
- Returns:
rewritten method
An example with return:
<<<
import torch from onnx_diagnostic.torch_export_patches.patch_module import transform_method class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: return x + y, x - y else: return torch.abs(x) + y, torch.abs(x) - y x, y = torch.rand((3, 4)), torch.rand((3, 4)) expected = Model()(x, y) rewritten = transform_method(Model.forward) print("-- code --") print(rewritten.code) print(" -- export --") Model.forward = rewritten.func DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) print(ep)
>>>
-- code -- def forward(self, x, y): def branch_cond_then_1(x, y): return (x + y, x - y) def branch_cond_else_1(x, y): return (torch.abs(x) + y, torch.abs(x) - y) return torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x, y]) -- export -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"): # sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, y)); gt = true_graph_0 = false_graph_0 = x = y = None getitem: "f32[s35, s16]" = cond[0] getitem_1: "f32[s35, s16]" = cond[1]; cond = None return (getitem, getitem_1) class true_graph_0(torch.nn.Module): def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"): # File: <eval_with_key>.0:7 in forward, code: child = l_args_3_0__1.add(l_args_3_1__1) add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y) # File: <eval_with_key>.0:8 in forward, code: child_1 = l_args_3_0__1.sub(l_args_3_1__1); l_args_3_0__1 = l_args_3_1__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(x, y); x = y = None return (add, sub) class false_graph_0(torch.nn.Module): def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"): # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_0__1) abs_1: "f32[s35, s16]" = torch.ops.aten.abs.default(x) # File: <eval_with_key>.1:8 in forward, code: child = abs_1.add(l_args_3_1__1); abs_1 = None add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(abs_1, y); abs_1 = None # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_0__1); l_args_3_0__1 = None abs_2: "f32[s35, s16]" = torch.ops.aten.abs.default(x); x = None # File: <eval_with_key>.1:10 in forward, code: child_1 = abs_2.sub(l_args_3_1__1); abs_2 = l_args_3_1__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(abs_2, y); abs_2 = y = None return (add, sub) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
An example with assignments:
<<<
import torch from onnx_diagnostic.torch_export_patches.patch_module import transform_method class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: w = x + y z = x - y else: w = torch.abs(x) + y z = torch.abs(x) - y return w, z x, y = torch.rand((3, 4)), torch.rand((3, 4)) expected = Model()(x, y) rewritten = transform_method(Model.forward) print("-- code --") print(rewritten.code) print(" -- export --") Model.forward = rewritten.func DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) print(ep)
>>>
-- code -- def forward(self, x, y): def branch_cond_then_1(y, x): w = x + y z = x - y return (w.clone(), z.clone()) def branch_cond_else_1(y, x): w = torch.abs(x) + y z = torch.abs(x) - y return (w.clone(), z.clone()) w, z = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [y, x]) return (w, z) -- export -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"): # sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (y, x)); gt = true_graph_0 = false_graph_0 = y = x = None getitem: "f32[s35, s16]" = cond[0] getitem_1: "f32[s35, s16]" = cond[1]; cond = None return (getitem, getitem_1) class true_graph_0(torch.nn.Module): def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"): # File: <eval_with_key>.0:7 in forward, code: w = l_args_3_1__1.add(l_args_3_0__1) add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y) # File: <eval_with_key>.0:8 in forward, code: z = l_args_3_1__1.sub(l_args_3_0__1); l_args_3_1__1 = l_args_3_0__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(x, y); x = y = None # File: <eval_with_key>.0:9 in forward, code: child = w.clone(); w = None clone: "f32[s35, s16]" = torch.ops.aten.clone.default(add); add = None # File: <eval_with_key>.0:10 in forward, code: child_1 = z.clone(); z = None clone_1: "f32[s35, s16]" = torch.ops.aten.clone.default(sub); sub = None return (clone, clone_1) class false_graph_0(torch.nn.Module): def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"): # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_1__1) abs_1: "f32[s35, s16]" = torch.ops.aten.abs.default(x) # File: <eval_with_key>.1:8 in forward, code: w = abs_1.add(l_args_3_0__1); abs_1 = None add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(abs_1, y); abs_1 = None # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_1__1); l_args_3_1__1 = None abs_2: "f32[s35, s16]" = torch.ops.aten.abs.default(x); x = None # File: <eval_with_key>.1:10 in forward, code: z = abs_2.sub(l_args_3_0__1); abs_2 = l_args_3_0__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(abs_2, y); abs_2 = y = None # File: <eval_with_key>.1:11 in forward, code: child = w.clone(); w = None clone: "f32[s35, s16]" = torch.ops.aten.clone.default(add); add = None # File: <eval_with_key>.1:12 in forward, code: child_1 = z.clone(); z = None clone_1: "f32[s35, s16]" = torch.ops.aten.clone.default(sub); sub = None return (clone, clone_1) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}