import contextlib
import inspect
import math
import operator
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.fx import Node
from torch.fx.proxy import TracerBase
from ..helpers import string_type
_torch_cat = torch.cat
[docs]
class CustomProxy(torch.fx.proxy.Proxy):
"""
Defines a custom proxy to trace the execution of a model
and converts it into a fx graph.
Works with :class:`CustomTracer
<experimental_experiment.torch_interpreter.tracing.CustomTracer>`.
"""
def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None):
super().__init__(node, tracer=tracer)
assert isinstance(
self.tracer, CustomTracer
), f"Unexpected type {type(self.tracer)} for the tracer."
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.node.name})"
def _custom_fx_repr_fn(self) -> str:
"To avoid bugs."
return f"CustomProxy(%{str(self.node)})"
def __getattr__(self, k) -> "CustomAttribute":
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return CustomAttribute(self, k)
@classmethod
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
if isinstance(orig_method, torch._ops.HigherOrderOperator):
# not implemented by torch
if orig_method is torch.cond:
assert (
not kwargs
), f"Unexpected kwargs={kwargs}, args={args}, orig_method={orig_method}"
assert (
len(args) == 4
), f"Unexpected kwargs={kwargs}, args={args}, orig_method={orig_method}"
assert isinstance(
args[3], list
), f"Unexpected type {type(args[3])} for the last argument"
root = args[0]
cond_true = root.tracer.register_callable("cond", args[1])
cond_false = root.tracer.register_callable("cond", args[2])
node = root.tracer.create_node(
"call_function",
orig_method,
args=(
args[0].node,
cond_true,
cond_false,
type(args[3])(a.node for a in args[3]),
),
kwargs={},
)
return root.tracer.proxy(node)
return torch.fx.proxy.Proxy.__torch_function__(
orig_method, types, args=args, kwargs=kwargs
)
def __setitem__(self, *args, **kwargs):
assert not kwargs, f"Unexpected not empty kwargs={kwargs!r}"
assert len(args) == 2, f"Unexpected number of args={len(args)}: {args}"
indices, values = args
if isinstance(indices, CustomProxy):
indices = indices.node
node = self.tracer.create_node(
"call_function",
operator.setitem,
args=(self.node, indices, values.node if hasattr(values, "node") else values),
kwargs={},
)
# node_to_replace = self.node
return self.tracer.proxy(node)
def __len__(self):
raise RuntimeError(
"len(.) expects an integer, len needs to be replaced. You should use _len."
)
[docs]
def length(self):
"""Returns a proxy for the length."""
node = self.tracer.create_node("call_method", "__len__", args=(self.node,), kwargs={})
tt = self.tracer.proxy(node, cls=CustomProxyInt)
return tt
[docs]
def instanceof(self, cls):
"""Tells if this proxy represents a specific class."""
raise RuntimeError(f"Unable to know if cls is from type {cls}.")
[docs]
@classmethod
def cat(
cls,
tensors: List["CustomProxy"],
dim: int = 0,
*,
out=None,
axis: Optional[int] = None,
) -> "CustomProxy":
"""Implements cat for tensors."""
assert out is None, "Tracing is not implementing is out is not None."
if isinstance(tensors, list):
return _torch_cat(tensors, dim)
if axis is not None and dim == 0:
dim = axis
proxy = tensors
node = proxy.tracer.create_node(
"call_function", torch.cat, args=(proxy.node, dim), kwargs={}
)
return proxy.tracer.proxy(node)
def _len(x: Any) -> Union[int, CustomProxy]:
"""
Overloads `len` to return a proxy if the input is the proxy.
"""
if isinstance(x, CustomProxy):
return x.length()
return len(x)
def _isinstance(x, cls):
"""
Overloads `isinstance` to deal with CustomProxy.
"""
if isinstance(x, CustomProxy):
return x.instanceof(cls)
return isinstance(x, list)
[docs]
class CustomProxyInt(CustomProxy):
"A proxy for an integer."
[docs]
def instanceof(self, cls):
"""isinstance"""
return cls in {CustomProxyInt, CustomProxy, int}
[docs]
class CustomProxyFloat(CustomProxy):
"A proxy for a float."
[docs]
def instanceof(self, cls):
"""isinstance"""
return cls in {CustomProxyInt, CustomProxy, float}
[docs]
class CustomAttribute(CustomProxy):
"""
To trace attributes.
"""
def __init__(self, root: CustomProxy, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy(
"call_function", getattr, (self.root, self.attr), {}
).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root, *args), kwargs)
[docs]
class CustomParameterProxy(CustomProxy):
"""
A special proxy which lets "shape", "size", "dim", and a few other
attribute accesses pass through to the underlying module parameter object,
so that conditional tests on these attributes will not throw exception during tracing.
"""
def __init__(self, tracer: TracerBase, node: Node, name, param):
super().__init__(node, tracer)
assert isinstance(param, torch.nn.Parameter)
self.param = param
self.name = name
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name})"
@property
def shape(self):
return self.param.shape
def size(self):
return self.param.size()
def dim(self):
return self.param.dim()
@property
def ndim(self):
return self.param.ndim
def numel(self):
return self.param.numel()
def nelement(self):
return self.param.nelement()
[docs]
class CondCCOp(torch._ops.HigherOrderOperator):
"""
Cannot be imported from torch.ops.higher_order.cond
(function cond overwrite submodule cond).
"""
def __init__(self):
# we cannot use "cond" to avoid confusion with the existing cond
super().__init__("condcc")
def __call__(self, pred, true_fn, false_fn, operands):
# torch._higher_order_ops.utils.validate_subgraph_args_types(operands)
return super().__call__(pred, true_fn, false_fn, operands)
[docs]
@contextlib.contextmanager
def replace_problematic_function_before_tracing():
"""
Replaces function that cannot be traced with the default tracer
such as :func:`torch.cat`.
"""
saved = {
"cat": torch.cat,
"cond": torch.cond,
# ("torch.ops.higher_order", "cond"): torch.ops.higher_order.cond,
}
newf = {
"cat": CustomProxy.cat,
"cond": CondCCOp(),
# ("torch.ops.higher_order", "cond"): CondOp(),
}
for k, v in newf.items():
if isinstance(k, tuple):
setattr(k[0], k[1], v)
else:
setattr(torch, k, v)
try:
yield
finally:
for k, v in saved.items():
if isinstance(k, tuple):
setattr(k[0], k[1], v)
else:
setattr(torch, k, v)
[docs]
class CustomTracer(torch.fx.Tracer):
"""
Defines a custom tracer to trace the execution of a model
and converts it into a fx graph.
Works with :class:`CustomProxy
<experimental_experiment.torch_interpreter.tracing.CustomProxy>`.
::
from experimental_experiment.torch_interpreter.tracing import CustomTracer
graph = CustomTracer().trace(model)
"""
def __init__(
self,
autowrap_modules: Tuple["ModuleType"] = (math,), # noqa: F821
autowrap_functions: Tuple[Callable, ...] = (),
param_shapes_constant: bool = False,
):
super().__init__(
autowrap_modules=autowrap_modules,
autowrap_functions=autowrap_functions,
param_shapes_constant=param_shapes_constant,
)
self._callables = {}
[docs]
def register_callable(self, name: str, fn: Callable) -> torch.fx.Node:
"""
Registers a function and return a unique name.
:param name: prefix to prepend to the function name
:param fn: function
:return: new_name
"""
cand = f"_cb_{name}_{fn.__name__}_0"
if cand in self._callables:
i = 1
cand = f"_cb_{name}_{fn.__name__}_{i}"
while cand in self._callables:
i += 1
cand = f"_cb_{name}_{fn.__name__}_{i}"
self._callables[cand] = fn
return self.create_node("get_attr", cand, args=(), kwargs={})
[docs]
def proxy(
self, node: torch.fx.Node, cls: type[CustomProxy] = CustomProxy
) -> torch.fx.Proxy:
"""
Overwrites this method to replace the default Proxy by CustomProxy.
"""
return cls(node, self)
[docs]
def create_arg(self, a: Any) -> "Argument": # noqa: F821
"""
Overwrites this method to deal with more argument.
"""
if a is bool:
return torch.bool
if a is int:
return torch.int64
if a is float:
return torch.float32
if a is complex:
return torch.complex64
return super().create_arg(a)
[docs]
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
"""
See :meth:`torch.fx.Tracer.getattr`.
"""
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if (
"proxy_factory_fn"
in inspect.signature(self.create_proxy).parameters
):
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node, n=n, attr_val=attr_val: CustomParameterProxy(
self, node, n, attr_val
)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
[docs]
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
remove_inplace: bool = True,
update_model_with_callable: bool = True,
) -> torch.fx.Graph:
"""
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
can either be an ``nn.Module`` instance or a Python callable.
Note that after this call, ``self.root`` may be different from the ``root`` passed
in here. For example, when a free function is passed to ``trace()``, we will
create an ``nn.Module`` instance to use as the root and add embedded constants to.
:param root: Either a ``Module`` or a function to be
traced through. Backwards-compatibility for this parameter is
guaranteed.
:param concrete_args: Concrete arguments that should
not be treated as Proxies. This parameter is experimental and
its backwards-compatibility is *NOT* guaranteed.
:param remove_inplace: Removes inplace nodes
:param update_model_with_attribute: in some cases (control flow),
the model needs to be
:return: A ``Graph`` representing the semantics of the passed-in ``root``.
"""
assert concrete_args is None or isinstance(
concrete_args, dict
), f"Unexpected type for concrete_args: {string_type(concrete_args)}"
with replace_problematic_function_before_tracing():
graph = super().trace(root)
if concrete_args:
for node in graph.nodes:
if node.op == "placeholder":
if node.name in concrete_args:
node.meta["example_value"] = concrete_args[node.name]
self._replace_problematic_functions(graph)
if update_model_with_callable and self._callables:
for k, v in self._callables.items():
setattr(root, k, v)
self.remove_unnecessary_slices(graph)
if not remove_inplace:
graph.lint()
return graph
self.remove_inplace(graph)
graph.lint()
return graph
@classmethod
def _replace_problematic_functions(cls, graph: torch.fx.Graph) -> int:
"""
The tracing introduced some problematic functions which need to be replaced.
:return: number of impacted nodes
"""
replaces = {
CustomProxy.cat: torch.cat,
# CondCCOp: torch.ops.higher_order.cond,
}
n = 0
for node in graph.nodes:
if node.op == "call_function":
if node.target in replaces:
n += 1
node.target = replaces[node.target]
elif isinstance(node.target, CondCCOp):
n += 1
node.target = torch.ops.higher_order.cond
return n
@classmethod
def _get_aten_name(cls, node: torch.fx.Node) -> str:
"""
Returns the aten name for the target as a string.
"""
if node.target == operator.getitem:
return "getitem"
if isinstance(node.target, torch._ops.OpOverloadPacket):
if node.target != torch.ops.aten.sym_size:
raise RuntimeError(f"Unsupported function {node!r}.")
raise NotImplementedError(f"Unsupported function {node!r} (not implemented).")
if isinstance(node.target, types.BuiltinFunctionType):
return str(node.target)
if isinstance(node.target, torch._ops.OpOverload):
return node.target.name()
if callable(node.target):
# a single function
return f"aten_{node.target.__name__}"
if isinstance(node.target, str):
return node.target
raise NotImplementedError(
f"Unsupported function {node!r} (not implemented), "
f"node.target={node.target}, type is {type(node.target)}."
)
@classmethod
def _inplace_nodes(cls, graph: torch.fx.Graph) -> List[Tuple[int, torch.fx.Node]]:
"""
Returns the position and the node involved in inplace modifications.
"""
return [
(i, node)
for i, node in enumerate(graph.nodes)
if node.op != "output"
and len(node.users) == 0
and node.op.startswith("call_")
and node.target not in {operator.getitem}
and cls._get_aten_name(node)
not in {
"aten::_assert_scalar",
"aten::sym_constrain_range_for_size",
"aten::_log_api_usage_once",
"aten::_enter_autocast",
"aten::_set_grad_enabled",
}
]
@classmethod
def _replace_meth_setitem(cls, graph: torch.fx.Graph) -> int:
"""
The execution of ``op="call_method", target="__setitem__" `` returns None
We replace it by ``op="call_function", target="operator.setitem"``.
:return: number of impacted nodes
"""
n = 0
for node in graph.nodes:
if node.op == "call_method" and node.target == "__setitem__":
node.op = "call_function"
node.target = operator.setitem
n += 1
return n
@classmethod
def _replace_getattr(cls, graph: torch.fx.Graph) -> int:
"""
Nodes such as
``%_tensor_constant0_1 : [num_users=1] = get_attr[target=_tensor_constant0]``
are part of the replacement in function ``replace_all_uses_with``.
Let's remove the duplicates first.
:return: number of impacted get_attr nodes
"""
targets = {}
to_replace = []
for node in graph.nodes:
if node.op == "get_attr":
if node.target in targets:
# replacements
to_replace.append((node, targets[node.target]))
else:
targets[node.target] = node
if to_replace:
for node, by in to_replace:
node.replace_all_uses_with(by)
graph.erase_node(node)
return len(to_replace)
[docs]
@classmethod
def remove_unnecessary_slices(cls, graph: torch.fx.Graph) -> int:
"""
Removes unnecessary slices:
:param graph: graph to modify
:return: number of inplace nodes removed
::
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
(args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
"""
nodes = list(enumerate(graph.nodes))
removed = 0
for pos, node in nodes:
if not hasattr(node.target, "name"):
continue
if node.target.name() != "aten::slice.Tensor":
continue
if len(node.args) != 4 or node.args[2] != 0 or node.args[3] != 9223372036854775807:
continue
# The first argument is the node to keep.
new_name = node.args[0]
old_name = node
# Let's replace.
changed = old_name.replace_all_uses_with(new_name)
assert changed, (
f"No change applied, the node [{node}] at position {pos} "
f"can be removed and replaced by {old_name} in \n{graph}."
)
graph.erase_node(old_name)
removed += 1
return removed
[docs]
@classmethod
def graph_erase_node(cls, graph: torch.fx.Graph, node: torch.fx.Node):
"""
Removes a node all predecessors with are only consumed by this one.
"""
nodes = [node]
while (
node.op == "call_function"
and node.args
and isinstance(node.args[0], torch.fx.Node)
and all(isinstance(_, (int, float)) for _ in node.args[1:])
and len(node.args[0].users) == 1
):
node = node.args[0]
nodes.append(node)
for node in nodes:
graph.erase_node(node)
[docs]
@classmethod
def remove_inplace(
cls,
graph: torch.fx.Graph,
exported_program: Optional[torch.export.ExportedProgram] = None,
) -> int:
"""
Removes inplace operations.
:param graph: graph to modify
:param exported_program: if available, it is used in the error message
to make it easier to trace the code source
:return: number of inplace nodes removed
The most difficult pattern is the following:
::
%slice_11 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
(args = (%clone, 0, 0, 9223372036854775807), kwargs = {})
%slice_12 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
(args = (%slice_11, 1, 0, 9223372036854775807), kwargs = {})
%slice_13 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor]
(args = (%slice_12, 2, 0, 9223372036854775807), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default]
(args = (%slice_13, %masked_fill), kwargs = {})
"""
inplace = cls._inplace_nodes(graph)
if len(inplace) == 0:
# No inplace.
return False
n_inplace = len(inplace)
cls._replace_getattr(graph)
cls._replace_meth_setitem(graph)
def delete_user_cb(n, nodes_to_leave):
return n not in nodes_to_leave
err_graph = str(graph)
max_iter = 10
while inplace and max_iter > 0:
existing_nodes = list(enumerate(graph.nodes))
for pos, node in reversed(inplace):
if node.target in {
operator.add,
operator.floordiv,
operator.mul,
operator.mod,
operator.sub,
}:
# This node cannot be one inplace modifications. The node is just not used.
graph.erase_node(node)
continue
if hasattr(node.target, "name"):
if (
node.target.name()
in {
"aten::view",
"aten::detach_", # output = input
"aten::add.Tensor", # it happens when running
"aten::div.Tensor", # z = f(x=x, y=x+1) but f does not use y
"aten::mul.Tensor",
"aten::sub.Tensor",
"aten::zeros", # unused as it does not end up with '_'
}
or node.target.name()[-1] != "_" # not an inplace modification
):
# This node cannot be one inplace modifications.
# The node is just not used.
cls.graph_erase_node(graph, node)
continue
if len(node.args) == 1:
# Simple casen we check the predecessor is only used once and
# in that case, we can remove as well.
predecessor = node.args[0]
if len(predecessor.users):
# We can safely remove as the precessessor
# is only used by this node
cls.graph_erase_node(graph, node)
continue
assert node.target.name() in {"aten::copy_"} and len(node.args) == 2, (
f"(inplace) Unsupported target {node.target!r}, target_name="
f"{node.target.name()!r}, name={node.name!r}, node.args={node.args} "
f"at position {pos}/{len(graph.nodes)}"
f"\n--original graph--\n{err_graph}"
f"\n--graph\n{exported_program or graph}"
)
# We change the predecessor of the node is a node clone.
predecessor = node.args[0]
assert (
hasattr(predecessor.target, "name")
and predecessor.target.name() == "aten::clone"
), (
f"(inplace) Unexpected predecessor {predecessor.target!r} "
f"for node {node.name!r} with args={node.args} at position "
f"{pos}/{len(graph.nodes)}"
f"\n--original graph--\n{err_graph}"
f"\n--graph\n{exported_program or graph}"
)
# class Node can be used as a key
# We also assume a user is placed after this node.
nodes_to_leave = {n[1] for n in existing_nodes[: pos + 1]}
node_args = node.args
p_users = predecessor.users
# We can replace with expand then.
with graph.inserting_before(node):
# We assume the first argument is the one modified inplace.
new_node = graph.call_method(
"expand_as", args=(node_args[1], predecessor)
)
# let's replace
changed = predecessor.replace_all_uses_with(
new_node,
delete_user_cb=(
lambda n, leave=nodes_to_leave: delete_user_cb(n, leave)
),
)
graph.erase_node(node)
# new_node is replaced as well so we manually revert the replacement
new_node.update_arg(1, predecessor)
assert changed, (
f"No change applied, the inplace node [{node}] "
f"at position {pos} with node.args={node_args}, was not replaced "
f"by [{new_node}] with target {new_node.target!r} and "
f"new_node.args={new_node.args}, predecessor="
f"[{predecessor}] with target={predecessor.target!r}, "
f"p_users={list(p_users)}, "
f"predecessor.users={list(predecessor.users)}, "
f"new_node.users={list(new_node.users)} in "
f"\n{exported_program or graph}"
)
else:
assert node.target in {
"add_",
"div_",
"mul_",
"mod_",
"sub_",
operator.setitem,
}, (
f"Unsupported target {node.target!r}, name={node.name!r} "
f"at position {pos}/{len(graph.nodes)}\n--graph\n{graph}"
)
# We assume the first argument is the one modified inplace.
new_name = node
old_name = node.args[0]
# class Node can be used as a key
# We also assume a user is placed after this node.
nodes_to_leave = {n[1] for n in existing_nodes[: pos + 1]}
# let's replace
changed = old_name.replace_all_uses_with(
new_name,
delete_user_cb=(
lambda n, leave=nodes_to_leave: delete_user_cb(n, leave)
),
)
assert changed, (
f"No change applied, the inplace node [{node}] at position {pos} "
f"does not replace [{old_name}] in \n{graph}\n-- node to keep --"
f"\n{nodes_to_leave}"
)
# We need to continue in case one unused node left another one
# after it was removed. It could be improved by looking at
inplace = cls._inplace_nodes(graph)
if len(inplace) == 0:
# No inplace left.
break
max_iter -= 1
assert len(inplace) == 0, (
f"Inplace nodes remain at positions {sorted(inplace)}"
f"/{len(graph.nodes)} in\n{graph}\n--original graph--\n{err_graph}"
)
return n_inplace