Source code for onnx_array_api.ort.ort_tensors

from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
from onnx import ModelProto, TensorProto
from onnx.defs import onnx_opset_version
from onnxruntime import InferenceSession, RunOptions, get_available_providers
from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice
from onnxruntime.capi._pybind_state import OrtMemType
from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument

from ..npx.npx_tensors import EagerTensor, JitTensor
from ..npx.npx_types import DType, TensorType


[docs] class OrtTensor: """ Default backend based on :class:`onnxruntime.InferenceSession`. Data is not copied. :param input_names: input names :param onx: onnx model """ CPU = C_OrtDevice(C_OrtDevice.cpu(), OrtMemType.DEFAULT, 0) CUDA0 = C_OrtDevice(C_OrtDevice.cuda(), OrtMemType.DEFAULT, 0) providers = [ c for c in ["CUDAExecutionProvider", "CPUExecutionProvider"] if c in get_available_providers() ]
[docs] @staticmethod def from_array( value: np.ndarray, device: Optional[C_OrtDevice] = None ) -> "OrtTensor": """ Creates an instance of :class:`OrtTensor` from a numpy array. Relies on `ortvalue_from_numpy`. A copy of the data in the Numpy object is held by the :epkg:`C_OrtValue` only if the device is **not cpu**. Any expression such as `from_array(x.copy())`, or `from_array(x.astype(np.float32))`, ... creates an intermediate variable scheduled to be deleted by the garbage collector as soon as the function returns. In that case, the buffer holding the values is deleted and the instance `OrtTenor` is no longer equal to the original value: `assert_allclose(value, tensor.numpy())` is false. `value` must remain alive as long as the `OrtTensor` is. :param value: value :param device: CPU, GPU, value such as `OrtTensor.CPU`, `OrtTensor.CUDA0` :return: instance of :class:`OrtTensor` """ if device is None: device = OrtTensor.CPU return OrtTensor(C_OrtValue.ortvalue_from_numpy(value, device), _hold=value)
[docs] def numpy(self) -> np.ndarray: """ Converts the :epkg:`OrtValue` into numpy array. """ return self._tensor.numpy()
[docs] class Evaluator: """ Wraps class :class:`onnxruntime.InferenceSession` to have a signature closer to python function. :param tensor_class: class tensor such as :class:`NumpyTensor <onnx_array_api.npx.npx_numpy_tensors.NumpyTensor>` :param input_names: input names :param onx: onnx model :param f: unused except in error messages :param _hold: :epkg:`onnxruntime` does not copy the data if it comes from a numpy array on CPU it does not hold any reference on it. *_hold* is used to stored the underlying numpy array hosting the data for an OrtTensor if it comes from it. It ensures the garbage collector does not remove it. """ def __init__( self, tensor_class: type, input_names: List[str], onx: ModelProto, f: Callable = None, ): try: self.ref = InferenceSession( onx.SerializeToString(), providers=tensor_class.providers, ) except InvalidArgument as e: if ( len(onx.graph.output) == 1 and onx.graph.output[0].type.tensor_type.elem_type == TensorProto.UNDEFINED ): # ShapeInference cannot use python function for unknown node type. # Let's give the only output the same type as the first # input. onx.graph.output[0].type.tensor_type.elem_type = onx.graph.input[ 0 ].type.tensor_type.elem_type self.ref = InferenceSession( onx.SerializeToString(), providers=tensor_class.providers, ) else: if len(onx.graph.node) <= 3: raise RuntimeError( f"Unable to create an InferenceSession with model {onx}." ) from e raise e self.input_names = input_names self.tensor_class = tensor_class self.output_names = [output.name for output in self.ref._outputs_meta] self.run_options = RunOptions() self._f = f
[docs] def run(self, *inputs: List["OrtTensor"]) -> List["OrtTensor"]: """ Executes the function. :param inputs: function inputs :return: outputs """ if len(inputs) != len(self.input_names): raise ValueError( f"Expected {len(self.input_names)} inputs but got " f"len(inputs)={len(inputs)}, f={self._f}." ) feeds = {} for name, inp in zip(self.input_names, inputs): feeds[name] = inp.value res = self.ref._sess.run_with_ort_values( feeds, self.output_names, self.run_options ) return list(map(inputs[0].__class__, res))
def __init__( self, tensor: Union[C_OrtValue, "OrtTensor", np.ndarray], _hold: Optional[np.ndarray] = None, ): if isinstance(tensor, C_OrtValue): self._tensor = tensor self._hold = _hold elif isinstance(tensor, OrtTensor): self._tensor = tensor._tensor self._hold = _hold elif isinstance(tensor, np.ndarray): if _hold is not None: raise RuntimeError( "tensor cannot be a numpy array and _hold be not None." ) self._tensor = C_OrtValue.ortvalue_from_numpy(tensor, OrtTensor.CPU) self._hold = tensor else: raise ValueError(f"An OrtValue is expected not {type(tensor)}.") def __repr__(self) -> str: "usual" return f"{self.__class__.__name__}(OrtTensor.from_array({self.numpy()!r}))" @property def device_name(self): return self._tensor.device_name() @property def ndim(self): return len(self.shape) @property def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." return tuple(self._tensor.shape()) @property def dtype(self) -> DType: "Returns the element type of this tensor." return DType(self._tensor.element_type()) @property def key(self) -> Any: "Unique key for a tensor of the same type." return (self.dtype, len(self.shape)) @property def value(self) -> C_OrtValue: "Returns the value of this tensor as a numpy array." return self._tensor @property def tensor_type(self) -> TensorType: "Returns the tensor type of this tensor." return TensorType[self.dtype] @property def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor has more than one dimension. It is always left undefined. """ if len(self._tensor.shape()) <= 1: # a scalar (len==0) or a 1D tensor return tuple(self._tensor.shape()) return (None, *tuple(self.shape[1:]))
[docs] def tensor_type_dims(self, name: str) -> TensorType: """ Returns the tensor type of this tensor. This property is used to define a key used to cache a jitted function. Same keys keys means same ONNX graph. Different keys usually means same ONNX graph but different input shapes. :param name: name of the constraint """ dt = self.dtype return TensorType[dt, self.dims, name]
[docs] @classmethod def create_function( cls: Any, input_names: List[str], onx: ModelProto, f: Callable ) -> Callable: """ Creates a python function calling the onnx backend used by this class. :param onx: onnx model :return: python function """ return cls.Evaluator(cls, input_names, onx, f=f)
class OrtCommon: """ Common methods to jit and eager mode. """ @classmethod def get_opsets(cls, opsets): if opsets is None: return {"": min(onnx_opset_version(), 18), "com.microsoft": 1} if "com.microsoft" in opsets: return opsets opsets = opsets.copy() opsets.update({"com.microsoft": 1}) return opsets @classmethod def get_ir_version(cls, ir_version): if ir_version is None: return 8 return min(ir_version, 8)
[docs] class EagerOrtTensor(OrtTensor, OrtCommon, EagerTensor): """ Defines a value for :epkg:`onnxruntime` as a backend. """ def __array_namespace__(self, api_version: Optional[str] = None): """ Returns the module holding all the available functions. """ if api_version is None or api_version == "2022.12": from onnx_array_api.array_api import onnx_ort return onnx_ort raise ValueError( f"Unable to return an implementation for api_version={api_version!r}." )
[docs] class JitOrtTensor(OrtTensor, OrtCommon, JitTensor): """ Defines a value for :epkg:`onnxruntime` as a backend. """ pass