import ast
import inspect
import types
import textwrap
from typing import Callable, Dict
import torch
NODE_TYPES = tuple(
getattr(ast, k)
for k in dir(ast)
if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type)
)
def _settl(node, lineno, level=0):
if isinstance(node, (str, int, float)):
return node
if isinstance(node, list):
for n in node:
_settl(n, lineno, level=level + 1)
return node
if isinstance(node, NODE_TYPES):
if not hasattr(node, "lineno") or node.lineno is None:
node.lineno = lineno
for k in dir(node):
if k in {"s", "n"}:
continue
if k[0] == "_":
continue
v = getattr(node, k)
_settl(v, max(lineno, node.lineno), level=level + 1)
return node
[docs]
class RewriteControlFlow(ast.NodeTransformer):
def __init__(self, wrapper_name):
self.wrapper_name = wrapper_name
self.counter = 0
self.current_func_args = None
def visit_FunctionDef(self, node):
# Capture argument names for branch functions
old_args = self.current_func_args
self.current_func_args = [arg.arg for arg in node.args.args]
node.body = [self.visit(n) for n in node.body]
self.current_func_args = old_args
return node
def visit_If(self, node):
# First recurse into subnodes
node = self.generic_visit(node)
test_node = node.test
# Case 1: simple assignment in both branches
if (
len(node.body) == 1
and isinstance(node.body[0], ast.Assign)
and len(node.orelse) == 1
and isinstance(node.orelse[0], ast.Assign)
and self.current_func_args is not None
):
then_assign = node.body[0]
else_assign = node.orelse[0]
tgt = then_assign.targets[0]
if (
isinstance(tgt, ast.Name)
and isinstance(else_assign.targets[0], ast.Name)
and tgt.id == else_assign.targets[0].id
):
self.counter += 1
then_name = f"{self.wrapper_name}_then_{self.counter}"
else_name = f"{self.wrapper_name}_else_{self.counter}"
then_expr = then_assign.value
else_expr = else_assign.value
# extract free variables
then_vars = sorted(
{
n.id
for n in ast.walk(then_expr)
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
}
)
else_vars = sorted(
{
n.id
for n in ast.walk(else_expr)
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
}
)
# build local funcs
then_args = [ast.arg(arg=v, annotation=None) for v in then_vars]
then_def = ast.FunctionDef(
name=then_name,
args=ast.arguments(
posonlyargs=[],
args=then_args,
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[ast.Return(then_expr)],
decorator_list=[],
returns=None,
)
else_args = [ast.arg(arg=v, annotation=None) for v in else_vars]
else_def = ast.FunctionDef(
name=else_name,
args=ast.arguments(
posonlyargs=[],
args=else_args,
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[ast.Return(else_expr)],
decorator_list=[],
returns=None,
)
# fix locations
for n in (then_def, else_def):
ast.copy_location(n, node)
ast.fix_missing_locations(n)
assert hasattr(n, "lineno")
# wrapper call and assignment
then_args_tuple = ast.Tuple(
[ast.Name(id=v, ctx=ast.Load()) for v in then_vars],
ctx=ast.Load(),
)
else_args_tuple = ast.Tuple(
[ast.Name(id=v, ctx=ast.Load()) for v in else_vars],
ctx=ast.Load(),
)
call = ast.Call(
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
args=[
test_node,
ast.Name(id=then_name, ctx=ast.Load()),
ast.Name(id=else_name, ctx=ast.Load()),
then_args_tuple,
else_args_tuple,
],
keywords=[],
)
assign = ast.Assign(targets=[tgt], value=call)
ast.copy_location(assign, node)
ast.fix_missing_locations(assign)
return [then_def, else_def, assign]
# Case 2: simple return in both branches
if (
len(node.body) == 1
and isinstance(node.body[0], ast.Return)
and len(node.orelse) == 1
and isinstance(node.orelse[0], ast.Return)
and self.current_func_args is not None
):
then_ret = node.body[0]
else_ret = node.orelse[0]
then_expr = then_ret.value
else_expr = else_ret.value
self.counter += 1
then_name = f"{self.wrapper_name}_then_{self.counter}"
else_name = f"{self.wrapper_name}_else_{self.counter}"
# extract free variables
then_vars = sorted(
{
n.id
for n in ast.walk(then_expr)
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
}
)
else_vars = sorted(
{
n.id
for n in ast.walk(else_expr)
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load)
}
)
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ != "torch")
# build local funcs
then_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
then_def = ast.FunctionDef(
name=then_name,
args=ast.arguments(
posonlyargs=[],
args=then_args,
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[ast.Return(then_expr)],
decorator_list=[],
returns=None,
)
else_args = [ast.arg(arg=v, annotation=None) for v in then_else_vars]
else_def = ast.FunctionDef(
name=else_name,
args=ast.arguments(
posonlyargs=[],
args=else_args,
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[ast.Return(else_expr)],
decorator_list=[],
returns=None,
)
for n in (then_def, else_def):
ast.copy_location(n, node)
ast.fix_missing_locations(n)
# wrapper call and return
then_else_args_list = ast.List(
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
ctx=ast.Load(),
)
call = ast.Call(
func=ast.Name(id=self.wrapper_name, ctx=ast.Load()),
args=[
test_node,
ast.Name(id=then_name, ctx=ast.Load()),
ast.Name(id=else_name, ctx=ast.Load()),
then_else_args_list,
],
keywords=[],
)
ret = ast.Return(call)
ast.copy_location(ret, node)
ast.fix_missing_locations(ret)
return [then_def, else_def, ret]
return node
[docs]
def generic_visit(self, node):
return super().generic_visit(node)
[docs]
class RewrittenMethod:
"""
Stores a rewritten method using
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
:param tree: ast tree
:param func: callable compiled from the tree
"""
def __init__(self, tree, func):
self.tree = tree
self.func = func
@property
def code(self) -> str:
"""Returns the source."""
return ast.unparse(self.tree)
def __repr__(self):
"usual"
return f"{self.__class__.__name__}({self.func})"