Source code for experimental_experiment.reference.ort_evaluator

from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from onnx import ModelProto, NodeProto, TensorProto, TypeProto, helper as oh, load
from onnx.numpy_helper import to_array
from ..helpers import pretty_onnx


[docs] class OrtEval: """ This class loads an onnx model and the executes one by one the nodes with onnxruntime. This class is mostly meant for debugging. :param proto: ModelProto or filaname :param providers: providers :param options: session options :param verbose: verbosity :param whole: run the whole model instead instead of node by node :param incremental: run the model node by node, but for every node, executes the graph up to that node :param optimized_model_filepath: export the optimized graph """ def __init__( self, proto: Union[str, ModelProto], providers: Optional[Union[str, List[str]]] = None, options: Optional["onnxruntime.SessionOptions"] = None, # noqa: F821 verbose: int = 0, whole: bool = False, incremental: bool = False, optimized_model_filepath: Optional[str] = None, ): self.session_options = options if self.session_options is None: import onnxruntime self.session_options = onnxruntime.SessionOptions() if optimized_model_filepath: self.session_options.optimized_model_filepath = optimized_model_filepath if verbose >= 30: import onnxruntime self.session_options.log_severity_level = 0 self.session_options.log_verbosity_level = 0 self.run_options = onnxruntime.RunOptions() self.run_options.log_severity_level = 0 self.run_options.log_verbosity_level = 0 else: import onnxruntime self.run_options = onnxruntime.RunOptions() if providers is None or providers in ("cpu", "CPU"): providers = ["CPUExecutionProvider"] elif providers in ("cuda", "CUDA"): providers = ["CUDAExecutionProvider"] elif not isinstance(providers, list) and providers.startswith("cuda"): device_id = 0 if ":" not in self.device else int(self.device.split(":")[1]) providers = [ ("CUDAExecutionProvider", {"device_id": device_id}), ("CPUExecutionProvider", {}), ] self.providers = providers self._cache = {} if isinstance(proto, str): proto = load(proto) assert isinstance(proto, ModelProto), f"Unexpected type {type(proto)}" self.proto = proto self.nodes = list(proto.graph.node) self.rt_inits_ = {init.name: to_array(init) for init in proto.graph.initializer} self.rt_nodes_ = list(self.proto.graph.node) self.verbose = verbose self.whole = whole self.incremental = incremental assert not whole or not incremental, ( f"whole={whole} and incremental={incremental} " f"cannot be both True at the same time." ) try: import torch except ImportError: return self.torch = torch self.TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, torch.uint8: np.uint8, torch.int8: np.int8, torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, torch.bool: np.bool_, } self.TORCH_DTYPE_TO_ONNX_DTYPE = { torch.float16: TensorProto.FLOAT16, torch.bfloat16: TensorProto.BFLOAT16, torch.float32: TensorProto.FLOAT, torch.float64: TensorProto.DOUBLE, torch.uint8: TensorProto.UINT8, torch.int8: TensorProto.INT8, torch.int16: TensorProto.INT16, torch.int32: TensorProto.INT32, torch.int64: TensorProto.INT64, torch.bool: TensorProto.BOOL, } def _get_torch_dtype(self, dt: Any) -> "torch.dtype": # noqa: F821 if dt == np.bool_: return self.torch.bool if dt == np.float32: return self.torch.float32 if dt == np.float16: return self.torch.float16 if dt == np.int64: return self.torch.int64 raise AssertionError(f"Unexpected type {type(dt)}: {dt}") def _get_itype(self, dt: Any) -> int: if isinstance(dt, int): return dt if dt in self.TORCH_DTYPE_TO_ONNX_DTYPE: return self.TORCH_DTYPE_TO_ONNX_DTYPE[dt] if dt in self.TORCH_DTYPE_TO_NUMPY_DTYPE: dt = self.TORCH_DTYPE_TO_NUMPY_DTYPE[dt] if dt == np.bool_: return TensorProto.BOOL if dt == np.float32: return TensorProto.FLOAT if dt == np.float16: return TensorProto.FLOAT16 if dt == np.int64: return TensorProto.INT64 return oh.np_dtype_to_tensor_dtype(dt) def _log(self, level: int, pattern: str, *args: List[Any]) -> None: if level < self.verbose: new_args = [self._log_arg(a) for a in args] print(pattern % tuple(new_args)) def _log_arg(self, a: Any) -> Any: if isinstance(a, (str, int, float)): return a if hasattr(a, "detach"): device = f"D{a.get_device()}:" a = a.detach().cpu().numpy() else: device = -1 if isinstance(a, np.ndarray): if self.verbose < 4: # noqa: PLR2004 return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]" elements = a.ravel().tolist() if len(elements) > 5: # noqa: PLR2004 elements = elements[:5] return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..." return f"{device}{a.dtype}:{a.shape}:{elements}" if hasattr(a, "append"): return ", ".join(map(self._log_arg, a)) return a def _log(self, level: int, pattern: str, *args: List[Any]) -> None: if level < self.verbose: new_args = [self._log_arg(a) for a in args] print(pattern % tuple(new_args))
[docs] def run( self, outputs: Optional[List[str]], feed_inputs: Dict[str, Any], intermediate: bool = False, ) -> Union[Dict[str, Any], List[Any]]: """ Runs the model. It only works with numpy arrays. :param outputs: required outputs or None for all :param feed_inputs: inputs :param intermediate: returns all output instead of the last ones :return: outputs, as a list if return_all is False, as a dictionary if return_all is True """ if self.whole: if "" in self._cache: sess = self._cache[""] else: import onnxruntime sess = onnxruntime.InferenceSession( self.proto.SerializeToString(), self.session_options, providers=self.providers, ) self._cache[""] = sess return sess.run(outputs, feed_inputs, run_options=self.run_options) if outputs is None: outputs = [o.name for o in self.proto.graph.output] results = self.rt_inits_.copy() for k, v in self.rt_inits_.items(): self._log(2, " +C %s: %s", k, v) for k, v in feed_inputs.items(): self._log(2, " +I %s: %s", k, v) results[k] = v for node in self.rt_nodes_: self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i != "" and i not in results: raise RuntimeError( f"Unable to find input {i!r} in known results {sorted(results)}, " f"self.rt_inits_ has {sorted(self.rt_inits_)}, " f"feed_inputs has {sorted(feed_inputs)}." ) inputs = [(results[i] if i != "" else None) for i in node.input] outputs = self._run(node, inputs, results) for name, value in zip(node.output, outputs): if name == "": continue self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] results[name] = value if intermediate: return results output_names = [o.name for o in self.proto.graph.output] for name in output_names: if name == "": continue if name not in results: raise RuntimeError( f"Unable to find output name {name!r} " f"in {sorted(results)}, proto is\n{self.proto_}" ) return [results[name] for name in output_names if name != ""]
def _get_sess_incremental( self, node: NodeProto ) -> Tuple[ModelProto, "onnxruntime.InferenceSession"]: # noqa: F821 import onnxruntime from ..xbuilder import GraphBuilder, OptimizationOptions builder = GraphBuilder( self.proto, optimization_options=OptimizationOptions(patterns=None) ) builder.select_outputs([n for n in node.output if n != ""]) onx = builder.to_onnx() try: sess = onnxruntime.InferenceSession( onx.SerializeToString(), self.session_options, self.providers ) except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e: raise AssertionError( f"Cannot create a session for node\n-----\n{node}\n----------\n{onx}" ) from e return onx, sess def _get_sess( self, node: NodeProto, inputs: List[Any] ) -> Tuple[ModelProto, "onnxruntime.InferenceSession"]: # noqa: F821 if self.incremental: return self._get_sess_incremental(node) import onnxruntime unique_names = set() vinputs = [] for i, it in zip(node.input, inputs): if i == "" or i in unique_names: continue unique_names.add(i) value = oh.make_tensor_value_info(i, self._get_itype(it.dtype), it.shape) vinputs.append(value) voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] onx = oh.make_model( oh.make_graph([node], "node", vinputs, voutputs), ir_version=self.proto.ir_version, functions=self.proto.functions, ) del onx.opset_import[:] onx.opset_import.extend(self.proto.opset_import) try: sess = onnxruntime.InferenceSession( onx.SerializeToString(), self.session_options, self.providers ) except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e: raise RuntimeError( f"Unable to infer a session due to {e}\n{pretty_onnx(onx)}" ) from e return onx, sess def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]: """ Runs a node. """ types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: sess = self._cache[key][1] else: self._cache[key] = onx, sess = self._get_sess(node, inputs) if self.incremental: # Inputs are the inputs of the model not the node. feeds = {} for i in self.proto.graph.input: feeds[i.name] = results[i.name] else: feeds = dict(zip(node.input, inputs)) if "" in feeds: feeds[""] = np.array([0], dtype=np.float32) outputs = sess.run(None, feeds, run_options=self.run_options) return outputs
[docs] def run_dlpack( self, outputs: Optional[List[str]], feed_inputs: Dict[str, Any] ) -> List[Any]: """ Runs the model using :epkg:`run_with_ortvaluevector`. It only works with :class:`torch.Tensor`. :param outputs: required outputs or None for all :param feed_inputs: inputs :return: outputs """ if self.whole: from onnxruntime.capi import _pybind_state as ORTC if "" in self._cache: sess = self._cache[""] else: import onnxruntime sess = onnxruntime.InferenceSession( self.proto.SerializeToString(), self.session_options, providers=self.providers, ) self._cache[""] = sess input_names = [i.name for i in self.proto.graph.input] output_names = [i.name for i in self.proto.graph.output] inputs = [feed_inputs[i] for i in input_names] ortvalues, output_devices = self._get_ortvalues_from_torch_tensors( inputs, len(self.proto.graph.output) ) ort_outputs = ORTC.OrtValueVector() sess.run_with_ortvaluevector( self.run_options, input_names, ortvalues, output_names, ort_outputs, output_devices, ) pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs) return pth_outputs if outputs is None: outputs = [o.name for o in self.proto.graph.output] if not hasattr(self, "rt_inits_torch_"): import torch def torch_Tensor(v): if len(v.shape) == 0: vv = torch_Tensor(np.expand_dims(v, 0)) return vv.squeeze() try: return torch.Tensor(v) except RuntimeError as e: raise RuntimeError( f"Unable to convert {[v]} (type={type(v)}, " f"dtype={v.dtype}, shape={v.shape}) into a torch Tensor." ) from e self.rt_inits_torch_ = { k: torch_Tensor(v).to(self._get_torch_dtype(v.dtype)) for k, v in self.rt_inits_.items() } if "CUDAExecutionProvider" in self.providers: ts = self.rt_inits_torch_ self.rt_inits_torch_ = {} for k, v in ts.items(): if v.dtype in (torch.float32,) and len(v.shape) == 0: pass elif v.dtype not in (torch.int64, torch.bool): v = v.cuda() self.rt_inits_torch_[k] = v results = self.rt_inits_torch_.copy() for k, v in self.rt_inits_.items(): self._log(2, " +C %s: %s", k, v) for k, v in feed_inputs.items(): self._log(2, " +I %s: %s", k, v) results[k] = v for node in self.rt_nodes_: self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i != "" and i not in results: raise RuntimeError( f"Unable to find input {i!r} in known results {sorted(results)}, " f"self.rt_inits_ has {sorted(self.rt_inits_)}, " f"feed_inputs has {sorted(feed_inputs)}." ) inputs = [(results[i] if i != "" else None) for i in node.input] outputs = self._run_dlpack(node, inputs, results) for name, value in zip(node.output, outputs): if name == "": continue self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] results[name] = value output_names = [o.name for o in self.proto.graph.output] for name in output_names: if name not in results: raise RuntimeError( f"Unable to find output name {name!r} " f"in {sorted(results)}, proto is\n{self.proto_}" ) return [results[name] for name in output_names]
def _get_ortvalues_from_torch_tensors( self, tensors: Tuple["torch.Tensor", ...], # noqa: F821 n_outputs: int, log_set: Optional[List[Any]] = None, ) -> Tuple[Tuple["torch.Tensor", ...], Tuple["OrtDevice", ...], Any]: # noqa: F821 import torch from onnxruntime.capi import _pybind_state as ORTC DEVICES = { -1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0) } if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): DEVICES[i] = ORTC.OrtDevice( ORTC.OrtDevice.cuda(), ORTC.OrtDevice.default_memory(), i ) ortvalues = ORTC.OrtValueVector() ortvalues.reserve(len(tensors)) dtypes = [] shapes = [] data_ptrs = [] devices = [] max_device = -1 assert isinstance(max_device, int), f"unexpected type for device={max_device!r}" assert tensors is not None, "tensors cannot be None" new_tensors = [] for pos, tensor in enumerate(tensors): if tensor is None: tensor = torch.Tensor(np.array([0], dtype=np.float32)).to( "cuda" if max_device >= 0 else "cpu" ) assert isinstance(tensor, torch.Tensor), f"Unexpected type {type(tensor)}" dtypes.append(self.TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype]) shapes.append(tensor.size()) data_ptrs.append(tensor.data_ptr()) d = tensor.get_device() if self.verbose > 10: if log_set is None: print(f" < p={pos} d={d} dtype={dtypes[-1]} shape={tensor.shape}") devices.append(DEVICES[d]) new_tensors.append(tensor) max_device = max(max_device, tensor.get_device()) if self.verbose > 10 and log_set: for pos, tensor in enumerate(log_set): if tensor is None: continue d = tensor.get_device() dt = self.TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype] print(f" < p={pos} d={d} dtype={dt} shape={tensor.shape}") ortvalues.push_back_batch(new_tensors, data_ptrs, dtypes, shapes, devices) output_devices = [] for i in range(n_outputs): dev = DEVICES[max_device] output_devices.append(dev) if self.verbose > 10: print(f" > p={i} d={max_device}") return ortvalues, output_devices def _ortvalues_to_torch_tensor( self, ortvalues: "onnxruntime.OrtValueVector", # noqa: F821 ) -> Tuple["torch.Tensor", ...]: # noqa: F821 if len(ortvalues) == 0: return tuple() from torch._C import _from_dlpack if all(ortvalues[i].has_value() for i in range(len(ortvalues))): res = ortvalues.to_dlpacks(_from_dlpack) else: res = [] for i in range(len(ortvalues)): if ortvalues[i].has_value(): res.append(_from_dlpack(ortvalues[i].to_dlpack())) else: res.append(None) return tuple(res) def _run_dlpack( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """ Runs a node. """ from onnxruntime.capi import _pybind_state as ORTC types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: sess = self._cache[key][1] else: self._cache[key] = onx, sess = self._get_sess(node, inputs) onames = [n for n in node.output if n != ""] if self.incremental: # Inputs are the inputs of the model not the node. former_inputs = inputs inputs = [] input_names = [] for i in self.proto.graph.input: if i.name == "": continue inputs.append(results[i.name]) input_names.append(i.name) ortvalues, output_devices = self._get_ortvalues_from_torch_tensors( inputs, len(onames), log_set=former_inputs ) assert len(onames) == len(output_devices), ( f"Length mismatch {onames} but {len(output_devices)} devices, " f"node.output={node.output}." ) ort_outputs = ORTC.OrtValueVector() sess.run_with_ortvaluevector( self.run_options, input_names, ortvalues, onames, ort_outputs, output_devices, ) pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs) return pth_outputs ortvalues, output_devices = self._get_ortvalues_from_torch_tensors(inputs, len(onames)) assert len(onames) == len(output_devices), ( f"Length mismatch {onames} but {len(output_devices)} devices, " f"node.output={node.output}." ) ort_outputs = ORTC.OrtValueVector() sess.run_with_ortvaluevector( self.run_options, node.input, ortvalues, onames, ort_outputs, output_devices, ) pth_outputs = self._ortvalues_to_torch_tensor(ort_outputs) return pth_outputs