onnx_diagnostic.torch_export_patches.patch_module¶
- class onnx_diagnostic.torch_export_patches.patch_module.RewriteControlFlow(prefix: str = 'branch_cond', skip_objects: Dict[str, object] | None = None, args_names: Set[str] | None = None, filter_node: Callable[[ast.Node], bool] | None = None, pre_rewriter: Callable[[ast.Node], ast.Node] | None = None, post_rewriter: Callable[[ast.Node], ast.Node] | None = None)[source][source]¶
- The class rewrites tests with function - torch.cond().- empty_tensoris a function returning an empty tensor, when a branch returns something the other branch does not.- Parameters:
- prefix – prefix used for nested tests 
- skip_objects – to skip variable names if included in that list such as modules 
- args_names – defines the local variables 
- filter_nodes – a function which is used to decide which node to rewrite, True by default 
- pre_rewriter – a rewriter applied before the automated rewriting 
- post_rewriter – a rewriter applied after the automated rewriting 
 
 
- class onnx_diagnostic.torch_export_patches.patch_module.RewrittenMethod(tree, func)[source][source]¶
- Stores a rewritten method using - onnx_diagnostic.torch_export_patches.patch_module.transform_method().- Parameters:
- tree – ast tree 
- func – callable compiled from the tree 
 
 
- class onnx_diagnostic.torch_export_patches.patch_module.ShapeFinder[source][source]¶
- Finds <x> in the expression - x.shape[0].
- class onnx_diagnostic.torch_export_patches.patch_module.UsedVarsFinder[source][source]¶
- Finds used and defined local variables with a section. 
- onnx_diagnostic.torch_export_patches.patch_module.inplace_add_parent(tree: ast.Node)[source][source]¶
- Adds parents to an AST tree. 
- onnx_diagnostic.torch_export_patches.patch_module.make_diff(code1: str, code2: str, output: str | None = None) str[source][source]¶
- Creates a diff between two codes. - Parameters:
- code1 – first code 
- code2 – second code 
- output – if not empty, stores the output in this file 
 
- Returns:
- diff 
 
- onnx_diagnostic.torch_export_patches.patch_module.normalize_assignment_in_test(tree: ast.Node)[source][source]¶
- Split AugAssign into BinOp and Assign to simplify whatever comes after. 
- onnx_diagnostic.torch_export_patches.patch_module.transform_method(func: Callable, prefix: str = 'branch_cond', verbose: int = 0, filter_node: Callable[[ast.Node], bool] | None = None, pre_rewriter: Callable[[ast.Node], ast.Node] | None = None, post_rewriter: Callable[[ast.Node], ast.Node] | None = None) RewrittenMethod[source][source]¶
- Returns a new function based on func where every test (if) is replaced by a call to - torch.cond(). Some known rewriting are part of the default patches (see Control flow rewriting).- A test must return the same things if it returns something or assign something. It cannot return in one branch and assign in the other branch. - Warning - room for improvement - When it assigns a value to a constant, the current implementation does check which ones is really used after the test. The rewritten local functions returns every assigned variable. This could be reduced. See method - _filter_target.- Parameters:
- func – method or function to rewrite 
- prefix – prefix used to create the functions for the branches 
- verbose – verbosity 
- filter_node – a function which tells which node to rewrite 
- pre_rewriter – a rewriter applied before the automated rewriting 
- post_rewriter – a rewriter applied after the automated rewriting 
 
- Returns:
- rewritten method 
 - An example with return: - <<< - import torch from onnx_diagnostic.torch_export_patches.patch_module import transform_method class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: return x + y, x - y else: return torch.abs(x) + y, torch.abs(x) - y x, y = torch.rand((3, 4)), torch.rand((3, 4)) expected = Model()(x, y) rewritten = transform_method(Model.forward) print("-- code --") print(rewritten.code) print(" -- export --") Model.forward = rewritten.func DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) print(ep) - >>> - -- code -- def forward(self, x, y): def branch_cond_then_1(x, y): return (x + y, x - y) def branch_cond_else_1(x, y): return (torch.abs(x) + y, torch.abs(x) - y) return torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x, y]) -- export -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s17, s27]", y: "f32[s17, s27]"): # sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, y)); gt = true_graph_0 = false_graph_0 = x = y = None getitem: "f32[s17, s27]" = cond[0] getitem_1: "f32[s17, s27]" = cond[1]; cond = None return (getitem, getitem_1) class true_graph_0(torch.nn.Module): def forward(self, x: "f32[s17, s27]", y: "f32[s17, s27]"): # File: <eval_with_key>.0:7 in forward, code: child = l_args_3_0__1.add(l_args_3_1__1) add: "f32[s17, s27]" = torch.ops.aten.add.Tensor(x, y) # File: <eval_with_key>.0:8 in forward, code: child_1 = l_args_3_0__1.sub(l_args_3_1__1); l_args_3_0__1 = l_args_3_1__1 = None sub: "f32[s17, s27]" = torch.ops.aten.sub.Tensor(x, y); x = y = None return (add, sub) class false_graph_0(torch.nn.Module): def forward(self, x: "f32[s17, s27]", y: "f32[s17, s27]"): # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_0__1) abs_1: "f32[s17, s27]" = torch.ops.aten.abs.default(x) # File: <eval_with_key>.1:8 in forward, code: child = abs_1.add(l_args_3_1__1); abs_1 = None add: "f32[s17, s27]" = torch.ops.aten.add.Tensor(abs_1, y); abs_1 = None # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_0__1); l_args_3_0__1 = None abs_2: "f32[s17, s27]" = torch.ops.aten.abs.default(x); x = None # File: <eval_with_key>.1:10 in forward, code: child_1 = abs_2.sub(l_args_3_1__1); abs_2 = l_args_3_1__1 = None sub: "f32[s17, s27]" = torch.ops.aten.sub.Tensor(abs_2, y); abs_2 = y = None return (add, sub) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo]} - An example with assignments: - <<< - import torch from onnx_diagnostic.torch_export_patches.patch_module import transform_method class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: w = x + y z = x - y else: w = torch.abs(x) + y z = torch.abs(x) - y return w, z x, y = torch.rand((3, 4)), torch.rand((3, 4)) expected = Model()(x, y) rewritten = transform_method(Model.forward) print("-- code --") print(rewritten.code) print(" -- export --") Model.forward = rewritten.func DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) print(ep) - >>> - -- code -- def forward(self, x, y): def branch_cond_then_1(y, x): w = x + y z = x - y return (w.clone(), z.clone()) def branch_cond_else_1(y, x): w = torch.abs(x) + y z = torch.abs(x) - y return (w.clone(), z.clone()) w, z = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [y, x]) return (w, z) -- export -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s17, s27]", y: "f32[s17, s27]"): # sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (y, x)); gt = true_graph_0 = false_graph_0 = y = x = None getitem: "f32[s17, s27]" = cond[0] getitem_1: "f32[s17, s27]" = cond[1]; cond = None return (getitem, getitem_1) class true_graph_0(torch.nn.Module): def forward(self, y: "f32[s17, s27]", x: "f32[s17, s27]"): # File: <eval_with_key>.0:7 in forward, code: w = l_args_3_1__1.add(l_args_3_0__1) add: "f32[s17, s27]" = torch.ops.aten.add.Tensor(x, y) # File: <eval_with_key>.0:8 in forward, code: z = l_args_3_1__1.sub(l_args_3_0__1); l_args_3_1__1 = l_args_3_0__1 = None sub: "f32[s17, s27]" = torch.ops.aten.sub.Tensor(x, y); x = y = None # File: <eval_with_key>.0:9 in forward, code: child = w.clone(); w = None clone: "f32[s17, s27]" = torch.ops.aten.clone.default(add); add = None # File: <eval_with_key>.0:10 in forward, code: child_1 = z.clone(); z = None clone_1: "f32[s17, s27]" = torch.ops.aten.clone.default(sub); sub = None return (clone, clone_1) class false_graph_0(torch.nn.Module): def forward(self, y: "f32[s17, s27]", x: "f32[s17, s27]"): # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_1__1) abs_1: "f32[s17, s27]" = torch.ops.aten.abs.default(x) # File: <eval_with_key>.1:8 in forward, code: w = abs_1.add(l_args_3_0__1); abs_1 = None add: "f32[s17, s27]" = torch.ops.aten.add.Tensor(abs_1, y); abs_1 = None # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_1__1); l_args_3_1__1 = None abs_2: "f32[s17, s27]" = torch.ops.aten.abs.default(x); x = None # File: <eval_with_key>.1:10 in forward, code: z = abs_2.sub(l_args_3_0__1); abs_2 = l_args_3_0__1 = None sub: "f32[s17, s27]" = torch.ops.aten.sub.Tensor(abs_2, y); abs_2 = y = None # File: <eval_with_key>.1:11 in forward, code: child = w.clone(); w = None clone: "f32[s17, s27]" = torch.ops.aten.clone.default(add); add = None # File: <eval_with_key>.1:12 in forward, code: child_1 = z.clone(); z = None clone_1: "f32[s17, s27]" = torch.ops.aten.clone.default(sub); sub = None return (clone, clone_1) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {s17: VR[2, int_oo], s27: VR[2, int_oo]}