onnx_diagnostic.torch_export_patches.patch_module

class onnx_diagnostic.torch_export_patches.patch_module.RewriteControlFlow(prefix: str = 'branch_cond', skip_objects: Dict[str, object] | None = None, args_names: Set[str] | None = None)[source][source]

The class rewrites tests with function torch.cond(). empty_tensor is a function returning an empty tensor, when a branch returns something the other branch does not.

generic_visit(node)[source][source]

Called if no explicit visitor function exists for a node.

class onnx_diagnostic.torch_export_patches.patch_module.RewrittenMethod(tree, func)[source][source]

Stores a rewritten method using onnx_diagnostic.torch_export_patches.patch_module.transform_method().

Parameters:
  • tree – ast tree

  • func – callable compiled from the tree

property code: str

Returns the source.

property dump: str

Returns the tree dumped as a string.

onnx_diagnostic.torch_export_patches.patch_module.inplace_add_parent(tree: ast.Node)[source][source]

Adds parents to an AST tree.

onnx_diagnostic.torch_export_patches.patch_module.normalize_assignment_in_test(tree: ast.Node)[source][source]

Split AugAssign into BinOp and Assign to simplify whatever comes after.

onnx_diagnostic.torch_export_patches.patch_module.transform_method(func: Callable, prefix: str = 'branch_cond', verbose: int = 0) RewrittenMethod[source][source]

Returns a new function based on func where every test (if) is replaced by a call to torch.cond().

A test must return the same things if it returns something or assign something. It cannot return in one branch and assign in the other branch.

Warning

room for improvement

When it assigns a value to a constant, the current implementation does check which ones is really used after the test. The rewritten local functions returns every assigned variable. This could be reduced. See method _filter_target.

Parameters:
  • func – method or function to rewrite

  • prefix – prefix used to create the functions for the branches

  • verbose – verbosity

Returns:

rewritten method

An example with return:

<<<

import torch
from onnx_diagnostic.torch_export_patches.patch_module import transform_method


class Model(torch.nn.Module):
    def forward(self, x, y):
        if x.sum() > 0:
            return x + y, x - y
        else:
            return torch.abs(x) + y, torch.abs(x) - y


x, y = torch.rand((3, 4)), torch.rand((3, 4))
expected = Model()(x, y)

rewritten = transform_method(Model.forward)
print("-- code --")
print(rewritten.code)

print(" -- export --")
Model.forward = rewritten.func

DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
print(ep)

>>>

    -- code --
    def forward(self, x, y):
    
        def branch_cond_then_1(y, x):
            return (x + y, x - y)
    
        def branch_cond_else_1(y, x):
            return (torch.abs(x) + y, torch.abs(x) - y)
        return torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [y, x])
     -- export --
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"):
                 # 
                sum_1: "f32[]" = torch.ops.aten.sum.default(x)
                gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
                
                 # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None
                true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (y, x));  gt = true_graph_0 = false_graph_0 = y = x = None
                getitem: "f32[s35, s16]" = cond[0]
                getitem_1: "f32[s35, s16]" = cond[1];  cond = None
                return (getitem, getitem_1)
                
            class true_graph_0(torch.nn.Module):
                def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"):
                     # File: <eval_with_key>.0:7 in forward, code: child = l_args_3_1__1.add(l_args_3_0__1)
                    add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y)
                    
                     # File: <eval_with_key>.0:8 in forward, code: child_1 = l_args_3_1__1.sub(l_args_3_0__1);  l_args_3_1__1 = l_args_3_0__1 = None
                    sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(x, y);  x = y = None
                    return (add, sub)
                    
            class false_graph_0(torch.nn.Module):
                def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"):
                     # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_1__1)
                    abs_1: "f32[s35, s16]" = torch.ops.aten.abs.default(x)
                    
                     # File: <eval_with_key>.1:8 in forward, code: child = abs_1.add(l_args_3_0__1);  abs_1 = None
                    add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(abs_1, y);  abs_1 = None
                    
                     # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_1__1);  l_args_3_1__1 = None
                    abs_2: "f32[s35, s16]" = torch.ops.aten.abs.default(x);  x = None
                    
                     # File: <eval_with_key>.1:10 in forward, code: child_1 = abs_2.sub(l_args_3_0__1);  abs_2 = l_args_3_0__1 = None
                    sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(abs_2, y);  abs_2 = y = None
                    return (add, sub)
                    
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        getitem_1: USER_OUTPUT
        
    Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}

An example with assignments:

<<<

import torch
from onnx_diagnostic.torch_export_patches.patch_module import transform_method


class Model(torch.nn.Module):
    def forward(self, x, y):
        if x.sum() > 0:
            w = x + y
            z = x - y
        else:
            w = torch.abs(x) + y
            z = torch.abs(x) - y
        return w, z


x, y = torch.rand((3, 4)), torch.rand((3, 4))
expected = Model()(x, y)

rewritten = transform_method(Model.forward)
print("-- code --")
print(rewritten.code)

print(" -- export --")
Model.forward = rewritten.func

DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds)
print(ep)

>>>

    -- code --
    def forward(self, x, y):
    
        def branch_cond_then_1(y, x):
            w = x + y
            z = x - y
            return (w, z)
    
        def branch_cond_else_1(y, x):
            w = torch.abs(x) + y
            z = torch.abs(x) - y
            return (w, z)
        w, z = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [y, x])
        return (w, z)
     -- export --
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[s35, s16]", y: "f32[s35, s16]"):
                 # 
                sum_1: "f32[]" = torch.ops.aten.sum.default(x)
                gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
                
                 # File: <eval_with_key>.3:10 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_, l_args_3_1_));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = l_args_3_1_ = None
                true_graph_0 = self.true_graph_0
                false_graph_0 = self.false_graph_0
                cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (y, x));  gt = true_graph_0 = false_graph_0 = y = x = None
                getitem: "f32[s35, s16]" = cond[0]
                getitem_1: "f32[s35, s16]" = cond[1];  cond = None
                return (getitem, getitem_1)
                
            class true_graph_0(torch.nn.Module):
                def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"):
                     # File: <eval_with_key>.0:7 in forward, code: w = l_args_3_1__1.add(l_args_3_0__1)
                    add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y)
                    
                     # File: <eval_with_key>.0:8 in forward, code: z = l_args_3_1__1.sub(l_args_3_0__1);  l_args_3_1__1 = l_args_3_0__1 = None
                    sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(x, y);  x = y = None
                    return (add, sub)
                    
            class false_graph_0(torch.nn.Module):
                def forward(self, y: "f32[s35, s16]", x: "f32[s35, s16]"):
                     # File: <eval_with_key>.1:7 in forward, code: abs_1 = torch.abs(l_args_3_1__1)
                    abs_1: "f32[s35, s16]" = torch.ops.aten.abs.default(x)
                    
                     # File: <eval_with_key>.1:8 in forward, code: w = abs_1.add(l_args_3_0__1);  abs_1 = None
                    add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(abs_1, y);  abs_1 = None
                    
                     # File: <eval_with_key>.1:9 in forward, code: abs_2 = torch.abs(l_args_3_1__1);  l_args_3_1__1 = None
                    abs_2: "f32[s35, s16]" = torch.ops.aten.abs.default(x);  x = None
                    
                     # File: <eval_with_key>.1:10 in forward, code: z = abs_2.sub(l_args_3_0__1);  abs_2 = l_args_3_0__1 = None
                    sub: "f32[s35, s16]" = torch.ops.aten.sub.Tensor(abs_2, y);  abs_2 = y = None
                    return (add, sub)
                    
    Graph signature: 
        # inputs
        x: USER_INPUT
        y: USER_INPUT
        
        # outputs
        getitem: USER_OUTPUT
        getitem_1: USER_OUTPUT
        
    Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}