import ast
import copy
import contextlib
import difflib
import inspect
import os
import types
import textwrap
import sys
from typing import Callable, Dict, List, Set, Optional, Tuple, Union
from .patch_module_helper import code_needing_rewriting
NODE_TYPES = tuple(
getattr(ast, k)
for k in dir(ast)
if "A" <= k[0] <= "Z" and isinstance(getattr(ast, k), type)
)
def _settl(node, lineno, level=0):
if isinstance(node, (str, int, float)):
return node
if isinstance(node, list):
for n in node:
_settl(n, lineno, level=level + 1)
return node
if isinstance(node, NODE_TYPES):
if not hasattr(node, "lineno") or node.lineno is None:
node.lineno = lineno
for k in dir(node):
if k in {"s", "n", "parent"}:
continue
if k[0] == "_":
continue
v = getattr(node, k)
_settl(v, max(lineno, node.lineno), level=level + 1)
return node
[docs]
class UsedVarsFinder(ast.NodeVisitor):
"""Finds used and defined local variables with a section."""
def __init__(self):
self.used = set()
self.defined = set()
def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
self.used.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.defined.add(node.id)
self.generic_visit(node)
def visit_Global(self, node):
pass
def visit_Nonlocal(self, node):
pass
[docs]
class ShapeFinder(ast.NodeVisitor):
"""Finds <x> in the expression ``x.shape[0]``."""
def __init__(self):
self.found_shape = set()
super().__init__()
def visit_Call(self, node):
if isinstance(node.func, ast.Name) and node.func.id == "range" and len(node.args) == 1:
n = node.args[0]
if (
isinstance(n, ast.Subscript)
and isinstance(n.slice, ast.Constant)
and isinstance(n.slice.value, int)
and n.slice.value == 0
and isinstance(n.value, ast.Attribute)
and isinstance(n.value.value, ast.Name)
and n.value.attr == "shape"
):
self.found_shape.add(n.value.value.id)
self.generic_visit(node)
[docs]
class RewriteControlFlow(ast.NodeTransformer):
"""
The class rewrites tests with function :func:`torch.cond`.
``empty_tensor`` is a function returning an empty tensor,
when a branch returns something the other branch does not.
:param prefix: prefix used for nested tests
:param skip_objects: to skip variable names if included in that list
such as modules
:param args_names: defines the local variables
:param filter_nodes: a function which is used to decide which node
to rewrite, True by default
:param pre_rewriter: a rewriter applied before the automated rewriting
:param post_rewriter: a rewriter applied after the automated rewriting
"""
def __init__(
self,
prefix: str = "branch_cond",
skip_objects: Optional[Dict[str, object]] = None,
args_names: Optional[Set[str]] = None,
filter_node: Optional[Callable[["ast.Node"], bool]] = None,
pre_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
post_rewriter: Optional[Callable[["ast.Node"], "ast.Node"]] = None,
):
self.counter_test = 0
self.counter_loop = 0
self.current_func_args = None
self.prefix = prefix
self.skip_objects = skip_objects or {}
self.args_names = args_names or set()
self.local_variables = self.args_names.copy()
self.filter_node = filter_node or (lambda _node: True)
self.pre_rewriter = pre_rewriter or (lambda node: node)
self.post_rewriter = post_rewriter or (lambda node: node)
[docs]
def generic_visit(self, node):
return super().generic_visit(node)
def _check(
self, cond: bool, node: "ast.Node", msg: str, cls: Optional[type[Exception]] = None
):
"""
Checks the condition is True, otherwise raises an exception with an error message
including the parsed code.
"""
if cls is not None:
if not cond:
smsg = msg if isinstance(msg, str) else msg()
raise cls(f"{smsg}\n\n--\n{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}")
return
assert cond, (
f"{msg if isinstance(msg, str) else msg()}\n\n--\n"
f"{ast.unparse(node)}\n--\n{ast.dump(node, indent=2)}"
)
def visit_Name(self, node):
node = self.generic_visit(node)
if isinstance(node.ctx, ast.Store):
self.local_variables.add(node.id)
return node
def visit_FunctionDef(self, node):
# Capture argument names for branch functions
old_args = self.current_func_args
self.current_func_args = [arg.arg for arg in node.args.args]
node.body = [self.visit(n) for n in node.body]
self.current_func_args = old_args
return node
def _find_id(self, exprs: List["ast.Node"]) -> List[str]:
vars = []
for expr in exprs:
for n in ast.walk(expr):
if (
isinstance(n, ast.Name)
# and isinstance(n.ctx, ast.Load)
and n.id not in self.skip_objects
):
vars.append(n.id)
return sorted(set(vars))
def _clone(self, name):
assert isinstance(name, ast.Name), f"Unexpected type {type(name)} for name"
return ast.Call(
func=ast.Attribute(value=name, attr="clone", ctx=ast.Load()), args=[], keywords=[]
)
def _rewrite_if(
self, node, then_exprs, else_exprs, tgt_mapping=None, known_local_variables=None
):
assert known_local_variables is not None, "known_local_variables cannot be None"
test_node = node.test
drop = set()
# extract free variables
then_name = f"{self.prefix}_then_{self.counter_test}"
else_name = f"{self.prefix}_else_{self.counter_test}"
then_vars = self._find_id(then_exprs)
else_vars = self._find_id(else_exprs)
then_else_vars = set(_ for _ in [*then_vars, *else_vars] if _ in known_local_variables)
then_ret, else_ret = None, None
if tgt_mapping is None and len(then_exprs) == 1 and len(else_exprs) == 1:
# return
then_ret = then_exprs[0]
else_ret = else_exprs[0]
then_exprs = [n for n in node.body if not isinstance(n, ast.Return)]
else_exprs = [n for n in node.orelse if not isinstance(n, ast.Return)]
is_tuple_or_list = (
isinstance(then_ret, (ast.Tuple, ast.List)),
isinstance(else_ret, (ast.Tuple, ast.List)),
)
assert len(set(is_tuple_or_list)) == 1, (
f"is_tuple_or_list={is_tuple_or_list}, inconsistencies return "
f"then value={then_ret}, "
f"else value={else_ret}"
)
if is_tuple_or_list[0]:
assert len(then_ret.elts) == len(else_ret.elts), (
f"Unexpected number of elements on both branches, "
f"then:{then_ret.elts}, else:{else_ret.elts}"
)
n_returned_values = len(then_ret.elts)
else:
n_returned_values = 0
else:
self._check(
tgt_mapping,
node,
"then and else branches do not have the same number "
"of assignments, we need more information to understand "
"which ones to return",
)
drop = set()
then_exprs, else_exprs = node.body, node.orelse
then_rets, else_rets = [], []
for t, then_else in sorted(tgt_mapping.items()):
then_e, else_e = then_else
if (then_e is None or else_e is None) and t not in then_else_vars:
# The variable is not used by one branch and it is not an input.
# Let's drop it.
drop.add(t)
continue
then_rets.append(then_e or ast.Name(else_e.id, ctx=ast.Load()))
else_rets.append(else_e or ast.Name(then_e.id, ctx=ast.Load()))
then_ret = (
self._clone(then_rets[0])
if len(then_rets) == 1
else ast.Tuple([self._clone(r) for r in then_rets], ctx=ast.Load())
)
else_ret = (
self._clone(else_rets[0])
if len(else_rets) == 1
else ast.Tuple([self._clone(r) for r in else_rets], ctx=ast.Load())
)
n_returned_values = len(then_rets) if len(then_rets) > 1 else 0
# build local funcs
then_def = ast.FunctionDef(
name=then_name,
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[*then_exprs, ast.Return(then_ret)],
decorator_list=[],
returns=None,
)
else_def = ast.FunctionDef(
name=else_name,
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg=v, annotation=None) for v in then_else_vars],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[*else_exprs, ast.Return(else_ret)],
decorator_list=[],
returns=None,
)
# fix locations
for n in (then_def, else_def):
ast.copy_location(n, node)
ast.fix_missing_locations(n)
assert hasattr(n, "lineno")
# wrapper call and assignment
then_else_args_list = ast.List(
[ast.Name(id=v, ctx=ast.Load()) for v in then_else_vars],
ctx=ast.Load(),
)
call = ast.Call(
func=ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()), attr="cond", ctx=ast.Load()
),
args=[
test_node,
ast.Name(id=then_name, ctx=ast.Load()),
ast.Name(id=else_name, ctx=ast.Load()),
then_else_args_list,
],
keywords=[],
)
return then_def, else_def, call, drop, n_returned_values
def _filter_target(self, node, tgt_mapping):
"""
This function should reduce the number of elements to return
by looking at the one used after the If statement.
"""
return tgt_mapping
def _make_targets(self, node, then_assigns, else_assigns):
tgt_mapping = {}
for a, then_or_else in [
*[(a, True) for a in then_assigns],
*[(a, False) for a in else_assigns],
]:
for t in a.targets:
if isinstance(t, ast.Name) and isinstance(t.ctx, ast.Store):
if t.id not in tgt_mapping:
tgt_mapping[t.id] = (t, None) if then_or_else else (None, t)
else:
v = tgt_mapping[t.id]
tgt_mapping[t.id] = (t, v[1]) if then_or_else else (v[0], t)
continue
self._check(
isinstance(t, ast.Tuple) and all(isinstance(_, ast.Name) for _ in t.elts),
node,
"Unexpected assignment. Not Supported.",
)
for _t in t.elts:
if not isinstance(_t, ast.Name) or not isinstance(_t.ctx, ast.Store):
continue
if _t.id not in tgt_mapping:
tgt_mapping[_t.id] = (_t, None) if then_or_else else (None, _t)
else:
v = tgt_mapping[_t.id]
tgt_mapping[_t.id] = (_t, v[1]) if then_or_else else (v[0], _t)
tgt_mapping = self._filter_target(node, tgt_mapping)
d = [(v[0] or v[1]) for k, v in sorted(dict(tgt_mapping).items())]
tgt = d[0] if len(d) == 1 else ast.Tuple(d, ctx=ast.Load())
return tgt, tgt_mapping
def visit_If(self, node):
if not self.filter_node(node):
return [node]
node = self.pre_rewriter(node)
# First recurse into subnodes
known_local_variables = self.local_variables.copy()
node = self.generic_visit(node)
has_then_return = any(isinstance(n, ast.Return) for n in node.body)
has_else_return = any(isinstance(n, ast.Return) for n in node.orelse)
ok = (has_then_return and has_else_return) or (
not has_then_return and not has_else_return
)
self._check(
ok,
node,
"Cannot mix return and assignment in a test or a "
"unique then branch with a return",
NotImplementedError,
)
self._check(self.current_func_args is not None, node, "current_func_args is None")
self.counter_test += 1
if not has_then_return:
# Case 1: simple assignment in both branches
then_assigns = [n for n in node.body if isinstance(n, ast.Assign)]
else_assigns = [n for n in node.orelse if isinstance(n, ast.Assign)]
self._check(then_assigns or else_assigns, node, "Missing assignment")
# the targets we need to export
tgt, tgt_mapping = self._make_targets(node, then_assigns, else_assigns)
then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
node,
then_assigns,
else_assigns,
tgt_mapping=tgt_mapping,
known_local_variables=known_local_variables,
)
if dropped and isinstance(tgt, ast.Tuple):
tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
elif isinstance(tgt, ast.Tuple):
tgt_elts = tuple(t for t in tgt.elts if t.id not in dropped)
else:
tgt_elts = [tgt]
if n_returned_values == 0:
assert len(tgt_elts) == 1, (
f"Inconsistencies between n_returned_values={n_returned_values}, "
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
)
tgt = tgt_elts[0]
else:
assert n_returned_values == len(tgt_elts), (
f"Inconsistencies between n_returned_values={n_returned_values}, "
f"dropped={dropped}, tgt.elts={tgt.elts}, tgt_elts={tgt_elts}"
)
tgt = ast.Tuple(tgt_elts, ctx=ast.Store())
added = {tgt.id} if isinstance(tgt, ast.Name) else set(t.id for t in tgt.elts)
assign = ast.Assign(targets=[tgt], value=call)
ast.copy_location(assign, node)
ast.fix_missing_locations(assign)
self.local_variables = known_local_variables | added
return [self.post_rewriter(n) for n in [then_def, else_def, assign]]
# Case 2: return in both branches, we assume both branches return the same results.
then_ret = node.body[-1]
else_ret = node.orelse[-1]
self._check(
isinstance(then_ret, ast.Return),
node,
"return is not the last instruction of then branch",
)
self._check(
isinstance(else_ret, ast.Return),
node,
"return is not the last instruction of else branch",
)
then_expr = then_ret.value
else_expr = else_ret.value
then_def, else_def, call, dropped, n_returned_values = self._rewrite_if(
node, [then_expr], [else_expr], known_local_variables=known_local_variables
)
ret = ast.Return(call)
ast.copy_location(ret, node)
ast.fix_missing_locations(ret)
return [self.post_rewriter(n) for n in [then_def, else_def, ret]]
def _find_loop_vars(self, node):
assert isinstance(node, ast.For), f"Unexpected type {type(node)} for node"
finder = ShapeFinder()
finder.visit(node.iter)
scan_shape_vars = finder.found_shape
scan_vars = set()
finder = UsedVarsFinder()
for stmt in node.body:
finder.visit(stmt)
assigned_in_body = set()
for stmt in node.body:
if isinstance(stmt, ast.Assign):
for tgt in stmt.targets:
if isinstance(tgt, ast.Name) and isinstance(tgt.value.ctx, ast.Store):
assigned_in_body |= {tgt.value.id}
extra_defined = set()
for stmt in node.body:
if isinstance(stmt, ast.Assign):
for tgt in stmt.targets:
if isinstance(tgt, ast.Subscript):
# It means the target existed before.
if (
isinstance(tgt.value, ast.Name)
and tgt.value.id not in assigned_in_body
):
extra_defined.add(tgt.value.id)
loop_vars = set()
if isinstance(node.target, ast.Name):
loop_vars.add(node.target.id)
elif isinstance(node.target, (ast.Tuple, ast.List)):
loop_vars |= {elt.id for elt in node.target.elts if isinstance(elt, ast.Name)}
output_vars = finder.defined | assigned_in_body
input_vars = (
finder.used
- finder.defined
- loop_vars
- scan_shape_vars
- scan_vars
- output_vars
- assigned_in_body
- extra_defined
)
return dict(
init=sorted(extra_defined),
loop=sorted(loop_vars),
scan_shape=sorted(scan_shape_vars),
scan=sorted(scan_vars),
input=sorted(input_vars),
output=sorted(output_vars),
)
def visit_For(self, node):
if not self.filter_node(node):
return [node]
node = self.pre_rewriter(node)
# For nested loops.
self.generic_visit(node)
# look for variables, loop, inputs and outputs of the body
vars = self._find_loop_vars(node)
init_vars, loop_vars, scan_shape_vars, scan_vars, input_vars, output_vars = [
vars[k] for k in ["init", "loop", "scan_shape", "scan", "input", "output"]
]
self._check(
len(scan_shape_vars) == len(loop_vars),
node,
lambda: (
f"Inconsistencies between loop_vars={loop_vars} "
f"and scan_shape_vars={scan_shape_vars}"
),
)
self._check(
len(scan_shape_vars) in {0, 1},
node,
lambda: f"Inconsistencies with scan_shape_vars={scan_shape_vars}",
)
self._check(
(len(scan_shape_vars) == 0 or len(scan_vars) == 0)
and (scan_shape_vars or scan_vars),
node,
lambda: (
f"Inconsistencies between scan_vars={scan_vars} "
f"and scan_shape_vars={scan_shape_vars}"
),
)
# creates the function
func_name = f"loop_body_{self.counter_loop}"
self.counter_loop += 1
func_def = ast.FunctionDef(
name=func_name,
args=ast.arguments(
posonlyargs=[],
args=[
ast.arg(arg=v)
for v in [
*init_vars,
*loop_vars,
*scan_vars,
*scan_shape_vars,
*input_vars,
]
],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
),
body=[
*[
ast.Assign(
targets=[ast.Name(id=i, ctx=ast.Load())],
value=[
ast.Call(
func=ast.Attribute(
value=ast.Name(id=i, ctx=ast.Load()),
attr="clone",
ctx=ast.Load(),
),
args=[],
keywords=[],
ctx=ast.Load(),
)
],
)
for i in init_vars
],
*node.body,
ast.Return(
value=ast.List(
[
ast.Name(id=v, ctx=ast.Load())
for v in [*init_vars, *loop_vars, *output_vars]
],
ctx=ast.Load(),
)
),
],
decorator_list=[],
ctx=ast.Store(),
)
# final rewriting
call = ast.Call(
func=(
ast.Attribute(
value=ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()),
attr="ops",
ctx=ast.Load(),
),
attr="higher_order",
ctx=ast.Load(),
),
attr="scan",
ctx=ast.Load(),
)
),
args=[
ast.Name(id=func_name, ctx=ast.Load()),
ast.List(
elts=[ast.Name(id=v, ctx=ast.Load()) for v in init_vars], ctx=ast.Store()
),
ast.List(
elts=[
*[
ast.Call(
ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()),
attr="arange",
ctx=ast.Load(),
),
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id=v, ctx=ast.Load()),
attr="shape",
ctx=ast.Load(),
),
slice=ast.Constant(value=0, ctx=ast.Load()),
ctx=ast.Load(),
),
],
keywords=[
ast.keyword(
arg="dtype",
value=ast.Attribute(
value=ast.Name(id="torch", ctx=ast.Load()),
attr="int64",
ctx=ast.Load(),
),
)
],
ctx=ast.Load(),
)
for v in scan_shape_vars
],
*[ast.Name(id=v, ctx=ast.Load()) for v in scan_vars],
],
ctx=ast.Store(),
),
ast.List(
elts=[
ast.Name(id=v, ctx=ast.Load()) for v in [*scan_shape_vars, *input_vars]
],
ctx=ast.Store(),
),
],
keywords=[],
ctx=ast.Load(),
)
target = ast.Tuple(
[ast.Name(id=v, ctx=ast.Store()) for v in [*init_vars, *loop_vars, *output_vars]],
ctx=ast.Store(),
)
assign = ast.Assign(targets=[target], value=call)
return [self.post_rewriter(func_def), self.post_rewriter(assign)]
[docs]
class RewrittenMethod:
"""
Stores a rewritten method using
:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`.
:param tree: ast tree
:param func: callable compiled from the tree
"""
def __init__(self, tree, func):
self.tree = tree
self.func = func
@property
def code(self) -> str:
"""Returns the source."""
return ast.unparse(self.tree)
@property
def dump(self) -> str:
"""Returns the tree dumped as a string."""
return ast.dump(self.tree, indent=2)
def __repr__(self):
"usual"
return f"{self.__class__.__name__}({self.func})"
class _AddParentTransformer(ast.NodeTransformer):
parent = None
def visit(self, node):
node.parent = self.parent
self.parent = node
node = super().visit(node)
if isinstance(node, ast.AST):
self.parent = node.parent
return node
class _SelectiveAssignNormalizer(ast.NodeTransformer):
def visit_If(self, node):
self.generic_visit(node)
node.body = [self._transform_if_needed(stmt) for stmt in node.body]
node.orelse = [self._transform_if_needed(stmt) for stmt in node.orelse]
return node
def _transform_if_needed(self, stmt):
if isinstance(stmt, ast.AugAssign):
return ast.Assign(
targets=[stmt.target],
value=ast.BinOp(left=copy.deepcopy(stmt.target), op=stmt.op, right=stmt.value),
)
if isinstance(stmt, ast.AnnAssign) and stmt.value is not None:
return ast.Assign(targets=[stmt.target], value=stmt.value)
return self.visit(stmt)
[docs]
def inplace_add_parent(tree: "ast.Node"):
"""Adds parents to an AST tree."""
_AddParentTransformer().visit(tree)
[docs]
def normalize_assignment_in_test(tree: "ast.Node"):
"""Split AugAssign into BinOp and Assign to simplify whatever comes after."""
_SelectiveAssignNormalizer().visit(tree)
[docs]
@contextlib.contextmanager
def torch_export_rewrite(
rewrite: Optional[
Union["torch.nn.Module", List[Union[Tuple[type, str], Callable]]] # noqa: F821
] = None,
dump_rewriting: Optional[str] = None,
verbose: int = 0,
):
"""
Automatically rewrite the methods given in `rewrite` to export
control flows (test and loops).
:param rewrite: methods of functions to rewrite, if not empty, the function may try
to discover them, a method is defined by its class (a type) and its name
if the class is local, by itself otherwise, it can also be a model,
in that case, the function calls :func:`code_needing_rewriting
<onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting>`
to retrieve the necessary rewriting
:param dump_rewriting: dumps rewriting into that folder, if it does not exists,
it creates it.
:param verbose: verbosity, up to 10, 10 shows the rewritten code,
``verbose=1`` shows the rewritten function,
``verbose=2`` shows the rewritten code as well
Example:
.. code-block:: python
class Model(torch.nn.Module):
def forward(self, x, y):
if x.sum() > 0:
return x + y
else:
return torch.abs(x) + y + 1
model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[(Model, "forward")]):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
If the method to rewrite is not local, then the following can be used:
.. code-block:: python
with torch_export_rewrite(rewrite=[Model.forward]):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
Functions (if not local) can also be rewritten:
.. code-block:: python
def outside(x, y):
if x.sum() > 0:
return x + y
else:
return torch.abs(x) + y + 1
class Model(torch.nn.Module):
def forward(self, x, y):
return outside(x, y)
model = Model()
x, y = torch.rand((4, 5)), torch.rand((4, 5))
DYN = torch.export.Dim.DYNAMIC
ds = ({0: DYN, 1: DYN}, {0: DYN, 1: DYN})
with torch_export_rewrite(rewrite=[outside]):
ep = torch.export.export(model, (x, y), dynamic_shapes=ds)
"""
if hasattr(rewrite, "forward"):
# It is a torch.nn.Module.
# Let's retrieve the known rewriting for this model class.
rewrite = code_needing_rewriting(rewrite.__class__.__name__)
assert rewrite, "rewrite is empty, automated discovery is not implemented yet"
keep = {}
for me in rewrite:
if isinstance(me, tuple):
assert len(me) == 2, f"Unexpected value for a rewritten method or function {me}"
cls, name = me
to_rewrite = getattr(cls, name)
kind = "method"
kws = {} # type: ignore[var-annotated]
else:
if isinstance(me, dict):
assert "function" in me and (
"filter_node" in me or "pre_rewriter" in me or "post_rewriter" in me
), (
f"If the rewriting code is defined as a dictionary, key "
f"'function' must be defined, other arguments must be understood by "
f"{transform_method.__name__}, "
f"the given value is {me!r}."
)
kws = me
me = me["function"]
del kws["function"]
else:
kws = {}
name = me.__qualname__
spl = name.split(".")
if len(spl) == 1:
# This a function
module = me.__module__
if module in me.__globals__:
mod = me.__globals__[module]
else:
assert module in sys.modules, (
f"Cannot find module name {module!r} in sys.modules or "
f"__globals__={sorted(me.__globals__)}"
)
mod = sys.modules[module]
cls_name = module
cls = mod
name = name
to_rewrite = me
kind = "function"
else:
kind = "method"
# This is a method
assert len(spl) >= 2, (
f"{me} is not method, its name {name!r} does not contain a class name, "
f"dir(me)={dir(me)}"
)
cls_name = spl[-2]
assert cls_name in me.__globals__, (
f"Class name {cls_name!r} from method {name!r} "
f"could not be found in set(me.__globals__)={sorted(me.__globals__)}"
)
cls = me.__globals__[cls_name]
name = me.__name__
to_rewrite = me
assert hasattr(
cls, name
), f"Method {name!r} inferred form {me} was not found in class {cls}."
assert (cls, name) not in keep, f"{kind} {me} cannot be rewritten twice."
if verbose:
print(f"[torch_export_rewrite] rewrites {kind} {cls.__name__}.{name}")
keep[cls, name] = to_rewrite
if dump_rewriting:
if not os.path.exists(dump_rewriting):
os.makedirs(dump_rewriting)
filename1 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.original.py")
if verbose:
print(f"[torch_export_rewrite] dump original code in {filename1!r}")
with open(filename1, "w") as f:
code = _clean_code(inspect.getsource(to_rewrite))
f.write(code)
rewr = transform_method(to_rewrite, verbose=max(verbose - 1, 0), **kws)
if dump_rewriting:
filename2 = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.rewritten.py")
if verbose:
print(f"[torch_export_rewrite] dump rewritten code in {filename2!r}")
with open(filename2, "w") as f:
rcode = _clean_code(rewr.code)
f.write(rcode)
diff = os.path.join(dump_rewriting, f"{kind}.{cls_name}.{name}.diff")
make_diff(code, rcode, diff)
setattr(cls, name, rewr.func)
try:
yield
finally:
for (cls, name), me in keep.items():
if verbose:
print(f"[torch_export_rewrite] restored {kind} {cls.__name__}.{name}")
setattr(cls, name, me)
def _clean_code(code: str) -> str:
try:
import black
except ImportError:
return code
return black.format_str(code, mode=black.FileMode(line_length=98))
[docs]
def make_diff(code1: str, code2: str, output: Optional[str] = None) -> str:
"""
Creates a diff between two codes.
:param code1: first code
:param code2: second code
:param output: if not empty, stores the output in this file
:return: diff
"""
text = "\n".join(
difflib.unified_diff(
code1.strip().splitlines(),
code2.strip().splitlines(),
fromfile="original",
tofile="rewritten",
lineterm="",
)
)
if output:
with open(output, "w") as f:
f.write(text)
return text