Source code for onnx_diagnostic.reference.ort_evaluator

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from onnx import (
    GraphProto,
    FunctionProto,
    ModelProto,
    NodeProto,
    TypeProto,
    ValueInfoProto,
    helper as oh,
    load,
)
from onnx.defs import onnx_opset_version
import onnxruntime
from ..helpers import string_type
from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
from ..helpers.ort_session import (
    InferenceSessionForTorch,
    InferenceSessionForNumpy,
    _InferenceSession,
)

PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]


[docs] class OnnxruntimeEvaluator: """ This class loads an onnx model and the executes one by one the nodes with onnxruntime. This class is mostly meant for debugging. :param proto: proto or filename :param session_options: options :param providers: providers :param nvtx: enable nvidia events :param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers :param graph_optimization_level: see :class:`onnxruntime.SessionOptions` :param log_severity_level: see :class:`onnxruntime.SessionOptions` :param log_verbosity_level: see :class:`onnxruntime.SessionOptions` :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions` :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions` :param use_training_api: use onnxruntime-traning API :param verbose: verbosity :param local_functions: additional local function :param ir_version: ir version to use when unknown :param opsets: opsets to use when unknown """ def __init__( self, proto: Union[str, Proto, "OnnxruntimeEvaluator"], session_options: Optional[onnxruntime.SessionOptions] = None, providers: Optional[Union[str, List[str]]] = None, nvtx: bool = False, enable_profiling: bool = False, graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None, log_severity_level: Optional[int] = None, log_verbosity_level: Optional[int] = None, optimized_model_filepath: Optional[str] = None, disable_aot_function_inlining: Optional[bool] = None, use_training_api: bool = False, verbose: int = 0, local_functions: Optional[ Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]] ] = None, ir_version: int = 10, opsets: Optional[Union[int, Dict[str, int]]] = None, ): if isinstance(proto, str): self.proto: Proto = load(proto) elif isinstance(proto, OnnxruntimeEvaluator): assert isinstance( proto.proto, PROTO ), f"Unexpected type for proto.proto {type(proto.proto)}" self.proto = proto.proto else: self.proto = proto assert isinstance( self.proto, PROTO ), f"Unexpected type for self.proto {type(self.proto)}" self._cache: Dict[ Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037 ] = {} self.ir_version = ir_version self.opsets = opsets self.session_kwargs: Dict[str, Any] = dict( session_options=session_options, providers=providers, nvtx=nvtx, enable_profiling=enable_profiling, graph_optimization_level=graph_optimization_level, log_severity_level=log_severity_level, log_verbosity_level=log_verbosity_level, optimized_model_filepath=optimized_model_filepath, disable_aot_function_inlining=disable_aot_function_inlining, use_training_api=use_training_api, ) self.nodes = ( [self.proto] if isinstance(self.proto, NodeProto) else ( list( self.proto.graph.node if hasattr(self.proto, "graph") else self.proto.node ) ) ) self.rt_inits_ = ( {init.name: to_array_extended(init) for init in self.proto.graph.initializer} if hasattr(self.proto, "graph") else {} ) self.rt_nodes_ = self.nodes.copy() self.verbose = verbose self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037 {(f.domain, f.name): self.__class__(f) for f in self.proto.functions} if hasattr(self.proto, "functions") else {} ) if local_functions: self.local_functions.update(local_functions) @property def input_names(self) -> List[str]: "Returns input names." if isinstance(self.proto, NodeProto): return self.nodes[0].input return [ getattr(o, "name", o) for o in ( self.proto.graph.input if hasattr(self.proto, "graph") else self.proto.input ) ] @property def output_names(self) -> List[str]: "Returns output names." if isinstance(self.proto, NodeProto): return self.nodes[0].output return [ getattr(o, "name", o) for o in ( self.proto.graph.output if hasattr(self.proto, "graph") else self.proto.output ) ] @property def input_types(self) -> List[TypeProto]: "Returns input types." if not isinstance(self.proto, (ModelProto, GraphProto)): raise ValueError(f"Cannot guess input types for type {type(self.proto)}") g = self.proto.graph if hasattr(self.proto, "graph") else self.proto return [i.type for i in g.input] @property def output_types(self) -> List[TypeProto]: "Returns output types." if not isinstance(self.proto, (ModelProto, GraphProto)): raise ValueError(f"Cannot guess output types for type {type(self.proto)}") g = self.proto.graph if hasattr(self.proto, "graph") else self.proto return [i.type for i in g.output] def _log_arg(self, a: Any) -> Any: if isinstance(a, (str, int, float)): return a device = f"D{a.get_device()}:" if hasattr(a, "detach") else "" if hasattr(a, "shape"): if self.verbose < 4: # noqa: PLR2004 return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]" elements = a.ravel().tolist() if len(elements) > 10: # noqa: PLR2004 elements = elements[:10] return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..." return f"{device}{a.dtype}:{a.shape}:{elements}" if hasattr(a, "append"): return ", ".join(map(self._log_arg, a)) return a def _log(self, level: int, pattern: str, *args: Any) -> None: if level < self.verbose: new_args = [self._log_arg(a) for a in args] print(pattern % tuple(new_args)) def _is_local_function(self, node: NodeProto) -> bool: return (node.domain, node.op_type) in self.local_functions
[docs] def run( self, outputs: Optional[List[str]], feed_inputs: Dict[str, Any], intermediate: bool = False, ) -> Union[Dict[str, Any], List[Any]]: """ Runs the model. It only works with numpy arrays. :param outputs: required outputs or None for all :param feed_inputs: inputs :param intermediate: returns all output instead of the last ones :return: outputs, as a list if return_all is False, as a dictionary if return_all is True """ if outputs is None: outputs = self.output_names results: Dict[str, Any] = self.rt_inits_.copy() for k, v in self.rt_inits_.items(): self._log(2, " +C %s: %s", k, v) for k, v in feed_inputs.items(): self._log(2, " +I %s: %s", k, v) results[k] = v for node in self.rt_nodes_: self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i != "" and i not in results: raise RuntimeError( f"Unable to find input {i!r} in known results {sorted(results)}, " f"self.rt_inits_ has {sorted(self.rt_inits_)}, " f"feed_inputs has {sorted(feed_inputs)}." ) inputs = [(results[i] if i != "" else None) for i in node.input] if node.op_type == "If" and node.domain == "": outputs = self._run_if(node, inputs, results) elif self._is_local_function(node): outputs = self._run_local(node, inputs, results) else: outputs = self._run(node, inputs, results) for name, value in zip(node.output, outputs): if name == "": continue self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] assert isinstance(name, str), f"unexpected type for name {type(name)}" results[name] = value if intermediate: return results output_names = self.output_names for name in output_names: if name == "": continue if name not in results: raise RuntimeError( f"Unable to find output name {name!r} " f"in {sorted(results)}, proto is\n{pretty_onnx(self.proto)}" ) return [results[name] for name in output_names if name != ""]
def _make_model_proto( self, nodes: Sequence[NodeProto], vinputs: Sequence[ValueInfoProto], voutputs: Sequence[ValueInfoProto], ) -> ModelProto: onx = oh.make_model( oh.make_graph(nodes, "-", vinputs, voutputs), ir_version=getattr(self.proto, "ir_version", self.ir_version), functions=getattr(self.proto, "functions", None), ) del onx.opset_import[:] if hasattr(self.proto, "opset_import"): onx.opset_import.extend(self.proto.opset_import) elif self.opsets: if isinstance(self.opsets, int): onx.opset_import.append(oh.make_opsetid("", self.opsets)) else: onx.opset_import.extend( [oh.make_opsetid(k, v) for k, v in self.opsets.items()] ) else: onx.opset_import.append(oh.make_opsetid("", onnx_opset_version())) return onx def _get_sess( self, node: NodeProto, inputs: List[Any] ) -> Tuple[ModelProto, _InferenceSession]: unique_names = set() vinputs = [] for i, it in zip(node.input, inputs): if i == "" or i in unique_names: continue unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) # no need to run shape inference voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] onx = self._make_model_proto([node], vinputs, voutputs) cls = ( InferenceSessionForNumpy if any(isinstance(i, np.ndarray) for i in inputs) else InferenceSessionForTorch ) try: sess = cls(onx, **self.session_kwargs) except ( onnxruntime.capi.onnxruntime_pybind11_state.Fail, onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph, onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument, ) as e: raise RuntimeError( f"Unable to infer a session with inputs\n{string_type(inputs)}" f"\ndue to {e}\n{pretty_onnx(onx)}" ) from e return onx, sess def _get_sess_if( self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: unique_names = set() vinputs = [] for i, it in zip(node.input, inputs): if i == "" or i in unique_names: continue unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) for i, v in context.items(): if i not in unique_names: unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape) vinputs.append(value) for att in node.attribute: if att.name == branch: g = att.g voutputs = g.output onx = self._make_model_proto(g.node, vinputs, voutputs) sess = OnnxruntimeEvaluator( onx, local_functions=self.local_functions, verbose=self.verbose, ir_version=self.ir_version, opsets=self.opsets, **self.session_kwargs, ) return onx, sess def _get_sess_local( self, node: NodeProto, inputs: List[Any] ) -> Tuple[FunctionProto, "OnnxruntimeEvaluator"]: ev = self.local_functions[node.domain, node.op_type] sess = OnnxruntimeEvaluator( ev, local_functions=self.local_functions, verbose=self.verbose, ir_version=self.ir_version, opsets=self.opsets, **self.session_kwargs, ) return ev.proto, sess def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]: """Runs a node.""" types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: sess = self._cache[key][1] else: onx, sess = self._get_sess(node, inputs) self._cache[key] = onx, sess feeds = dict(zip(node.input, inputs)) if "" in feeds: feeds[""] = np.array([0], dtype=np.float32) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" outputs = list(sess.run(None, feeds)) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs def _run_if( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """Runs a node if.""" feeds = dict(zip(node.input, inputs)) feeds.update(results) if feeds[node.input[0]]: name = "then_branch" else: name = "else_branch" key = (id(node), name) if key in self._cache: sess = self._cache[key][1] else: self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs def _run_local( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """Runs a node.""" types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: sess = self._cache[key][1] else: onx, sess = self._get_sess_local(node, inputs) self._cache[key] = onx, sess replace = dict(zip(node.input, sess.input_names)) assert len(node.input) == len(sess.input_names), ( f"Input mismatch: input_names={sess.input_names}, " f"replace={replace}, " f"type(self.proto)={type(self.proto)}, and node=\n{node}" ) feeds = {replace[i]: v for i, v in zip(node.input, inputs)} if "" in feeds: feeds[""] = np.array([0], dtype=np.float32) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs