Source code for experimental_experiment.torch_interpreter.interpreter

import os
import inspect
import operator
import pprint
import types
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np
from onnx import TensorProto
from ..helpers import string_type, make_hash
from ..xbuilder import GraphBuilder, FunctionOptions, VirtualTensor
from ..xbuilder._shape_helper import all_int, DYNAMIC_SHAPE
from ..xbuilder._dtype_helper import (
    torch_dtype_to_onnx_dtype,
    onnx_dtype_to_torch_dtype,
)
from ..xbuilder.model_container import _get_type
from ..xbuilder.expression_dimension import parse_expression_tokens
from . import LOCAL_DOMAIN
from .export_options import ExportOptions
from ._exceptions import FunctionNotFoundError
from .aten_functions import find_function
from .aten_methods import find_method


[docs] class DynamoInterpreter: """ Interprets a torch graph into an ONNX graph. Dispatches every node to the appropriate converting function. :param graph_builder: a graph builder :param retriever: callable to help retrieve the weights in a module, see function `_retrieve <experimental_experiment.torch_interpreter.onnx_export._retrieve>`. :param dispatcher: see :class:`experimental_experiment.torch_interpreter.Dispatcher` :param export_options: see :class:`ExportOptions <experimental_experiment.torch_interpreter.ExportOptions>` :param optimize_submodules: optimizes submodules after they are built :param submodule_naming: a function which returns a submodule name in the onnx graph :param parameter_naming: a function which returns a parameter name in the onnx graph :param module_name: module name (makes it easier to retrieve the parameter names) """ def _hash(self) -> str: return make_hash(self) def __init__( self, graph_builder: GraphBuilder, retriever: Callable, dispatcher: Optional["Dispatcher"] = None, # noqa: F821 example_inputs: Optional[Tuple["torch.Tensor", ...]] = None, # noqa: F821 export_options: Optional[ExportOptions] = None, optimize_submodules: bool = False, function_options: Optional[FunctionOptions] = None, submodule_naming: Optional[Callable] = None, parameter_naming: Optional[Callable] = None, module_name: Optional[str] = None, default_values: Optional[Dict[str, Any]] = None, ): import torch from ..xbuilder import FunctionOptions self.torch = torch self.builder = graph_builder self.retriever = retriever self.dispatcher = dispatcher self.export_options = export_options self.optimize_submodules = optimize_submodules self.function_options = function_options or FunctionOptions( name="*", domain="*", export_as_function=True, external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True, ) self.example_values_ = {} assert example_inputs is None or isinstance( example_inputs, tuple ), f"Unexpected type for example_inputs {type(example_inputs)}" assert example_inputs is None or all( ( t is None or isinstance( t, ( torch.SymInt, torch.SymFloat, torch.Tensor, list, int, float, VirtualTensor, ), ) or t.__class__.__name__ in {"DynamicCache", "MambaCache", "patched_DynamicCache"} ) for t in example_inputs ), ( f"Unexpected type for one input in example_inputs " f"{[type(t) for t in example_inputs]}" ) self.example_inputs_ = example_inputs self.flat_example_inputs_ = self.flatten_inputs(example_inputs) self.current_input_ = 0 self.preserved_modules = set() self.parent_interpreter = None self.parameter_naming = parameter_naming self.submodule_naming = submodule_naming self.module_name = module_name self.default_values = default_values or {} self._debug_aten_as_function = int(os.environ.get("ATENDEBUG", "0"))
[docs] def register_named_modules( self, parent_interpreter: Optional["DynamoInterpreter"], preserved_modules: Optional[Set[type["torch.nn.Module"]]], # noqa: F821 named_modules: Dict[str, "torch.nn.Module"], # noqa: F821 ): """ Registers a list of modules to preserve as local function in the onnx model. If empty, the graph is almost inlined. The module to convert to onnx should the output of method :func:`torch.export.unflatten.unflatten`. """ assert parent_interpreter is None or isinstance( parent_interpreter, DynamoInterpreter ), f"Unexpected type {type(parent_interpreter)} for the interpreter" if self.builder.verbose > 4 and preserved_modules: print( f"[DynamoInterpreter-{self._hash()}.register] " f"{sorted(c.__name__ for c in preserved_modules)}" ) self.named_modules = named_modules self.preserved_modules = preserved_modules or parent_interpreter.preserved_modules if parent_interpreter is not None: self.submodule_naming = parent_interpreter.submodule_naming self.parameter_naming = parent_interpreter.parameter_naming
[docs] def flatten_inputs(self, x: Any) -> List["torch.Tensor"]: # noqa: F821 """ Flatten inputs. """ if x is None: return x if isinstance(x, (list, tuple)): res = [] for i in x: if i is None or isinstance( i, ( self.torch.Tensor, self.torch.SymInt, self.torch.SymFloat, int, float, VirtualTensor, ), ): res.append(i) else: res.extend(self.flatten_inputs(i)) return tuple(res) if isinstance(x, tuple) else res if x.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"): res = self.flatten_inputs(x.key_cache) + self.flatten_inputs(x.value_cache) return tuple(res) if x.__class__.__name__ in ("MambaCache",): return (x.conv_states, x.ssm_states) raise AssertionError(f"Unexpected type {type(x)} for x")
[docs] def run_node(self, node: "torch.fx.Node"): # noqa: F821 """ Runs a node: call the approrpiate method based on the node type. """ example_value = None if hasattr(node, "meta") and "example_value" in node.meta: if isinstance(node.target, str) or callable(node.target): self.example_values_[node.target] = node.meta["example_value"] example_value = self.example_values_[node.target] else: raise RuntimeError( f"Unexpected type {type(node.target)} " f"for node.target in {node}, op={node.op}, " f"node.target={node.target}, node.meta={node.meta}." ) if self.builder.verbose > 1: # verbose exa = ( f"{torch_dtype_to_onnx_dtype(example_value.dtype)}'{tuple(example_value.shape)}" if hasattr(example_value, "dtype") else "" ) v = node.meta.get("val", None) if hasattr(node, "meta") else None val = ( f"{torch_dtype_to_onnx_dtype(v.dtype)}'{tuple(v.shape)}" if hasattr(v, "dtype") else "" ) symbol = "#" if self._can_set_shape_and_type(node) else "-" a1 = "E" if hasattr(node, "meta") and "example_value" in node.meta else "-" a2 = "A" if hasattr(node, "meta") and "val" in node.meta else "-" print( f"[DynamoInterpreter-{self._hash()}.run_node][{symbol}{a1}{a2}] " f"{node.op}:{node.name}:{exa}:{val}" ) # debug exa = ( ("example_value", example_value.dtype, example_value.shape) if hasattr(example_value, "dtype") else "" ) v = node.meta.get("val", None) if hasattr(node, "meta") else None val = ("val", v.dtype, v.shape) if hasattr(v, "dtype") else "" self.builder.set_shapes_types(node.name, "run_node", (exa, val)) self.builder.register_users(node.name, node.users) if node.op == "placeholder": res = self.placeholder(node) elif node.op == "call_function": res = self.call_function(node) elif node.op == "output": res = self.output(node) elif node.op == "call_module": self.builder._check_constants(f"before-{node.op}") res = self.call_module(node) self.builder._check_constants(f"after-{node.op}") elif node.op == "get_attr": res = self.get_attr(node) elif node.op == "call_method": res = self.call_method(node) else: raise ValueError(f"Unable to process node kind {node.op!r} ({node}).") # Checks consistency of shapes and types name = node.name if val and len(val) == 3: exp_dtype, exp_shape = val[1:] if isinstance(exp_dtype, int): exp_dtype = onnx_dtype_to_torch_dtype(exp_dtype) if self.builder.has_type(name): itype = self.builder.get_type(name) ttype = onnx_dtype_to_torch_dtype(itype) aten_name = self._get_aten_name(node) if node.op == "call_function" else "-" assert ttype == exp_dtype, ( f"Type mismatch for {name!r}, node.op={node.op!r}, " f"aten_name={aten_name!r}, " f"onnx {ttype} != expected torch " f"{exp_dtype}{self.builder.get_debug_msg()}" ) if self.builder.has_shape(name): shape = self.builder.get_shape(name) self.builder._check_two_shapes_are_compatible( tuple(exp_shape), shape, name=name, register_int=False, ) return res
[docs] def get_attr(self, node: "torch.fx.Node"): # noqa: F821 """ Retrieves an attribute. """ if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.get_attr][{node.name}]") try: init = getattr(node.graph.owning_module, node.target) except AttributeError as e: # Maybe it is a parameter: init = None for name, p in node.graph.owning_module.named_parameters(): if name == node.target: init = p if init is None: raise AttributeError( f"Unable to find attribute {node.target!r} (node.name={node.name!r}) in " f"type(owning_module)={type(node.graph.owning_module)}, " f"\nmodules=" f"{sorted([_[0] for _ in node.graph.owning_module.named_modules()])}" f"\nparameters=" f"{sorted([_[0] for _ in node.graph.owning_module.named_parameters()])}" f"\nnode.__dict__={node.__dict__}{self.builder.get_debug_msg()}" ) from e if isinstance(init, self.torch.fx.GraphModule): # This function is meant to be used later. if "." in self.builder.local_domain: root, n = self.builder.local_domain.split(".") n = int(n) + 1 else: root, n = self.builder.local_domain, 0 builder, _args, _kwargs, _output_names = self._interpret_sub_module( init, None, None, source_node=node, local_domain=f"{root}.{n}" ) self.builder.make_local_function( builder, function_options=FunctionOptions( name=node.name, domain=self.builder.local_domain, export_as_function=True, return_initializer=True, move_initializer_to_constant=self.function_options.move_initializer_to_constant, external_threshold=self.function_options.external_threshold, merge_allowed=self.function_options.merge_allowed, rename_allowed=self.function_options.rename_allowed, ), optimize=self.optimize_submodules, ) return None parameter_name = ( self.parameter_naming(node.name, init, node=node, prefix=self.module_name) if isinstance(init, self.builder.torch.nn.Parameter) else None ) self.builder.make_initializer( node.name, init, parameter_name=parameter_name, source=( f"DynamoInterpret.get_attr.1/P({parameter_name})" if parameter_name else "DynamoInterpret.get_attr.0" ), ) return node.name
def _make_tensor_check(self, name: str, fake_tensor: bool, users: Any): if ( not fake_tensor and self.example_inputs_ is not None and not self.builder.was_inputs_renamed ): assert len(self.builder.input_names) < len(self.flat_example_inputs_), ( f"Too many inputs already ({len(self.builder.input_names)}), " f"self.current_input_={self.current_input_}, " f"unexpected {name!r} " f"after {self.builder.input_names}" f"{self.builder.get_debug_msg()}" ) if ( not self.builder.as_function and self.flat_example_inputs_[self.current_input_] is None ): # We skip it. assert users is None or len(users) == 0, ( f"Input {name!r} (index {self.current_input_}" f"/{len(self.flat_example_inputs_)}) " f"is None but it is used by {users}, " f"as_function={self.builder.as_function}. " f"Existing inputs {self.builder.input_names}. Example inputs: " f"{['-' if t is None else t.shape for t in self.flat_example_inputs_]}" f"{self.builder.get_debug_msg()}" ) self.current_input_ += 1 return "" # second check assert self.builder.as_function or len(self.builder.input_names) < len( tuple(t for t in self.flat_example_inputs_ if t is not None) ), ( f"Too many inputs already ({len(self.builder.input_names)}), " f"unexpected {name!r} " f"after {self.builder.input_names}" f"{self.builder.get_debug_msg()}" ) return None def _make_tensor_input( self, name: str, elem_type: Any, shape: DYNAMIC_SHAPE, is_dimension: bool, users: Iterable[str], fake_tensor: bool = False, default_initializer: Optional[Any] = None, ) -> str: ret = self._make_tensor_check(name, fake_tensor, users) if ret is not None: return ret shape = self.builder.get_input_dynamic_shape(name, self.current_input_, shape) self.current_input_ += 1 return self.builder.make_tensor_input( name, elem_type, shape, is_dimension=is_dimension, default_initializer=default_initializer, marker="DynamoInterpreter._make_tensor_input", ) def _make_list_input( self, name: str, example_value: List["torch.Tensor"], # noqa: F821 users: Iterable[str], fake_tensor: bool = False, ) -> str: ret = self._make_tensor_check(name, fake_tensor, users) if ret is not None: return ret assert all(isinstance(t, self.torch.Tensor) for t in example_value), ( f"Input {name!r}, unexpected type in example_value: " f"{string_type(example_value)}{self.get_debug_msg()}" ) assert len(set(t.dtype for t in example_value)) == 1, ( f"Input {name!r}, multiple element type in example_value " f"{[t.dtype for t in example_value]}{self.get_debug_msg()}" ) shape = self.builder.get_input_dynamic_shape( name, self.current_input_, example_shape=None, example_value=example_value ) assert isinstance(shape, list) and len(shape) == 1, ( f"For a sequence, shapes should be specified as a list of 1 element, " f"shape={string_type(shape)}{self.builder.get_debug_msg()}" ) elem_type = _get_type(example_value[0].dtype) self.current_input_ += 1 return self.builder.make_tensor_sequence_input( name, elem_type, shape[0], marker="DynamoInterpreter._make_list_input", )
[docs] def placeholder(self, node: "torch.fx.Node"): # noqa: F821 """ placeholder for an input. The interpreter adds an Identity node between the input names he wants and the name it has in the graph module. """ if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.placeholder][{node.name}]") val = node.meta.get("val", None) _msg = lambda _=self: _.builder.get_debug_msg() # noqa: E731 if self.builder.verbose > 2: print( f"[DynamoInterpreter-{self._hash()}.placeholder]" f"[{node.name}] val={string_type(val)}" ) if val is None: example_value = node.meta.get("example_value", None) if self.builder.verbose > 2: print( f"[DynamoInterpreter-{self._hash()}.placeholder]" f"[{node.name}] example_value={string_type(val)}" ) # index_input may be wrong because torch.export.export may flatten the inputs. # gathering the default value may not be optimal here. if example_value is None and node.name in self.default_values: example_value = self.default_values[node.name] if self.builder.as_function and example_value is None: return self._make_tensor_input( node.name, None, None, is_dimension=False, users=node.users ) if example_value is None: # The input is not defined. # We return. self.current_input_ += 1 return if isinstance( example_value, (self.builder.torch.SymInt, self.builder.torch.SymFloat) ): # torch.SymInt self.builder.make_dynamic_object(node.name, example_value) return self._make_tensor_input( node.name, elem_type=self.builder.torch.int64, shape=(1,), is_dimension=True, users=node.users, ) if isinstance(example_value, (int, float)): # int or float return self._make_tensor_input( node.name, elem_type=( self.builder.torch.int64 if isinstance(example_value, int) else self.builder.torch.float32 ), shape=(1,), is_dimension=False, users=node.users, ) if isinstance(example_value, (self.torch.Tensor, VirtualTensor)): return self._make_tensor_input( node.name, elem_type=example_value.dtype, shape=example_value.shape, is_dimension=False, users=node.users, ) if isinstance(example_value, list) and all( isinstance(t, self.torch.Tensor) for t in example_value ): return self._make_list_input(node.name, example_value, users=node.users) if example_value.__class__.__name__ == "DynamicCache": import transformers assert isinstance(example_value, transformers.cache_utils.DynamicCache), ( f"Unexpected type {type(example_value)} for an input" f"{self.builder.get_debug_msg()}" ) if not example_value.key_cache: # The cache is empty. We create a dummy input with a default value return self._make_tensor_input( node.name, elem_type=np.float32, shape=(1,), is_dimension=False, users=None, default_initializer=np.array([0], dtype=np.float32), ) raise NotImplementedError( f"Unable to create an input {node.name!r} " f"with type {string_type(example_value)}" f"{self.builder.get_debug_msg()}" ) if isinstance(val, (self.torch.Tensor, self.torch._subclasses.fake_tensor.FakeTensor)): stack_trace = node.meta.get("stack_trace", None) value = None if stack_trace is None and "from_node" not in node.meta: # torch 2.1.0 and 2.2.0 behave differently. # torch 2.4.0, stack_trace is None but from_node is in node.meta value = self.retriever(node.target, val, debug={"node": node}, exc=False) if value is None: return self._make_tensor_input( node.name, elem_type=val.dtype, shape=val.shape, is_dimension=False, users=node.users, fake_tensor=isinstance( val, self.torch._subclasses.fake_tensor.FakeTensor ), ) if value is None: if "nn_module_stack" not in node.meta: value = self.retriever(node.target, val, debug={"node": node}) if value is None: return self._make_tensor_input( node.name, elem_type=val.dtype, shape=val.shape, is_dimension=False, users=node.users, ) else: value = self.retriever(node.target, val, debug={"node": node}, exc=False) if value is None: # This is probably one input then. return self._make_tensor_input( node.target, elem_type=val.dtype, shape=val.shape, is_dimension=False, users=node.users, ) if value is None or isinstance( value, self.torch._subclasses.fake_tensor.FakeTensor ): if ".FakeTensor" in str(type(val)): dtype = val.dtype shape = val.shape return self._make_tensor_input( node.name, dtype, shape, False, users=node.users, fake_tensor=True ) raise RuntimeError(f"value is None, unable to retrieve target {node.target!r}") parameter_name = ( self.parameter_naming(node.name, value, node=node, msg=_msg) if isinstance(value, self.builder.torch.nn.Parameter) else None ) return self.builder.make_initializer( node.name, value, parameter_name=parameter_name, source=( f"DynamoInterpret.placeholder.1/P({parameter_name})" if parameter_name else "DynamoInterpret.placeholder.0" ), ) if isinstance(val, (self.torch.SymInt, self.torch.SymFloat)): return self.builder.make_dynamic_object(node.name, val, shape_as_input=True) if isinstance(val, (int, float)): # scalar input return self._make_tensor_input( node.name, elem_type=TensorProto.INT64 if isinstance(val, int) else TensorProto.FLOAT, shape=(1,), is_dimension=False, users=node.users, ) if isinstance(val, VirtualTensor): return self._make_tensor_input( node.name, elem_type=val.dtype, shape=val.shape, is_dimension=False, users=node.users, ) raise RuntimeError( f"Unsupported type {type(val)} for placeholder " f"{getattr(node, 'target', '?')}{self.builder.get_debug_msg()}." )
[docs] def output(self, node): """ Adds an output to the graph. """ output_name = node.name if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.output][{output_name}]") declared = node.args assert len(declared) == 1, ( f"declared must have one element: {declared}, output_name={output_name}" f"{self.builder.get_debug_msg()}" ) output = declared[0] if hasattr(output, "name"): output = output.name self.builder.make_node( "Identity", [output], [output_name], check=False, name=".output" ) outputs = [(output, output_name)] else: outputs = [] for i, a in enumerate(output): if a is None: a_name = None o = f"{output_name}_{i}" cst = None elif isinstance(a, int): # The model seems to return an integer. o = f"{output_name}_INT_{i}" a_name = None cst = self.builder.make_node( "Constant", [], [o], value_int=a, name=".output_INT_{a}" ) self.builder.set_type(o, TensorProto.INT64) self.builder.set_shape(o, tuple()) else: cst = None a_name = a if isinstance(a, str) else a.name if self.builder.get_is_dimension(a_name, n_outputs=len(output)): o = f"{output_name}_dim_{i}" else: o = f"{output_name}_{i}" if a_name is None: # the gradient may need unused output if cst is None: o = f"{output_name}_NONE_{i}" self.builder.make_node( "Constant", [], [o], value_float=0.0, name=".output_NONE" ) self.builder.set_type(o, TensorProto.FLOAT) self.builder.set_shape(o, tuple()) outputs.append((None, o)) else: self.builder.make_node( "Identity", [a_name], [o], check=False, name=".output" ) outputs.append((a_name, o)) val = node.meta.get("val", None) if isinstance(val, tuple): assert len(val) == 1, ( f"output not yet implemented for multiple outputs, node={node}" f"{self.builder.get_debug_msg()}" ) val = val[0] if val is None: for a, o in outputs: if a is None: assert not self.builder.is_sequence(o), ( f"Output sequences are not implemented but {o!r} is one" f"{self.builder.get_debug_msg()}" ) elem_type = self.builder.get_type(o) shape = self.builder.get_shape(o) else: assert not self.builder.is_sequence(a), ( f"Output sequences are not implemented but {a!r} is one" f"{self.builder.get_debug_msg()}" ) elem_type = self.builder.get_type(a) if self.builder.has_shape(a): shape = self.builder.get_shape(a) elif self.builder.has_rank(a): shape = tuple([None] * self.builder.get_rank(a)) elif self.builder.as_function: shape = None else: raise RuntimeError( f"val is None for node={node}, " f"output={output}, a={a!r}, o={o!r}, " f"has_type={self.builder.has_type(a)}, " f"has_rank={self.builder.has_rank(a)}, " f"has_shape={self.builder.has_shape(a)}, " f"\nmeta={node.meta}" f"\nnode.__dict__={node.__dict__}" f"{self.builder.get_debug_msg()}" ) # let's avoid none ns = [] for i, d in enumerate(shape): if d is None: d = f"d_{o}_{i}" self.builder.make_dynamic_object(d, self.torch.SymInt(d)) ns.append(d) shape = tuple(ns) is_dimension = self.builder.get_is_dimension( a or o, elem_type=elem_type, shape=shape, n_outputs=len(outputs) ) self.builder.make_tensor_output( o, elem_type=elem_type, shape=shape, indexed=False, is_dimension=is_dimension, ) return [_[1] for _ in outputs] if isinstance(val, self.torch.Tensor): n_outputs = len(self.builder.outputs) output_name = f"{node.name}_{n_outputs}" shape = val.shape dtype = _get_type(val.dtype) self.builder.make_tensor_output(output_name, dtype, shape) return output_name raise TypeError(f"Unexpected output type {type(val)}.")
def _fill_in_default_kwargs( self, node: "torch.fx.Node", # noqa: F821 ) -> Tuple[List[Any], Dict[str, Any]]: if hasattr(node.target, "_schema"): node_schema = node.target._schema else: node_schema = None complete_args = [] complete_kwargs = {} if inspect.isbuiltin(node.target) or not node_schema: complete_args = list(node.args) complete_kwargs = {} for k, v in node.kwargs.items(): if isinstance(v, self.torch.fx.Node): complete_kwargs[k] = v.name elif v is None: complete_kwargs[k] = None elif isinstance(v, (int, float, str, self.torch.device, self.torch.dtype)): complete_kwargs[k] = v elif isinstance(v, self.torch.fx.immutable_collections.immutable_list) and all( isinstance(el, self.torch.fx.Node) for el in v ): complete_kwargs[k] = [t.name for t in v] else: raise AssertionError( f"Unexpected type {type(v)} for k={k!r} (v={v!r})" f"{self.builder.get_debug_msg()}" ) else: for i, expected_arg in enumerate(node_schema.arguments): if i < len(node.args): complete_args.append(node.args[i]) elif expected_arg.name in node.kwargs: complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name] else: # Get default from schema. complete_kwargs[expected_arg.name] = expected_arg.default_value return complete_args, complete_kwargs def _get_aten_name(self, node: "torch.fx.Node") -> str: # noqa: F821 if node.target == operator.getitem: return "getitem" if isinstance(node.target, self.torch._ops.OpOverloadPacket): if node.target != self.torch.ops.aten.sym_size: raise RuntimeError(f"Unsupported function {node!r}.") raise NotImplementedError(f"Unsupported function {node!r} (not implemented).") if isinstance(node.target, types.BuiltinFunctionType): return node.target if isinstance(node.target, self.torch._ops.OpOverload): return node.target if callable(node.target): # a single function return f"aten_{node.target.__name__}" raise NotImplementedError( f"Unsupported function {node!r} (not implemented), " f"node.target={node.target}, type is {type(node.target)}." ) def _getitem_slice( self, node: "torch.fx.Node", # noqa: F821 input_name: str, index_slice: slice, sts: Optional[Dict[str, Any]], axes: List[int], expand_axes: List[int], name: str = "_getitem_slice", ): assert isinstance(axes, list), f"Unexpected type {type(axes)} for axes" assert all_int(axes), f"Expected only integer axis but got {axes}" assert len(axes) == len( index_slice ), f"Length mismatch {len(axes)} != {len(index_slice)}" # axes aaxes = np.array(axes, dtype=np.int64) axes_name = self.builder.unique_name(f"{node.name}_axis") self.builder.make_initializer( axes_name, aaxes, source="DynamoInterpreter._getitem_slice.axis.1" ) shape_value = None if self.builder.has_shape(input_name): shape_value = self.builder.get_shape(input_name) starts = [] ends = [] steps = [] shape_name = None end_name = None concat = False for axis_, aslice in zip(axes, index_slice): axis = axis_ if isinstance(aslice, int): # integer starts.append(aslice) ends.append(aslice + 1) steps.append(1) continue assert isinstance( aslice, (slice, int, self.torch.fx.Node) ), f"Unexpected type {type(aslice)} ({aslice}) in {index_slice}" assert isinstance(aslice, slice), ( f"One index is given as an integer {aslice!r} but this requires " f"to append a node 'Squeeze' after this one and this is not yet " f"implemented. You can replace the integer by `i:i+1`" f"{self.builder.get_debug_msg()}" ) starts.append(aslice.start or 0) if aslice.stop is None: if shape_value is None or not isinstance(shape_value[axis], int): if shape_name is None: shape_name = self.builder.unique_name(f"{node.name}_shape") self.builder.make_node( "Shape", [input_name], [shape_name], name=f"{name}A" ) aaxis = np.array([axis], dtype=np.int64) axis_name = self.builder.unique_name(f"{node.name}_axis_{axis}") self.builder.make_initializer( axis_name, aaxis, source="DynamoInterpreter._getitem_slice.axis.2" ) end_name = self.builder.unique_name(f"{node.name}_end") self.builder.make_node( "GatherElements", [shape_name, axis_name], [end_name], name=f"{name}B", sts=None, ) ends.append(end_name) concat = True else: ends.append(shape_value[axis]) else: vstop = aslice.stop.name if hasattr(aslice.stop, "name") else aslice.stop concat |= isinstance(vstop, str) ends.append(vstop) steps.append(aslice.step if aslice.step else 1) # if concat: one end is coming from a shape if concat: iends = [] for i in ends: if isinstance(i, str): if self.builder.get_rank(i) == 0: iends.append( self.builder.op.UnsqueezeAnyOpset( i, np.array([0], dtype=np.int64), name=f"{name}C" ) ) else: assert self.builder.get_rank(i) == 1, ( f"Unexpected rank={self.builder.get_rank(i)} for {i!r}" f"{self.builder.get_debug_msg()}" ) iends.append(i) else: assert isinstance( i, int ), f"Unexpected value for end={i!r}{self.builder.get_debug_msg()}" iends.append(np.array([i], dtype=np.int64)) if len(iends) > 1: conc_ends = self.builder.op.Concat(*iends, axis=0, name=f"{name}D") else: conc_ends = self.builder.op.Identity(iends[0], name=f"{name}E") else: assert all_int(ends), ( f"Unexpected value for ends={ends}: {[type(_) for _ in ends]}" f"{self.builder.get_debug_msg()}" ) conc_ends = self.builder.make_initializer( "", np.array(ends, dtype=np.int64), source="DynamoInterpreter._getitem_slice.1" ) assert all_int(steps), ( f"Not implemented for steps={steps} (types are " f"{[type(c) for c in steps]}){self.builder.get_debug_msg()}" ) if all_int(starts): conc_starts = self.builder.make_initializer( self.builder.unique_name(f"{node.name}_start"), np.array(starts, dtype=np.int64), source="DynamoInterpreter._getitem_slice.2", ) else: istarts = [] for i in starts: si = i.name if hasattr(i, "name") else i if isinstance(si, str): if self.builder.get_rank(si) == 0: istarts.append( self.builder.op.UnsqueezeAnyOpset( si, np.array([0], dtype=np.int64), name=f"{name}C" ) ) else: assert self.builder.get_rank(si) == 1, ( f"Unexpected rank={self.builder.get_rank(i)} for {si!r}" f"{self.builder.get_debug_msg()}" ) istarts.append(si) else: assert isinstance( si, int ), f"Unexpected value for end={si!r}{self.builder.get_debug_msg()}" istarts.append(np.array([si], dtype=np.int64)) if len(istarts) > 1: conc_starts = self.builder.op.Concat(*istarts, axis=0, name=f"{name}SD") else: conc_starts = self.builder.op.Identity(istarts[0], name=f"{name}SE") inputs = [ input_name, conc_starts, conc_ends, axes_name, self.builder.make_initializer( self.builder.unique_name(f"{node.name}_step"), np.array(steps, dtype=np.int64), source="DynamoInterpreter._getitem_slice.3", ), ] if expand_axes: sliced = self.builder.make_node("Slice", inputs, name=f"{name}F") res = self.builder.op.UnsqueezeAnyOpset( sliced, np.array(expand_axes, dtype=np.int64), outputs=[node.name], name=f"{name}F", ) else: res = self.builder.make_node("Slice", inputs, [node.name], name=f"{name}G") if not sts: dtype = self.builder.get_type(inputs[0]) self.builder.set_type(node.name, dtype) if not concat and self.builder.has_shape(inputs[0]): shape = self.builder.get_shape(inputs[0]) new_shape = self.builder._apply_slice_to_shape( shape, index_slice, axes=axes, expand_axes=expand_axes ) assert not self.builder.has_shape( node.name ) or new_shape == self.builder.get_shape(node.name), ( f"Shape for node {node.name!r} is already set to " f"{self.builder.get_shape(node.name)} with type " f"{self.builder.get_type(node.name)} (expecting {dtype}) " f"new_shape={new_shape}, shape={shape}, index_slice={index_slice}, " f"axes={axes}, expand_axes={expand_axes}" f"{self.builder.get_debug_msg()}" ) self.builder.set_shape(node.name, new_shape) elif expand_axes: self.builder.set_rank( node.name, self.builder.get_rank(inputs[0]) + len(expand_axes) ) return res def _getitem_int1( self, node: "torch.fx.Node", # noqa: F821 input_name: str, indices: List[int], sts: Optional[Dict[str, Any]], axes: List[int], expand_axes: List[int], name: str = "_getitem_int1", ): from ._aten_functions import _aten_tensor_int1 return _aten_tensor_int1( self.builder, sts, [node.name], input_name, indices, axes=axes, expand_axes=expand_axes, name=name, )
[docs] def getitem(self, node: "torch.fx.Node"): # noqa: F821 """ Called when the brackets ``something[...]`` appears. The index may be another variable, an integer, a slice, a tuple, a list. """ if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.getitem]") args = node.args assert len(args) == 2 node_output, index = args result_name = node_output.name val = node.meta.get("val", None) sts = None if val is not None: if isinstance(val, self.torch.Tensor): shape = val.shape dtype = _get_type(val.dtype) # the shaphe could be new if a function produces a results # depending on the result values self._verify_new_shape(shape, node) self.builder.set_shape(node.name, tuple(shape)) self.builder.set_type(node.name, dtype) sts = {"dtype": val.dtype} elif isinstance(val, self.torch.SymInt): self.builder.set_shape(node.name, (1,)) self.builder.set_type(node.name, TensorProto.INT64) sts = {"dtype": self.torch.int64} else: raise TypeError( f"Unexpected type {type(val)} in node {node!r}" f"\n{self.builder.pretty_text(add_fx_graph=True)}" ) if hasattr(index, "name"): # A dynamic index (torch.fx.Node) res = self.builder.make_node( "Gather", [result_name, index.name], [node.name], name="getitemA" ) if not sts: self.builder.set_type(node.name, self.builder.get_type(result_name)) self.builder.set_rank( node.name, self.builder.get_rank(result_name) + self.builder.get_rank(index.name) - 1, ) return res if isinstance(index, int): name_index = f"{result_name}#{index}" if self.builder.has_name(name_index): # The user to get a tensor a tuple of tensors return self.builder.make_node( "Identity", [name_index], [node.name], name="getitemB_tuple" ) # The user mean to access the first element of a tensor or a sequence if self.builder.is_sequence(result_name): # A sequence tpos = self.builder.make_initializer( "", np.array(index, dtype=np.int64), source="DynamoInterpreter.getitem.1" ) res = self.builder.make_node( "SequenceAt", [result_name, tpos], [node.name], name="getitemB_tuple", ) if not sts: info = self.builder.get_sequence(result_name) dtype = info["dtype"] if isinstance(dtype, tuple): dtype = dtype[index] self.builder.set_type(res, dtype) if info["shapes"] is not None: self.builder.set_shape( res, info["shapes"][min(index, len(info["shapes"]) - 1)] ) elif info["ranks"] is not None: if isinstance(info["ranks"], int): self.builder.set_rank(res, info["ranks"]) else: self.builder.set_rank( res, info["ranks"][min(index, len(info["ranks"]) - 1)] ) return res else: # A tensor. res = self.builder.op.SqueezeAnyOpset( self.builder.op.Gather( result_name, np.array([index], dtype=np.int64), name="getitemB_index", ), np.array([0], dtype=np.int64), name="getitemB_index", outputs=[node.name], ) if not sts: self.builder.set_type(node.name, self.builder.get_type(result_name)) if self.builder.has_shape(result_name): self.builder.set_shape( node.name, self.builder.get_shape(result_name)[1:] ) else: self.builder.set_rank( node.name, self.builder.get_rank(result_name) - 1 ) return res if isinstance(index, slice): return self._getitem_slice( node, node_output.name, [index], sts=sts, axes=[0], expand_axes=[], name="_getitem_slice1", ) if isinstance(index, self.torch.fx.immutable_collections.immutable_list): # something like x[[0, 2]] if all_int(index): # something like x[[0, 1]] axes = [0] return self._getitem_int1( node, node_output.name, index, sts=sts, axes=axes, expand_axes=[], name="_getitem_int1a", ) if isinstance(index, tuple): if all(isinstance(x, (slice, self.torch.fx.Node)) for x in index): return self._getitem_slice( node, node_output.name, list(index), sts=sts, axes=list(range(len(index))), expand_axes=[], name="_getitem_slicen", ) if all(x is Ellipsis or x is None or isinstance(x, slice) for x in index): # something like x[3:4] axes = [] slices = [] expand_axes = [] ellipsis = False true_slice = False for i, ind in enumerate(index): if ind is Ellipsis: assert not ellipsis, f"Second (...) found in index={index}" ellipsis = True continue if ind is None: assert ( not ellipsis ), f"An axis cannot be inserted after (...) in index={index}" expand_axes.append(i) continue axes.append(((i - len(index)) if ellipsis else i) - len(expand_axes)) if ( not isinstance(ind, slice) or ind.start is not None or ind.stop is not None or ind.step is not None ): true_slice = True slices.append(ind) if true_slice: return self._getitem_slice( node, node_output.name, slices, sts=sts, axes=axes, expand_axes=expand_axes, name="_getitem_slice2", ) # It is just a node unsqueeze. res = self.builder.op.UnsqueezeAnyOpset( str(node.args[0]), np.array(expand_axes, dtype=np.int64), name="getitem_unsqueeze", outputs=[node.name], ) return res raise RuntimeError( f"getitem: unexpected tuple {tuple(type(x) for x in index)} " f"for index={index}, node={node}, args={args}, val={val}, " f"types={string_type(args)}{self.builder.get_debug_msg()}" ) raise RuntimeError( f"getitem: unexpected type {type(index)} for index={index}, " f"node={node}, args={args}, val={val}, " f"types={string_type(args)}{self.builder.get_debug_msg()}" )
def _verify_new_shape(self, shape, node): for dim in shape: if isinstance(dim, self.torch.SymInt): sdim = self.builder._torch_sym_int_to_str(dim) tokens = parse_expression_tokens(sdim) if len(tokens) == 1: # Only one token, possibly knew t = tokens.pop() if t not in self.builder.dynamic_objects: self.builder.add_dynamic_object(t, t) if t in self.builder.dynamic_dimensions_source: self.builder.dynamic_dimensions_source[t].append(dim) else: self.builder.dynamic_dimensions_source[t] = [dim] def _process_arg(self, node, aten_name, i): if i is None: return None if isinstance(i, str): return i if hasattr(i, "name"): return i.name if isinstance(i, tuple): return tuple(self._process_arg(node, aten_name, t) for t in i) if isinstance(i, (float, int, tuple, slice, complex)): return i if isinstance(i, list): new_list = [] for el in i: if hasattr(el, "name"): # torch.fx.Node new_list.append(el.name) continue new_list.append(el) return new_list if i is Ellipsis: return i if isinstance(i, (self.torch.dtype, self.torch.device)): return i raise RuntimeError( f"Unexpected type (argument {i}) {type(i)} " f"for function {aten_name!r} " f"in args={node.args}{self.builder.get_debug_msg()}" )
[docs] def call_function(self, node: "torch.fx.Node") -> Union[str, Tuple[str]]: # noqa: F821 """ Called for a function. """ aten_name = self._get_aten_name(node) fx_args, fx_kwargs = self._fill_in_default_kwargs(node) if aten_name == "aten_auto_functionalized": # Should we make a direct call? aten_name = node.args[0] fx_args = fx_args[1:] self.builder.add_stat(kind="aten", name=aten_name) if aten_name == "getitem": return self.getitem(node) fct, lookup, lookup_names = None, None, None if self.dispatcher is not None: fct = self.dispatcher.find_function(aten_name) lookup_names = [aten_name] if fct is None: fct, lookup, lookup_names = find_function(aten_name) if self.dispatcher is not None: fct = self.dispatcher.fallback( aten_name, fct, node.args, node.kwargs, self.builder ) if fct is None: raise FunctionNotFoundError( f"Unable to interpret function {type(aten_name)}: " f"{aten_name!r}, searched for " f"{lookup} and attributes {lookup_names}, " f"args={node.args}, kwargs={node.kwargs}" f"{self.builder.get_debug_msg()}" ) if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.call_function][{fct.__name__}]") args = [self._process_arg(node, aten_name, a) for a in fx_args] output_names = self._get_output_names(node) can_set = self._can_set_shape_and_type(node) n_nodes = len(self.builder.nodes) + len(self.builder.initializers_dict) assert ( len(node.users) > 0 or aten_name in { self.torch._C._set_grad_enabled, self.torch._C._log_api_usage_once, self.torch.amp.autocast_mode._enter_autocast, self.torch.amp.autocast_mode._exit_autocast, self.torch.ops.aten._assert_scalar.default, self.torch.torch.sym_constrain_range_for_size, "aten__exit_autocast", "aten__enter_autocast", "aten_FunctionCtx", } or ( hasattr(aten_name, "_opname") and aten_name._opname in {"sym_constrain_range_for_size"} ) ), ( f"This is probably one inplace function node={node!r}, " f"node.meta={node.meta!r}, aten_name={aten_name!r}, " f"aten_name._opname={getattr(aten_name, '_opname', '?')}, " f"output_names={output_names!r}{self.builder.get_debug_msg()}" ) if self.export_options.aten_as_function: res = self.add_aten_as_function( str(aten_name), fct, can_set, output_names, args=args, kwargs=fx_kwargs ) else: res = fct(self.builder, can_set, output_names, *args, **fx_kwargs) n_nodes_after = len(self.builder.nodes) + len(self.builder.initializers_dict) if res is None: if len(node.users) == 0: return raise RuntimeError( f"Unexpected return res=None, for node={node}, " f"output_names={output_names}" f"{self.builder.get_debug_msg()}" ) if n_nodes_after == n_nodes: raise RuntimeError( f"No node or initializer was added ({n_nodes}=={n_nodes_after}) " f"for node={node}{self.builder.get_debug_msg()}" ) self._set_shape_and_type(node, res, fct_name=aten_name) res = self._check_output_name(node, res, output_names) return res
[docs] def call_method(self, node: "torch.fx.Node") -> Union[str, Tuple[str]]: # noqa: F821 """ Called for a method. """ method_name = node.target if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.call_method][{method_name}]") assert isinstance( node.args, tuple ), f"Unexpected type {type(node.args)} for node.args." fct = None if self.dispatcher is not None: fct = self.dispatcher.find_method(f"aten_meth_{method_name}") name_fct = f"aten_meth_{method_name}" fct = find_method(name_fct) if self.dispatcher is not None: fct = self.dispatcher.fallback(name_fct, fct, node.args, node.kwargs, self.builder) if fct is None: raise FunctionNotFoundError( f"Unable to interpret method {name_fct!r}, " f"args={node.args}, kwargs={node.kwargs}, " f"dispatcher={self.dispatcher}" f"{self.builder.get_debug_msg()}" ) args = [getattr(node.args[0], "name", node.args[0])] for i in node.args[1:]: args.append(i.name if hasattr(i, "name") else i) kwargs = node.kwargs output_names = self._get_output_names(node) can_set = self._can_set_shape_and_type(node) if self.export_options.aten_as_function: res = self.add_aten_as_function(name_fct, fct, can_set, output_names, args, kwargs) else: res = fct(self.builder, can_set, output_names, *args, **kwargs) self._set_shape_and_type(node, res, fct_name=method_name) res = self._check_output_name(node, res, output_names) return res
[docs] def add_aten_as_function( self, name_fct: str, fct: Callable, can_set: Optional[Dict[str, Any]], output_names: List[str], args: List[Any], kwargs: Dict[str, Any], domain: str = "aten", ) -> Union[str, Tuple[str]]: """ Converts a function into a local function and adds this local function to the graph. """ assert isinstance(name_fct, str), ( f"Unexpected type {type(name_fct)} for name_fct={name_fct}" f"{self.builder.get_debug_msg()}" ) # Collects inputs input_names = [] for a in args: if isinstance(a, str) and self.builder.has_name(a): if a not in input_names: input_names.append(a) elif isinstance(a, list): # some inputs are given as a list for n in a: if ( isinstance(n, str) and self.builder.has_name(n) and n not in input_names ): input_names.append(n) for k, v in kwargs.items(): if isinstance(v, str): raise NotImplementedError( f"This option is not implemented yet for k={k!r} " f"with type={type(v)}{self.builder.get_debug_msg()}" ) if self.builder.verbose > 1 or self._debug_aten_as_function: print( f"[DynamoInterpreter.add_aten_as_function] {name_fct}" f"({', '.join(input_names)}) -> {', '.join(output_names)}" ) new_builder = self.builder.make_subset_builder( input_names, name=name_fct, domain=domain ) try: res = fct(new_builder, can_set, output_names, *args, **kwargs) except AssertionError as e: raise AssertionError( f"The conversion of operator {name_fct!r} into a local function\n--ERROR--\n" f"{e}{self.builder.get_debug_msg()}" ) from e assert (len(output_names) == 1 and res == output_names[0]) or res == output_names, ( f"Mismatch issue res={res!r}, output_names={output_names!r} " f"for function {name_fct!r}{self.builder.get_debug_msg()}" ) for o in output_names: new_builder.make_tensor_output( o, indexed=False, is_dimension=self.builder.get_is_dimension(o, exc=False) ) inits, (fdomain, fname) = self.builder.make_local_function( new_builder, FunctionOptions( export_as_function=True, name=name_fct.replace(".", "_"), domain=domain, inline=False, merge_allowed=True, rename_allowed=True, move_initializer_to_constant=True, return_initializer=True, external_threshold=2**8, ), optimize=False, ) new_inits = [] for init in inits: new_init = self.builder.make_initializer( init.name, init, source="add_aten_as_function" ) new_inits.append(new_init) self.builder.make_node( fname, [*input_names, *new_inits], output_names, domain=fdomain, name=name_fct ) if not can_set: for o in output_names: if new_builder.has_type(o): self.builder.set_type(o) if new_builder.has_shape(o): self.builder.set_shape(o) elif new_builder.has_rank(o): self.builder.set_rank(o) return output_names[0] if len(output_names) == 1 else output_names
def _get_output_names(self, node: "torch.fx.Node") -> List[str]: # noqa: F821 val = node.meta.get("val", None) if val is not None and isinstance(val, tuple): n_outputs = len(val) output_names = [ ("" if val[i] is None else f"{node.name}#{i}") for i in range(n_outputs) ] else: assert isinstance( node.name, str ), f"Unexpected type {type(node.name)} for node.name" output_names = [node.name] return output_names def _check_output_name( self, node: "torch.fx.Node", # noqa: F821 res: Union[str, List[str]], output_names: List[str], ) -> Union[str, List[str]]: if isinstance(node.name, str): if len(output_names) != 1: if output_names != list(res): raise NotImplementedError( f"Unexpected output_names {output_names}, " f"res={res!r}, node.name={node.name!r}" ) elif isinstance(res, list) and len(res) != 1: # SplitToSequence rewritten into a Split name = output_names[0] assert all(s.startswith(name) for s in res), ( f"Unexpected output_names={output_names}, " f"res={res}, node.name={node.name}" f"{self.builder.get_debug_msg()}" ) # nothing to do res = tuple(res) elif res != node.name: assert isinstance(res, str), ( f"Unexpected res={res}, output_names={output_names}, " f"node.name={node.name}" f"{self.builder.get_debug_msg()}" ) self.builder.make_node( "Identity", [res], [node.name], name="_check_output_name" ) res = node.name else: raise NotImplementedError( f"Unexpected type {type(node.name)} for node.name={node.name!r}." ) return res def _can_set_shape_and_type( self, node: "torch.fx.Node" # noqa: F821 ) -> Optional[Dict[str, Any]]: if node.meta.get("val", None) is not None: dtype = self._get_node_output_type(node) assert dtype is not None, ( f"dtype is null, but val={node.meta.get('val', None)}" f"{self.builder.get_debug_msg()} " ) return {"dtype": dtype} return None def _get_node_output_type( self, node: "torch.fx.Node", # noqa: F821 ) -> Optional[Union["torch.dtype", Tuple["torch.dtype", ...]]]: # noqa: F821 val = node.meta.get("val", None) if val is not None: if isinstance(val, (tuple, list)): # Type list comes from SplitToSequence. return tuple((None if v is None else v.dtype) for v in val) if isinstance(val, (int, self.torch.SymInt)): return self.torch.SymInt if isinstance(val, self.torch.SymBool): return self.torch.SymBool if isinstance(val, (float, self.torch.SymFloat)): return self.torch.SymFloat exa = node.meta.get("example_value", None) assert exa is None or val.dtype == exa.dtype, ( f"dtype inconsistency (val, example_value) " f"{val.dtype} != {exa.dtype}{self.builder.get_debug_msg()}" ) assert hasattr(val, "dtype"), ( f"Unexpected type {type(val)} for val={val}, " f"node={node!r}{self.builder.get_debug_msg()}" ) return val.dtype return None def _set_shape_and_type( self, node: "torch.fx.Node", # noqa: F821 res: Union[str, List[str]], fct_name: Optional[str] = None, ): val = node.meta.get("val", None) exa = node.meta.get("example_value", None) if val is not None and exa is not None: assert val.dtype == exa.dtype, ( f"dtype inconsistency (val, example_value) " f"{val.dtype} != {exa.dtype}{self.builder.get_debug_msg()}" ) assert val.shape == exa.shape, ( f"shape inconsistency (val, example_value) " f"{val.shape} != {exa.shape}{self.builder.get_debug_msg()}" ) last_node = self.builder.last_added_node description = [] if val is not None and fct_name not in {"aten_cond"}: # extracting shape and types if not isinstance(val, tuple): val = (val,) res = (res,) assert isinstance( res, (list, tuple) ), f"Unexpected type {type(res)}{self.builder.get_debug_msg()}" if len(val) != len(res): raise RuntimeError( f"Length mismatch {len(val)} != {len(res)} " f"between {val} and {res}" f"{self.builder.get_debug_msg()}" ) output_sets = set(last_node.output) if last_node is not None else {} for i, (v, r) in enumerate(zip(val, res)): if isinstance(v, self.torch.Tensor): dtype = _get_type(v.dtype) if i >= 1 and node.target.name() in { "aten::_native_batch_norm_legit.no_stats", "aten::_native_batch_norm_legit_no_training", }: # It seems the type is not very consistant # and the output might not be used. self.builder.set_type(r, dtype, exc=False) else: self.builder.set_type(r, dtype) shape = tuple(v.shape) for t in shape: if isinstance(t, self.builder.torch.SymInt): expr = str(t.node._expr) if expr not in self.builder.dynamic_objects: # A new shape may be given to a result. self.builder.add_dynamic_object(expr, t, parse=True) if self.builder.is_dynamic_shape(shape): # sets shape coming from the original model # we must not set the existing shape as static, # if it was dynamic before self.builder.set_shape(r, shape, set_if_more_precise=False) elif self.builder.has_rank(r): assert len(shape) == self.builder.get_rank(r), ( f"Rank already set for {r!r}, " f"but rank={self.builder.get_rank(r)} " f"differs for shape={shape!r}{self.builder.get_debug_msg()}" ) else: self.builder.set_rank(r, len(shape)) if r in output_sets: description.append(f"{r}:{dtype}:{shape}".replace(" ", "")) elif isinstance(v, self.torch.SymInt): # this is a shape self.builder.set_shape(r, (1,)) self.builder.set_type(r, TensorProto.INT64) self.builder.make_dynamic_object(r, v) elif isinstance(v, self.torch.SymBool): # this is a shape self.builder.set_shape(r, (1,)) self.builder.set_type(r, TensorProto.BOOL) self.builder.make_dynamic_object(r, v) elif isinstance(v, self.torch.SymFloat): # this is a shape self.builder.set_shape(r, (1,)) self.builder.set_type(r, TensorProto.FLOAT) self.builder.make_dynamic_object(r, v) elif v is None: continue elif isinstance(v, list) and len(v) > 0: if len(v) == len(r) and r[0].endswith("#0"): # Operator Split was used instead of SplitToSequence. for r_, v_ in zip(r, v): self.builder.set_type(r_, torch_dtype_to_onnx_dtype(v_.dtype)) shape = tuple(v_.shape) if self.builder.is_dynamic_shape(shape): self.builder.set_shape(r_, shape, set_if_more_precise=False) elif self.builder.has_rank(r_): assert len(shape) == self.builder.get_rank(r_), ( f"Rank already set for {r_!r}, " f"but rank={self.builder.get_rank(r_)} " f"differs for shape={shape!r}" f"{self.builder.get_debug_msg()}" ) else: self.builder.set_rank(r, len(shape)) else: # This is coming from the sequence. dtype = list(set(_.dtype for _ in v)) assert len(dtype) == 1, ( f"Only sequence of tensors of the same type are allowed " f"but dtype={dtype}{self.builder.get_debug_msg()}" ) itype = torch_dtype_to_onnx_dtype(dtype[0]) self.builder.set_sequence( r, itype, shapes=tuple( tuple(map(self.builder._torch_sym_int_to_str, _.shape)) for _ in v ), ) else: raise TypeError( f"Unexpected type in node {node!r}, r={r!r}, " f"type(val)={type(v)}{self.builder.get_debug_msg()}" f"\n----\nval={val}" ) if exa is not None and not isinstance(exa, tuple): if hasattr(exa, "dtype"): # a tensor description.append(f"~{exa.dtype}:{exa.shape}".replace(" ", "")) else: # a SymInt description.append(f"~SumInt:{exa!r}".replace(" ", "")) if last_node is not None and description: last_node.doc_string += "\n".join(description) def _interpret_sub_module( self, sub_module, args, kwargs, source_node=None, local_domain=None ): from .onnx_export import _make_builder_interpreter assert not kwargs, ( f"This functionality is not implemented kwargs={string_type(kwargs)}" f"{self.get_debug_msg()}" ) if args is None: new_args = None else: new_args = [] for a in args: if isinstance(a, self.torch.fx.Node): name = a.name dtype = self.builder.get_type(name) if self.builder.has_type(name) else 0 shape = ( self.builder.get_shape(name) if self.builder.has_shape(name) else ( self.builder.make_new_dynamic_shape( self.builder.get_rank(name), prefix=name ) if self.builder.has_rank(name) else None ) ) new_args.append(VirtualTensor(name=name, dtype=dtype, shape=shape)) elif isinstance(a, self.torch.Tensor): new_args.append(a) else: raise NotImplementedError( f"Unable to process argument {type(a)}{self.get_debug_msg()}" ) if hasattr(sub_module, "graph") and isinstance(sub_module, self.torch.fx.GraphModule): gm = sub_module elif ( hasattr(sub_module, "graph") and isinstance(sub_module, self.torch.nn.Module) and sub_module.__class__.__name__ == "InterpreterModule" ): gm = sub_module else: # https://pytorch.org/docs/stable/fx.html tracer_class = self.torch.fx.Tracer graph = tracer_class().trace(sub_module) # Let's propulate with type if new_args: ii = 0 for node in graph.nodes: if node.op == "placeholder": if ii >= len(new_args) or "val" in node.meta: ii += 1 continue ag = new_args[ii] if isinstance(ag, VirtualTensor): node.meta["val"] = ag else: node.meta["example_value"] = ag ii += 1 gm = self.torch.fx.GraphModule(sub_module, graph) graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter( gm, args=None if new_args is None else tuple(new_args), kwargs=None if kwargs is None else kwargs, as_function=True, target_opset=self.builder.opsets, optimization_options=self.builder.optimization_options, verbose=max(0, self.builder.verbose - 1), dispatcher=self.dispatcher, raise_list=self.builder.raise_list, # dynamic shapes applies on the inner graph, not on the subgraph # dynamic_shapes=self.builder.dynamic_shapes, export_options=self.export_options, optimize_submodules=self.optimize_submodules, function_options=self.function_options, local_domain=local_domain, submodule_naming=self.submodule_naming, parameter_naming=self.parameter_naming, module_name=( None if (self.module_name is None or source_node is None) else ( source_node.target if self.module_name == "" else f"{self.module_name}.{source_node.target}" ) ), ) assert mask_outputs is None or all( mask_outputs ), f"Unexpected value for mask_outputs={mask_outputs}{self.get_debug_msg()}" # We register the dynamic elements in case the submodule is using them. for k, v in self.builder.dynamic_objects.items(): # We assume the list of dynamic objects is valid. if not self.builder.has_name(k): builder.add_dynamic_object(k, v, check_tokens=False) if self.builder.has_type(k): builder.set_type(k, self.builder.get_type(k)) if self.builder.has_shape(k): builder.set_shape(k, self.builder.get_shape(k)) if self.preserved_modules and hasattr(self, "named_modules"): assert ( source_node is not None ), f"For this option, source_node cannot be None{self.builder.get_debug_msg()}" module_name = source_node.target if module_name in self.named_modules: module_child = self.named_modules[module_name] interpreter.register_named_modules( self, None, dict(module_child.named_modules()) ) builder.process(graph_module, interpreter) assert builder.outputs, f"No output detected for node={source_node}, graph={gm}" # processing args, kwargs fx_args, fx_kwargs = self._fill_in_default_kwargs(source_node) args = [getattr(i, "name", i) for i in fx_args] kwargs = [getattr(i, "name", i) for i in fx_kwargs] # looking at the sample example val = source_node.meta.get("val", None) if val is not None and isinstance(val, tuple): n_outputs = len(val) output_names = [f"{source_node.name}#{i}" for i in range(n_outputs)] elif self.preserved_modules and val is not None and isinstance(val, list): n_outputs = len(val) output_names = [f"{source_node.name}#{i}" for i in range(n_outputs)] val = tuple(val) else: output_names = [source_node.name] if val is None: val = source_node.meta.get("example_value", None) if val is not None and not isinstance(val, tuple): val = (val,) # if not none if val is not None: if self.preserved_modules and len(val) == 1 and isinstance(val[0], list): # submodules with multiple outputs assert len(val[0]) == len(builder.outputs), ( f"Output mismatch {len(val[0])} != {len(builder.outputs)}, " f"source_node.name={source_node.name!r}, target={source_node.target!r}" f"type(val)={string_type(val)}, " f"builder.outputs={string_type(builder.outputs)}" f"{self.builder.get_debug_msg()}" ) # Shapes and types are set outside this function when the final node is added. else: # regular node assert len(val) == len(builder.outputs), ( f"Output mismatch {len(val)} != {len(builder.outputs)}, " f"source_node.name={source_node.name!r}, target={source_node.target!r}" f"type(val)={string_type(val)}, " f"builder.outputs={string_type(builder.outputs)}" f"{self.builder.get_debug_msg()}" ) for i in range(len(val)): name = builder.outputs[i].name if not builder.has_shape(name): builder.set_shape(name, val[i].shape) if not builder.has_type(name): builder.set_type(name, val[i].dtype) if isinstance(val[i], self.builder.torch.Tensor): self.builder.set_shapes_types( source_node.name, "call_module", (val[i].dtype, val[i].shape) ) elif isinstance(val[i], (self.builder.torch.SymInt)): self.builder.set_shapes_types( source_node.name, "call_module", (self.builder.torch.SymInt, tuple()), ) elif isinstance(val[i], (self.builder.torch.SymFloat)): self.builder.set_shapes_types( source_node.name, "call_module", (self.builder.torch.SymFloat, tuple()), ) else: # We could use the informations stored in the builder. pass return builder, args, kwargs, output_names
[docs] def get_submodule_name( self, module_name: str, module: "torch.nn.Module" # noqa: F821 ) -> str: """ Gets a submodule name, simple but unique. """ assert self.submodule_naming, "submodule_naming is null" assert self.parameter_naming, "parameter_naming is null" return self.submodule_naming(module_name, module)
[docs] def call_module(self, node: "torch.fx.Node"): # noqa: F821 """ Called for a module. """ def raise_msg(): return ( f"node={node}\n--\nnode.__dict__={pprint.pformat(node.__dict__)}" f"\n--\n{pprint.pformat(node.meta)}\n---\n{dir(node)}" f"\n---GRAPH\n{type(node.graph)}\n---GRAPH\n{node.graph}" f"\n---GRAPH\n{node.graph.__dict__}\n---GRAPH\n{dir(node.graph)}" f"\n---GRAPH.MODULE\n{type(node.graph.owning_module)}" f"\n---GRAPH.MODULE\n{id(node.graph.owning_module)}" f"\n---GRAPH.MODULE\n{node.graph.owning_module}" # f"\n---GRAPH.MODULE\n{node.graph.owning_module.__dict__}" f"\n---GRAPH.MODULE\n{dir(node.graph.owning_module)}" f"\nVALUES\n{pprint.pformat(self.example_values_)}" ) owning_module = node.graph.owning_module assert owning_module is not None, f"owning_module is None\n{raise_msg()}" sub_module = owning_module.get_submodule(node.target) assert isinstance( sub_module, self.torch.nn.Module ), f"Not implemented for type {type(sub_module)}.\n{raise_msg()}" if self.builder.verbose > 1: print(f"[DynamoInterpreter-{self._hash()}.call_module] class [{type(sub_module)}]") print( f"[DynamoInterpreter-{self._hash()}.call_module] with " f"node.args={string_type(node.args)}]" ) print( f"[DynamoInterpreter-{self._hash()}.call_module] with " f"kwargs={string_type(node.kwargs)}]" ) # This function is meant to be used later. if "." in self.builder.local_domain: root, n = self.builder.local_domain.split(".") n = int(n) + 1 else: root, n = self.builder.local_domain, 0 self.builder._check_constants("before-_interpret_sub_module") builder, args, kwargs, output_names = self._interpret_sub_module( sub_module, node.args, node.kwargs, source_node=node, local_domain=f"{root}.{n}" ) self.builder._check_constants("after-_interpret_sub_module") assert kwargs is None or len(kwargs) == 0, ( f"args={string_type(args)}, kwargs={string_type(kwargs)} " f"is not implemented yet{self.builder.get_debug_msg()}" ) name = sub_module.__class__.__name__ local_function_name = None if sub_module.__class__.__name__ == "InterpreterModule": # a local function is added. assert node.target in self.named_modules, ( f"Unable to find module name {node.target!r} in " f"{sorted(self.named_modules)}{self.builder.get_debug_msg()}" ) m = self.named_modules[node.target] if type(m) in self.preserved_modules: # Which name to give the submodule? # The class, the module name, ...? local_function_name = name = self.get_submodule_name(node.target, m) self.builder._check_constants("before-make_nodes") # let's create a function under the appropriate name self.builder.make_nodes( builder, args, output_names, prefix=f"_sub_{name}_", function_options=FunctionOptions( name=local_function_name, domain=LOCAL_DOMAIN, export_as_function=True, return_initializer=True, move_initializer_to_constant=self.function_options.move_initializer_to_constant, external_threshold=self.function_options.external_threshold, merge_allowed=self.function_options.merge_allowed, rename_allowed=self.function_options.rename_allowed, ), optimize=self.optimize_submodules, ) self.builder._check_constants("after-make_nodes") if len(output_names) == len(builder.outputs): # One output, both tensor for name, out_name in zip(builder.output_names, output_names): if builder.has_type(name): self.builder.set_type(out_name, builder.get_type(name)) if builder.has_shape(name): existing_shape = builder.get_shape(name) # We need to move any dynamic objects necessary from the submodules # to the parent module. self.builder.register_dynamic_objects_from_shape(existing_shape) self.builder.set_shape(out_name, existing_shape) elif builder.has_rank(name): self.builder.set_rank(out_name, builder.get_rank(name)) elif len(output_names) == 1 and len(builder.outputs) > 1: # The module outputs more than one output itypes, shapes, ranks = [], [], [] for name in builder.output_names: itypes.append(builder.get_type(name) if builder.has_type(name) else None) shapes.append(builder.get_shape(name) if builder.has_shape(name) else None) ranks.append(builder.get_rank(name) if builder.has_rank(name) else None) self.builder.set_sequence( output_names[0], tuple(itypes), shapes=tuple(shapes), ranks=ranks ) else: raise AssertionError( f"Unexpected number of outputs, output_names={output_names}, " f"len(builder.outputs)={len(builder.outputs)}, " f"builder.output_names={builder.output_names}" f"{builder.get_debug_msg()}\n--\n--\n--" f"{self.builder.get_debug_msg()}\n------\n" ) else: # nodes are inserted inline self.builder._check_constants("before-make_nodes(2)") self.builder.make_nodes(builder, args, output_names, prefix=f"_sub_{name}_") self.builder._check_constants("after-make_nodes(2)") return output_names