from typing import Any, List, Optional, Union, Tuple
import onnx
import torch
from ...api import TensorLike
from ...helpers import string_type
from ...helpers.torch_helper import to_tensor
[docs]
class OpRunValue(TensorLike):
"""Defines a value for the runtime, a tensor or a sequence."""
__slots__ = ("cached", "is_constant", "sequence", "tensor")
[docs]
@classmethod
def is_sequence(cls) -> bool:
"Tells if it is sequence."
raise NotImplementedError("is_sequence must be overwritten.")
[docs]
class OpRunTensor(OpRunValue):
"""
Wrapper around a tensor.
:param tensor: torch.Tensor
:param is_constant: is it a constant
:param may_cpu: change the device the tensor is if
more appropriate
"""
def __init__(self, tensor, is_constant: bool = False, may_cpu: bool = False):
assert isinstance(tensor, torch.Tensor), (
f"Unexpected type {type(tensor)}, "
f"__name__={getattr(tensor, '__name__', 'no name')}"
)
assert tensor is None or tensor.numel() != 1 or tensor.item() != -666666
self.tensor = (
tensor.cpu()
if may_cpu
and len(tensor.shape) == 1
and tensor.numel() < 8
and tensor.dtype == torch.int64
and tensor.get_device() >= 0
else tensor
)
self.is_constant = is_constant
self.cached: Optional[Tuple[int, ...]] = None
[docs]
@classmethod
def is_sequence(cls) -> bool:
"Tells if it is sequence."
return False
[docs]
def to(self, to: Any) -> "OpRunTensor":
"Changes the device."
return OpRunTensor(self.tensor.to(to))
[docs]
def string_type(self) -> str:
"Returns information about the value as a string."
s = string_type(self.tensor, with_shape=True, with_min_max=True, with_device=True)
if self.is_constant:
return f"CST({s})"
return s
def __repr__(self) -> str:
"usual"
if self.is_constant:
return (
f"{self.__class__.__name__}"
f"({string_type(self.tensor, with_shape=True)}, is_constant=True)"
)
return f"{self.__class__.__name__}({string_type(self.tensor, with_shape=True)})"
@property
def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
"Returns either a tensor or a sequence."
return self.tensor
@property
def shape(self):
"shape (torch shape)"
return self.tensor.shape
@property
def dtype(self):
"dtype (torch dtype)"
return self.tensor.dtype
def _tensor_as_tuple_int(self) -> Tuple[int, ...]:
return tuple(map(int, self.tensor))
[docs]
def numel(self) -> int:
"Returns the number of elements."
return 0 if self.tensor is None else self.tensor.numel()
[docs]
def get_device(self) -> int:
"Returns the device id."
return -1 if self.tensor is None else self.tensor.get_device()
@property
def device(self):
"Returns the device."
return -1 if self.tensor is None else self.tensor.device
@property
def as_tuple_int(self) -> Tuple[int, ...]:
"value as int"
if self.is_constant:
if self.cached is None:
self.cached = self._tensor_as_tuple_int()
return self.cached
return self._tensor_as_tuple_int()
[docs]
def copy(self) -> "OpRunTensor":
"Shallow copy."
return self.__class__(self.tensor)
[docs]
class OpRunSequence(OpRunValue):
"""Defines a sequence."""
def __init__(
self, sequence: Optional[List[torch.Tensor]] = None, dtype: torch.dtype = torch.float32
):
self.tensor = torch.tensor(-666666, dtype=dtype)
self.is_shape = False
self.sequence = sequence or []
self.cached: Optional[Tuple[int, ...]] = None
assert all(
isinstance(s, torch.Tensor) for s in self.sequence
), f"Unexpected type in sequence {[type(s) for s in self.sequence]}"
@property
def dtype(self):
"dtype (torch dtype)"
return self.tensor.dtype
@property
def tensor_or_sequence(self) -> Union[torch.Tensor, List[torch.Tensor]]:
"Returns either a tensor or a sequence."
return self.sequence
[docs]
@classmethod
def is_sequence(cls) -> bool:
"Tells if it is sequence."
return True
[docs]
def insert_at(
self, tensor: torch.Tensor, position: Optional[OpRunTensor] = None
) -> "OpRunSequence":
"Inserts a value at a given position."
assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor"
new_seq = OpRunSequence()
seq = self.sequence.copy()
new_seq.sequence = seq
if position is None:
seq.append(tensor.tensor)
else:
seq.insert(int(position.tensor.item()), tensor.tensor)
return new_seq
[docs]
def copy(self) -> "OpRunSequence":
"Shallow copy."
return self.__class__(self.sequence, dtype=self.dtype)
[docs]
def string_type(self) -> str:
"Returns a string which can be printed."
return string_type(self.sequence, with_shape=True)
[docs]
class OpRun:
"""
Main class. Every kernel should inherit from it.
It does not copy the proto.
"""
[docs]
@classmethod
def device_dependent(cls) -> bool:
"""
Returns True if the kernel needs a device to be efficiently initialized.
"""
return False
[docs]
@classmethod
def has_subgraphs(cls) -> bool:
"""Returns True if the kernel has subgraphs."""
return False
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
assert isinstance(
node, onnx.NodeProto
), f"node must be a NodeProto but node is {type(node)}"
self.op_type = node.op_type
self.domain = node.domain
self.input = node.input
self.output = node.output
if version is None:
name = self.__class__.__name__.split("_")
assert (
len(name) == 2
), f"Cannot guess version from name={self.__class__.__name__!r}"
version = int(name[1])
self.version = version
self.name = node.name
def __str__(self) -> str:
"usual"
if self.domain:
return (
f"{self.op_type}[{self.domain}]({', '.join(self.input)}) "
f"-> {', '.join(self.output)}"
)
return f"{self.op_type}({', '.join(self.input)}) -> {', '.join(self.output)}"
[docs]
def run(
self, *args: Optional[OpRunValue]
) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
"Kernel implementation."
raise NotImplementedError(
f"Method run is not implemented for kernel {self.__class__.__name__!r}"
)
def _find_attribute(self, node: onnx.NodeProto, name: str):
for att in node.attribute:
if att.name == name:
return att
return None
[docs]
def get_attribute_float(
self, node: onnx.NodeProto, name: str, default_value: Optional[float] = None
) -> Optional[float]:
"""
Returns an attribute as an int.
:param node: NodeProto
:param name: name
:param default_value: default_value
:return: value
"""
att = self._find_attribute(node, name)
return default_value if att is None else float(att.f)
[docs]
def get_attribute_int(
self, node: onnx.NodeProto, name: str, default_value: Optional[int] = None
) -> Optional[int]:
"""
Returns an attribute as an int.
:param node: NodeProto
:param name: name
:param default_value: default_value
:return: value
"""
att = self._find_attribute(node, name)
return default_value if att is None else int(att.i)
[docs]
def get_attribute_ints(
self, node: onnx.NodeProto, name: str, default_value: Optional[Tuple[int, ...]] = None
) -> Optional[Tuple[int, ...]]:
"""
Returns an attribute as a tuple of ints.
:param node: NodeProto
:param name: name
:param default_value: default_value
:return: value
"""
att = self._find_attribute(node, name)
return default_value if att is None else tuple(map(int, att.ints))
[docs]
def get_attribute_string(
self, node: onnx.NodeProto, name: str, default_value: Optional[str] = None
) -> Optional[str]:
"""
Returns an attribute as a tuple of ints.
:param node: NodeProto
:param name: name
:param default_value: default_value
:return: value
"""
att = self._find_attribute(node, name)
return default_value if att is None else att.s.decode("utf-8")
[docs]
def get_attribute_tensor(self, node: onnx.NodeProto, name: str) -> Optional[torch.Tensor]:
"""
Returns an attribute as a torch tensor.
:param node: NodeProto
:param name: name
:param default_value: default_value
:return: value
"""
att = self._find_attribute(node, name)
if att is None:
return None
return to_tensor(att.t)
[docs]
def same_device(self, *tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""Puts all tensors on the same device."""
devices = [t.get_device() for t in tensors]
if len(set(devices)) == 1:
return tuple(tensors)
index = devices.index(max(devices))
device = tensors[index].device
return tuple(t.to(device) for t in tensors)
[docs]
class OpRunFunction(OpRun):
"""
Defines a kernel based on a local functions.
"""
def __init__(
self,
runtime: "onnx_diagnostic.reference.TorchOnnxEvaluator", # noqa: F821
node: onnx.NodeProto,
version: Optional[int] = None,
):
super().__init__(node, version)
self.runtime = runtime
self.input_names = runtime.input_names
[docs]
def run(
self, *args: Optional[OpRunValue]
) -> Union[OpRunValue, Tuple[Optional[OpRunValue], ...]]:
return self.runtime.run_with_values(*args)