Source code for experimental_experiment.torch_interpreter.dispatcher

import inspect
from typing import Any, Callable, Dict, List, Optional


[docs] class Dispatcher: """ Used to changes the way class :class:`DynamoInterpreter <experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter>` selects the function translating aten function or module. :param registered_functions: registered functions :param verbose: verbose """ def __init__(self, registered_functions: Dict[str, Callable], verbose: int = 0): self.registered_functions = registered_functions self.verbose = verbose def _get_function_name(self, name: Any) -> str: if isinstance(name, str): return name if isinstance(name, type(abs)): new_name = f"aten_{name.__name__.replace('.', '_')}" if new_name in self.registered_functions: return new_name lookup_names = ["__qualname__", "__name__"] for att in lookup_names: if hasattr(name, att): v = getattr(name, att).replace(".", "_") if v in self.registered_functions: return v return str(v)
[docs] def find_function(self, name: Any) -> Optional[Callable]: """ Finds the most suitable function to translate a function. :param name: function name or definition :return: the function or None if not found The signature of the returned function is similar to a function such as :func:`aten_elu <experimental_experiment.torch_interpreter._aten_functions.aten_elu>`. """ key = self._get_function_name(name) if key not in self.registered_functions: if self.verbose > 3: print( f"[Dispatcher.find_function] could not find a " f"function for key={key!r} with name={name!r}" ) return None return self.registered_functions[key]
[docs] def find_method(self, name: Any) -> Optional[Callable]: """ Finds the most suitable function to translate a method. :param name: method name or definition :return: the function or None if not found The signature of the returned function is similar to a function such as :func:`aten_elu <experimental_experiment.torch_interpreter._aten_functions.aten_elu>`. """ if name not in self.registered_functions: if self.verbose > 3: print(f"[Dispatcher.find_method] could not find a method for name={name!r}") return None return self.registered_functions[name]
[docs] def fallback( self, name: Any, fct: Optional[Callable], args: List[Any], kwargs: Dict[str, Any], builder: "GraphBuilder", # noqa: F821 ) -> Optional[Callable]: """ The function is called after the function converting an aten function into ONNX. *fct* is this function. It can be changed and just set when mapping was found. :param name: object or str :param fct: function found so far :param args: known arguments coming from the graph module :param kwargs: known named arguments coming from the graph module :param builder: GraphBuilder :return: callable """ return fct
[docs] class ForceDispatcher(Dispatcher): """ Implements a dispatcher which fails whenever there is no converting for a node in the fx graph. There is no fallback to the existing functions. When no function is found, an onnx node is added with a non standard domain. :param signatures: function used only for their signature mapping a name to a function in order to have parameter names :param verbose: verbose :param domain: domain of the added node :param version: version of the domain :param strict: when an input is not a tensor, it becomes a named parameter if strict is False :param only_registered: fails if a function is not found in signatures """ def __init__( self, signatures: Optional[Dict[str, Callable]] = None, verbose: int = 0, domain: str = "aten.lib", version: int = 1, strict: bool = False, only_registered: bool = False, ): super().__init__({}, verbose=verbose) self.signatures = signatures or {} self.domain = domain self.version = version self.strict = strict self.only_registered = only_registered self._process_signatures() @classmethod def _convert_into_type(cls, annotation): assert ( annotation is not None and annotation is not inspect._empty ), f"Unexpected annotation={annotation}" if annotation in (float, int, bool): return annotation if hasattr(annotation, "_name") and annotation._name == "List": assert len(annotation.__args__) == 1, f"Unexpected annotation {annotation}" assert annotation.__args__[0] in (float, int, bool), ( f"Unexpected annotation {annotation}, " f"annotation.__args__[0]={annotation.__args__[0]!r}" ) t = annotation.__args__[0] return lambda v, t=t: [t(_) for _ in v] raise RuntimeError(f"Unexpected annotation {annotation!r}") def _process_signature(self, f: Callable): args = [] kwargs = [] sig = inspect.signature(f) has_annotation = any( (p.annotation is not None and p.annotation is not inspect._empty) for p in sig.parameters.values() ) # If there is annotation, we assume every result = None # without annotation is an optional Tensor. for name, p in sig.parameters.items(): ann = p.annotation if p.default is inspect._empty: args.append(name) elif p.default is None: noann = p.annotation is None or p.annotation is inspect._empty if has_annotation and noann: args.append(name) elif not noann: kwargs.append( ( name, p.default, ( None if ann is inspect._empty or ann is None else self._convert_into_type(ann) ), ) ) else: raise RuntimeError( f"Unable to determine if parameter {name!r} " f"is an input or a parameter, annotation is {p.annotation}, " f"default is {p.default!r} for function {f}, " f"has_annotation={has_annotation}" ) else: kwargs.append( ( name, p.default, ( None if ann is inspect._empty or ann is None else self._convert_into_type(ann) ), ) ) return args, kwargs def _process_signatures(self): self.sigs_ = {} for k, v in self.signatures.items(): sig = self._process_signature(v) self.sigs_[k] = sig
[docs] def fallback( self, name: Any, fct: Optional[Callable], args: List[Any], kwargs: Dict[str, Any], builder: "GraphBuilder", # noqa: F821 ) -> Optional[Callable]: """ The function is called after the function converting an aten function into ONNX. *fct* is this function. It can be changed and just set when mapping was found. :param name: object or str :param fct: function found so far :param args: known arguments coming from the graph module :param kwargs: known named arguments coming from the graph module :param builder: GraphBuilder :return: callable """ if fct is not None: # The conversion has been found. return fct fname = self._get_function_name(name) def wrapper( g, sts, outputs, *args, _name=fname, _domain=self.domain, _version=self.version, _only_registered=self.only_registered, **kwargs, ): sig = self.sigs_.get(_name, None) assert ( not _only_registered or sig is not None ), f"Unable to find a function with {_name!r}{g.get_debug_msg()}" kwargs = kwargs.copy() new_args = [] for i, n in enumerate(args): if isinstance(n, str): new_args.append(n) continue if isinstance(n, g.torch.Tensor): init = g.make_initializer( "", n, source=f"ForceDispatcher.fallback.wrapper.{i}" ) new_args.append(init) continue if not sig: if self.strict: raise RuntimeError( f"Unsupported type {type(n)} for argument {i} " f"for function {_name!r}{g.get_debug_msg()}" ) kwargs[f"param_{i}"] = n continue a, kw = sig if n is None and i < len(a): # An optional input. new_args.append("") continue assert i >= len(a), ( f"Unsupported type {type(n)} for argument {i} for function {_name!r}" f"sig={sig}, {g.get_debug_msg()}" ) ni = i - len(a) assert ni < len(kw), ( f"Unexpected argument at position {i}, for function {_name!r}" f"sig={sig}{g.get_debug_msg()}" ) p = kw[ni] kwargs[p[0]] = n if p[2] is None else p[2](n) # for some arguments given as named arguments new_kwargs = {} for k, v in kwargs.items(): if isinstance(v, g.torch.fx.node.Node): new_args.append(v.name) continue new_kwargs[k] = v # Let's get rid of the empty name at the end of the inputs. i = len(new_args) - 1 while i >= 0 and new_args[i] == "": i -= 1 new_args = new_args[: i + 1] if i >= 0 else [] g.add_domain(_domain, _version) g.make_node( _name, new_args, outputs=outputs, domain=_domain, name=g.unique_node_name(_name), **new_kwargs, ) if len(outputs) == 1: return outputs[0] return tuple(outputs) return wrapper