onnx_diagnostic.torch_export_patches.patch_module¶
- class onnx_diagnostic.torch_export_patches.patch_module.RewriteControlFlow(prefix: str = 'branch_cond', skip_objects: Dict[str, object] | None = None, args_names: Set[str] | None = None)[source][source]¶
The class rewrites tests with function
torch.cond()
.empty_tensor
is a function returning an empty tensor, when a branch returns something the other branch does not.
- class onnx_diagnostic.torch_export_patches.patch_module.RewrittenMethod(tree, func)[source][source]¶
Stores a rewritten method using
onnx_diagnostic.torch_export_patches.patch_module.transform_method()
.- Parameters:
tree – ast tree
func – callable compiled from the tree
- onnx_diagnostic.torch_export_patches.patch_module.inplace_add_parent(tree: ast.Node)[source][source]¶
Adds parents to an AST tree.
- onnx_diagnostic.torch_export_patches.patch_module.normalize_assignment_in_test(tree: ast.Node)[source][source]¶
Split AugAssign into BinOp and Assign to simplify whatever comes after.
- onnx_diagnostic.torch_export_patches.patch_module.transform_method(func: Callable, prefix: str = 'branch_cond', verbose: int = 0) RewrittenMethod [source][source]¶
Returns a new function based on func where every test (if) is replaced by a call to
torch.cond()
.A test must return the same things if it returns something or assign something. It cannot return in one branch and assign in the other branch.
Warning
room for improvement
When it assigns a value to a constant, the current implementation does check which ones is really used after the test. The rewritten local functions returns every assigned variable. This could be reduced. See method
_filter_target
.- Parameters:
func – method or function to rewrite
prefix – prefix used to create the functions for the branches
verbose – verbosity
- Returns:
rewritten method
An example with return:
<<<
import torch from onnx_diagnostic.torch_export_patches.patch_module import transform_method class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: return x + y, x - y else: return torch.abs(x) + y, torch.abs(x) - y x, y = torch.rand((3, 4)), torch.rand((3, 4)) expected = Model()(x, y) rewritten = transform_method(Model.forward) print("-- code --") print(rewritten.code) print(" -- export --") Model.forward = rewritten.func DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) print(ep)
>>>
-- code -- def forward(self, x, y): def branch_cond_then_1(y, x): return (x + y, x - y) def branch_cond_else_1(y, x): return (torch.abs(x) + y, torch.abs(x) - y) return torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [y, x]) -- export -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"): # sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (y, x)); gt = true_graph_0 = false_graph_0 = y = x = None getitem: "f32[s35, s16]" = cond[0] getitem_1: "f32[s35, s16]" = cond[1]; cond = None return (getitem, getitem_1) class true_graph_0(torch.nn.Module): def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"): # File: <eval_with_key>.0:7 in forward, code: child = l_args_3_1__1.add(l_args_3_0__1) add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y) # File: <eval_with_key>.0:8 in forward, code: child_1 = l_args_3_1__1.sub(l_args_3_0__1); l_args_3_1__1 = l_args_3_0__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(x, y); x = y = None return (add, sub) class false_graph_0(torch.nn.Module): def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"): # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_1__1) abs_1: "f32[s35, s16]" = torch.ops.aten.abs.default(x) # File: <eval_with_key>.1:8 in forward, code: child = abs_1.add(l_args_3_0__1); abs_1 = None add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(abs_1, y); abs_1 = None # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_1__1); l_args_3_1__1 = None abs_2: "f32[s35, s16]" = torch.ops.aten.abs.default(x); x = None # File: <eval_with_key>.1:10 in forward, code: child_1 = abs_2.sub(l_args_3_0__1); abs_2 = l_args_3_0__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(abs_2, y); abs_2 = y = None return (add, sub) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
An example with assignments:
<<<
import torch from onnx_diagnostic.torch_export_patches.patch_module import transform_method class Model(torch.nn.Module): def forward(self, x, y): if x.sum() > 0: w = x + y z = x - y else: w = torch.abs(x) + y z = torch.abs(x) - y return w, z x, y = torch.rand((3, 4)), torch.rand((3, 4)) expected = Model()(x, y) rewritten = transform_method(Model.forward) print("-- code --") print(rewritten.code) print(" -- export --") Model.forward = rewritten.func DYN = torch.export.Dim.DYNAMIC ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN}) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds) print(ep)
>>>
-- code -- def forward(self, x, y): def branch_cond_then_1(y, x): w = x + y z = x - y return (w, z) def branch_cond_else_1(y, x): w = torch.abs(x) + y z = torch.abs(x) - y return (w, z) w, z = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [y, x]) return (w, z) -- export -- ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"): # sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (y, x)); gt = true_graph_0 = false_graph_0 = y = x = None getitem: "f32[s35, s16]" = cond[0] getitem_1: "f32[s35, s16]" = cond[1]; cond = None return (getitem, getitem_1) class true_graph_0(torch.nn.Module): def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"): # File: <eval_with_key>.0:7 in forward, code: w = l_args_3_1__1.add(l_args_3_0__1) add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y) # File: <eval_with_key>.0:8 in forward, code: z = l_args_3_1__1.sub(l_args_3_0__1); l_args_3_1__1 = l_args_3_0__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(x, y); x = y = None return (add, sub) class false_graph_0(torch.nn.Module): def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"): # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_1__1) abs_1: "f32[s35, s16]" = torch.ops.aten.abs.default(x) # File: <eval_with_key>.1:8 in forward, code: w = abs_1.add(l_args_3_0__1); abs_1 = None add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(abs_1, y); abs_1 = None # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_1__1); l_args_3_1__1 = None abs_2: "f32[s35, s16]" = torch.ops.aten.abs.default(x); x = None # File: <eval_with_key>.1:10 in forward, code: z = abs_2.sub(l_args_3_0__1); abs_2 = l_args_3_0__1 = None sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(abs_2, y); abs_2 = y = None return (add, sub) Graph signature: # inputs x: USER_INPUT y: USER_INPUT # outputs getitem: USER_OUTPUT getitem_1: USER_OUTPUT Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}