Source code for onnx_diagnostic.reference.torch_ops._op_run

from typing import Any, List, Optional, Union, Tuple
import onnx
import torch
from ...api import TensorLike
from ...helpers import string_type
from ...helpers.torch_helper import to_tensor


[docs] class OpRunValue(TensorLike): """Defines a value for the runtime, a tensor or a sequence.""" __slots__ = ("cached", "is_constant", "sequence", "tensor")
[docs] @classmethod def is_sequence(cls) -> bool: "Tells if it is sequence." raise NotImplementedError("is_sequence must be overwritten.")
[docs] class OpRunTensor(OpRunValue): """ Wrapper around a tensor. :param tensor: torch.Tensor :param is_constant: is it a constant :param may_cpu: change the device the tensor is if more appropriate """ def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False): assert isinstance(tensor, torch.Tensor), ( f"Unexpected type {type(tensor)}, " f"__name__={getattr(tensor, '__name__', 'no name')}" ) assert tensor is None or tensor.numel() != 1 or tensor.item() != -666666 self.tensor = ( tensor.cpu() if may_cpu and len(tensor.shape) == 1 and tensor.numel() < 8 and tensor.dtype == torch.int64 and tensor.get_device() >= 0 else tensor ) self.is_constant = is_constant self.cached: Optional[Tuple[int, ...]] = None
[docs] @classmethod def is_sequence(cls) -> bool: "Tells if it is sequence." return False
[docs] def to(self, to: Any) -> "OpRunTensor": "Changes the device." return OpRunTensor(self.tensor.to(to))
[docs] def string_type(self) -> str: "Returns information about the value as a string." s = string_type(self.tensor, with_shape=True, with_min_max=True, with_device=True) if self.is_constant: return f"CST({s})" return s
def __repr__(self) -> str: "usual" if self.is_constant: return ( f"{self.__class__.__name__}" f"({string_type(self.tensor, with_shape=True)}, is_constant=True)" ) return f"{self.__class__.__name__}({string_type(self.tensor, with_shape=True)})" @property def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]: "Returns either a tensor or a sequence." return self.tensor @property def shape(self): "shape (torch shape)" return self.tensor.shape @property def dtype(self): "dtype (torch dtype)" return self.tensor.dtype def _tensor_as_tuple_int(self) -> Tuple[int, ...]: return tuple(map(int, self.tensor))
[docs] def numel(self) -> int: "Returns the number of elements." return 0 if self.tensor is None else self.tensor.numel()
[docs] def get_device(self) -> int: "Returns the device id." return -1 if self.tensor is None else self.tensor.get_device()
@property def device(self): "Returns the device." return -1 if self.tensor is None else self.tensor.device @property def as_tuple_int(self) -> Tuple[int, ...]: "value as int" if self.is_constant: if self.cached is None: self.cached = self._tensor_as_tuple_int() return self.cached return self._tensor_as_tuple_int()
[docs] def copy(self) -> "OpRunTensor": "Shallow copy." return self.__class__(self.tensor)
[docs] class OpRunSequence(OpRunValue): """Defines a sequence.""" def __init__( self, sequence: Optional[List[torch.Tensor]] = None, dtype: torch.dtype = torch.float32 ): self.tensor = torch.tensor(-666666, dtype=dtype) self.is_shape = False self.sequence = sequence or [] self.cached: Optional[Tuple[int, ...]] = None assert all( isinstance(s, torch.Tensor) for s in self.sequence ), f"Unexpected type in sequence {[type(s) for s in self.sequence]}" @property def dtype(self): "dtype (torch dtype)" return self.tensor.dtype @property def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]: "Returns either a tensor or a sequence." return self.sequence
[docs] @classmethod def is_sequence(cls) -> bool: "Tells if it is sequence." return True
[docs] def insert_at( self, tensor: torch.Tensor, position: Optional[OpRunTensor] = None ) -> "OpRunSequence": "Inserts a value at a given position." assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor" new_seq = OpRunSequence() seq = self.sequence.copy() new_seq.sequence = seq if position is None: seq.append(tensor.tensor) else: seq.insert(int(position.tensor.item()), tensor.tensor) return new_seq
[docs] def copy(self) -> "OpRunSequence": "Shallow copy." return self.__class__(self.sequence, dtype=self.dtype)
[docs] def string_type(self) -> str: "Returns a string which can be printed." return string_type(self.sequence, with_shape=True)
[docs] class OpRun: """ Main class. Every kernel should inherit from it. It does not copy the proto. """
[docs] @classmethod def device_dependent(cls) -> bool: """ Returns True if the kernel needs a device to be efficiently initialized. """ return False
[docs] @classmethod def has_subgraphs(cls) -> bool: """Returns True if the kernel has subgraphs.""" return False
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): assert isinstance( node, onnx.NodeProto ), f"node must be a NodeProto but node is {type(node)}" self.op_type = node.op_type self.domain = node.domain self.input = node.input self.output = node.output if version is None: name = self.__class__.__name__.split("_") assert ( len(name) == 2 ), f"Cannot guess version from name={self.__class__.__name__!r}" version = int(name[1]) self.version = version self.name = node.name def __str__(self) -> str: "usual" if self.domain: return ( f"{self.op_type}[{self.domain}]({', '.join(self.input)}) " f"-> {', '.join(self.output)}" ) return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"
[docs] def run( self, *args: Optional[OpRunValue] ) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]: "Kernel implementation." raise NotImplementedError( f"Method run is not implemented for kernel {self.__class__.__name__!r}" )
def _find_attribute(self, node: onnx.NodeProto, name: str): for att in node.attribute: if att.name == name: return att return None
[docs] def get_attribute_float( self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None ) -> Optional[float]: """ Returns an attribute as an int. :param node: NodeProto :param name: name :param default_value: default_value :return: value """ att = self._find_attribute(node, name) return default_value if att is None else float(att.f)
[docs] def get_attribute_int( self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None ) -> Optional[int]: """ Returns an attribute as an int. :param node: NodeProto :param name: name :param default_value: default_value :return: value """ att = self._find_attribute(node, name) return default_value if att is None else int(att.i)
[docs] def get_attribute_ints( self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None ) -> Optional[Tuple[int, ...]]: """ Returns an attribute as a tuple of ints. :param node: NodeProto :param name: name :param default_value: default_value :return: value """ att = self._find_attribute(node, name) return default_value if att is None else tuple(map(int, att.ints))
[docs] def get_attribute_string( self, node: onnx.NodeProto, name: str, default_value: Optional[str] = None ) -> Optional[str]: """ Returns an attribute as a tuple of ints. :param node: NodeProto :param name: name :param default_value: default_value :return: value """ att = self._find_attribute(node, name) return default_value if att is None else att.s.decode("utf-8")
[docs] def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torch.Tensor]: """ Returns an attribute as a torch tensor. :param node: NodeProto :param name: name :param default_value: default_value :return: value """ att = self._find_attribute(node, name) if att is None: return None return to_tensor(att.t)
[docs] def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: """Puts all tensors on the same device.""" devices = [t.get_device() for t in tensors] if len(set(devices)) == 1: return tuple(tensors) index = devices.index(max(devices)) device = tensors[index].device return tuple(t.to(device) for t in tensors)
[docs] class OpRunFunction(OpRun): """ Defines a kernel based on a local functions. """ def __init__( self, runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821 node: onnx.NodeProto, version: Optional[int] = None, ): super().__init__(node, version) self.runtime = runtime self.input_names = runtime.input_names
[docs] def run( self, *args: Optional[OpRunValue] ) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]: return self.runtime.run_with_values(*args)