Source code for experimental_experiment.torch_interpreter.tracing

import contextlib
import inspect
import math
import operator
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.fx import Node
from torch.fx.proxy import TracerBase

_torch_cat = torch.cat


[docs] class CustomProxy(torch.fx.proxy.Proxy): """ Defines a custom proxy to trace the execution of a model and converts it into a fx graph. Works with :class:`CustomTracer <experimental_experiment.torch_interpreter.tracing.CustomTracer>`. """ def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None): super().__init__(node, tracer=tracer) assert isinstance( self.tracer, CustomTracer ), f"Unexpected type {type(self.tracer)} for the tracer." def __repr__(self) -> str: return f"{self.__class__.__name__}({self.node.name})" def _custom_fx_repr_fn(self) -> str: "To avoid bugs." return f"CustomProxy(%{str(self.node)})" def __getattr__(self, k) -> "CustomAttribute": # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return CustomAttribute(self, k) @classmethod def __torch_function__(cls, orig_method, types, args=None, kwargs=None): if isinstance(orig_method, torch._ops.HigherOrderOperator): # not implemented by torch if orig_method is torch.cond: assert ( not kwargs ), f"Unexpected kwargs={kwargs}, args={args}, orig_method={orig_method}" assert ( len(args) == 4 ), f"Unexpected kwargs={kwargs}, args={args}, orig_method={orig_method}" assert isinstance( args[3], list ), f"Unexpected type {type(args[3])} for the last argument" root = args[0] cond_true = root.tracer.register_callable("cond", args[1]) cond_false = root.tracer.register_callable("cond", args[2]) node = root.tracer.create_node( "call_function", orig_method, args=( args[0].node, cond_true, cond_false, type(args[3])(a.node for a in args[3]), ), kwargs={}, ) return root.tracer.proxy(node) return torch.fx.proxy.Proxy.__torch_function__( orig_method, types, args=args, kwargs=kwargs ) def __setitem__(self, *args, **kwargs): assert not kwargs, f"Unexpected not empty kwargs={kwargs!r}" assert len(args) == 2, f"Unexpected number of args={len(args)}: {args}" indices, values = args if isinstance(indices, CustomProxy): indices = indices.node node = self.tracer.create_node( "call_function", operator.setitem, args=(self.node, indices, values.node if hasattr(values, "node") else values), kwargs={}, ) # node_to_replace = self.node return self.tracer.proxy(node) def __len__(self): raise RuntimeError( "len(.) expects an integer, len needs to be replaced. You should use _len." )
[docs] def length(self): """Returns a proxy for the length.""" node = self.tracer.create_node("call_method", "__len__", args=(self.node,), kwargs={}) tt = self.tracer.proxy(node, cls=CustomProxyInt) return tt
[docs] def instanceof(self, cls): """Tells if this proxy represents a specific class.""" raise RuntimeError(f"Unable to know if cls is from type {cls}.")
[docs] @classmethod def cat( cls, tensors: List["CustomProxy"], dim: int = 0, *, out=None, axis: Optional[int] = None, ) -> "CustomProxy": """Implements cat for tensors.""" assert out is None, "Tracing is not implementing is out is not None." if isinstance(tensors, list): return _torch_cat(tensors, dim) if axis is not None and dim == 0: dim = axis proxy = tensors node = proxy.tracer.create_node( "call_function", torch.cat, args=(proxy.node, dim), kwargs={} ) return proxy.tracer.proxy(node)
def _len(x: Any) -> Union[int, CustomProxy]: """ Overloads `len` to return a proxy if the input is the proxy. """ if isinstance(x, CustomProxy): return x.length() return len(x) def _isinstance(x, cls): """ Overloads `isinstance` to deal with CustomProxy. """ if isinstance(x, CustomProxy): return x.instanceof(cls) return isinstance(x, list)
[docs] class CustomProxyInt(CustomProxy): "A proxy for an integer."
[docs] def instanceof(self, cls): """isinstance""" return cls in {CustomProxyInt, CustomProxy, int}
[docs] class CustomProxyFloat(CustomProxy): "A proxy for a float."
[docs] def instanceof(self, cls): """isinstance""" return cls in {CustomProxyInt, CustomProxy, float}
[docs] class CustomAttribute(CustomProxy): """ To trace attributes. """ def __init__(self, root: CustomProxy, attr: str): self.root = root self.attr = attr self.tracer = root.tracer self._node: Optional[Node] = None @property def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: self._node = self.tracer.create_proxy( "call_function", getattr, (self.root, self.attr), {} ).node return self._node def __call__(self, *args, **kwargs): return self.tracer.create_proxy("call_method", self.attr, (self.root, *args), kwargs)
[docs] class CustomParameterProxy(CustomProxy): """ A special proxy which lets "shape", "size", "dim", and a few other attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing. """ def __init__(self, tracer: TracerBase, node: Node, name, param): super().__init__(node, tracer) assert isinstance(param, torch.nn.Parameter) self.param = param self.name = name def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name})" @property def shape(self): return self.param.shape def size(self): return self.param.size() def dim(self): return self.param.dim() @property def ndim(self): return self.param.ndim def numel(self): return self.param.numel() def nelement(self): return self.param.nelement()
[docs] class CondCCOp(torch._ops.HigherOrderOperator): """ Cannot be imported from torch.ops.higher_order.cond (function cond overwrite submodule cond). """ def __init__(self): # we cannot use "cond" to avoid confusion with the existing cond super().__init__("condcc") def __call__(self, pred, true_fn, false_fn, operands): # torch._higher_order_ops.utils.validate_subgraph_args_types(operands) return super().__call__(pred, true_fn, false_fn, operands)
[docs] @contextlib.contextmanager def replace_problematic_function_before_tracing(): """ Replaces function that cannot be traced with the default tracer such as :func:`torch.cat`. """ saved = { "cat": torch.cat, "cond": torch.cond, # ("torch.ops.higher_order", "cond"): torch.ops.higher_order.cond, } newf = { "cat": CustomProxy.cat, "cond": CondCCOp(), # ("torch.ops.higher_order", "cond"): CondOp(), } for k, v in newf.items(): if isinstance(k, tuple): setattr(k[0], k[1], v) else: setattr(torch, k, v) try: yield finally: for k, v in saved.items(): if isinstance(k, tuple): setattr(k[0], k[1], v) else: setattr(torch, k, v)
[docs] class CustomTracer(torch.fx.Tracer): """ Defines a custom tracer to trace the execution of a model and converts it into a fx graph. Works with :class:`CustomProxy <experimental_experiment.torch_interpreter.tracing.CustomProxy>`. :: from experimental_experiment.torch_interpreter.tracing import CustomTracer graph = CustomTracer().trace(model) """ def __init__( self, autowrap_modules: Tuple["ModuleType"] = (math,), # noqa: F821 autowrap_functions: Tuple[Callable, ...] = (), param_shapes_constant: bool = False, ): super().__init__( autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions, param_shapes_constant=param_shapes_constant, ) self._callables = {}
[docs] def register_callable(self, name: str, fn: Callable) -> torch.fx.Node: """ Registers a function and return a unique name. :param name: prefix to prepend to the function name :param fn: function :return: new_name """ cand = f"_cb_{name}_{fn.__name__}_0" if cand in self._callables: i = 1 cand = f"_cb_{name}_{fn.__name__}_{i}" while cand in self._callables: i += 1 cand = f"_cb_{name}_{fn.__name__}_{i}" self._callables[cand] = fn return self.create_node("get_attr", cand, args=(), kwargs={})
[docs] def proxy( self, node: torch.fx.Node, cls: type[CustomProxy] = CustomProxy ) -> torch.fx.Proxy: """ Overwrites this method to replace the default Proxy by CustomProxy. """ return cls(node, self)
[docs] def create_arg(self, a: Any) -> "Argument": # noqa: F821 """ Overwrites this method to deal with more argument. """ if a is bool: return torch.bool if a is int: return torch.int64 if a is float: return torch.float32 if a is complex: return torch.complex64 return super().create_arg(a)
[docs] def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): """ See :meth:`torch.fx.Tracer.getattr`. """ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache): for n, p in collection_to_search: if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} if ( "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters ): kwargs["proxy_factory_fn"] = ( None if not self.param_shapes_constant else lambda node, n=n, attr_val=attr_val: CustomParameterProxy( self, node, n, attr_val ) ) val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None if isinstance(attr_val, torch.nn.Parameter): maybe_parameter_proxy = maybe_get_proxy_for_attr( attr_val, self.root.named_parameters(), parameter_proxy_cache ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): maybe_buffer_proxy = maybe_get_proxy_for_attr( attr_val, self.root.named_buffers(), parameter_proxy_cache ) if maybe_buffer_proxy is not None: return maybe_buffer_proxy return attr_val
[docs] def trace( self, root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, remove_inplace: bool = True, update_model_with_callable: bool = True, ) -> torch.fx.Graph: """ Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` can either be an ``nn.Module`` instance or a Python callable. Note that after this call, ``self.root`` may be different from the ``root`` passed in here. For example, when a free function is passed to ``trace()``, we will create an ``nn.Module`` instance to use as the root and add embedded constants to. :param root: Either a ``Module`` or a function to be traced through. Backwards-compatibility for this parameter is guaranteed. :param concrete_args: Concrete arguments that should not be treated as Proxies. This parameter is experimental and its backwards-compatibility is *NOT* guaranteed. :param remove_inplace: Removes inplace nodes :param update_model_with_attribute: in some cases (control flow), the model needs to be :return: A ``Graph`` representing the semantics of the passed-in ``root``. """ with replace_problematic_function_before_tracing(): graph = super().trace(root, concrete_args) self._replace_problematic_functions(graph) if update_model_with_callable and self._callables: for k, v in self._callables.items(): setattr(root, k, v) if not remove_inplace: graph.lint() return graph self.remove_inplace(graph) graph.lint() return graph
@classmethod def _replace_problematic_functions(cls, graph: torch.fx.Graph) -> int: """ The tracing introduced some problematic functions which need to be replaced. :return: number of impacted nodes """ replaces = { CustomProxy.cat: torch.cat, # CondCCOp: torch.ops.higher_order.cond, } n = 0 for node in graph.nodes: if node.op == "call_function": if node.target in replaces: n += 1 node.target = replaces[node.target] elif isinstance(node.target, CondCCOp): n += 1 node.target = torch.ops.higher_order.cond return n @classmethod def _get_aten_name(cls, node: torch.fx.Node) -> str: """ Returns the aten name for the target as a string. """ if node.target == operator.getitem: return "getitem" if isinstance(node.target, torch._ops.OpOverloadPacket): if node.target != torch.ops.aten.sym_size: raise RuntimeError(f"Unsupported function {node!r}.") raise NotImplementedError(f"Unsupported function {node!r} (not implemented).") if isinstance(node.target, types.BuiltinFunctionType): return str(node.target) if isinstance(node.target, torch._ops.OpOverload): return node.target.name() if callable(node.target): # a single function return f"aten_{node.target.__name__}" if isinstance(node.target, str): return node.target raise NotImplementedError( f"Unsupported function {node!r} (not implemented), " f"node.target={node.target}, type is {type(node.target)}." ) @classmethod def _inplace_nodes(cls, graph: torch.fx.Graph) -> List[Tuple[int, torch.fx.Node]]: """ Returns the position and the node involved in inplace modifications. """ return [ (i, node) for i, node in enumerate(graph.nodes) if node.op != "output" and len(node.users) == 0 and node.op.startswith("call_") and node.target not in {operator.getitem} and cls._get_aten_name(node) not in { "aten::_assert_scalar", "aten::sym_constrain_range_for_size", "aten::_log_api_usage_once", "aten::_enter_autocast", "aten::_set_grad_enabled", } ] @classmethod def _replace_meth_setitem(cls, graph: torch.fx.Graph) -> int: """ The execution of ``op="call_method", target="__setitem__" `` returns None We replace it by ``op="call_function", target="operator.setitem"``. :return: number of impacted nodes """ n = 0 for node in graph.nodes: if node.op == "call_method" and node.target == "__setitem__": node.op = "call_function" node.target = operator.setitem n += 1 return n @classmethod def _replace_getattr(cls, graph: torch.fx.Graph) -> int: """ Nodes such as ``%_tensor_constant0_1 : [num_users=1] = get_attr[target=_tensor_constant0]`` are part of the replacement in function ``replace_all_uses_with``. Let's remove the duplicates first. :return: number of impacted get_attr nodes """ targets = {} to_replace = [] for node in graph.nodes: if node.op == "get_attr": if node.target in targets: # replacements to_replace.append((node, targets[node.target])) else: targets[node.target] = node if to_replace: for node, by in to_replace: node.replace_all_uses_with(by) graph.erase_node(node) return len(to_replace)
[docs] @classmethod def remove_inplace(cls, graph: torch.fx.Graph) -> int: """ Removes inplace operations. :return: number of inplace nodes removed """ inplace = cls._inplace_nodes(graph) if len(inplace) == 0: # No inplace. return False n_inplace = len(inplace) cls._replace_getattr(graph) cls._replace_meth_setitem(graph) def delete_user_cb(n, nodes_to_leave): return n not in nodes_to_leave existing_nodes = list(enumerate(graph.nodes)) for pos, node in reversed(inplace): assert node.target in { "add_", "div_", "mul_", "sub_", "mod_", operator.setitem, }, ( f"Unsupported target {node.target!r} at position {pos}/{len(graph.nodes)}" f"\n--graph\n{graph}" ) # We assume the first argument is the one modified inplace. new_name = node old_name = node.args[0] # class Node can be used as a key # We also assume a user is placed after this node. nodes_to_leave = {n[1] for n in existing_nodes[: pos + 1]} # let's replace changed = old_name.replace_all_uses_with( new_name, delete_user_cb=lambda n, leave=nodes_to_leave: delete_user_cb(n, leave), ) assert changed, ( f"No change applied, the inplace node [{node}] at position {pos} " f"does not replace [{old_name}] in \n{graph}\n-- node to keep --" f"\n{nodes_to_leave}" ) inplace = cls._inplace_nodes(graph) assert ( len(inplace) == 0 ), f"Inplace nodes remain at positions {sorted(_[0] for _ in inplace)} in\n{graph}" return n_inplace