Source code for onnx_diagnostic.reference.ort_evaluator

from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
from onnx import (
    AttributeProto,
    GraphProto,
    FunctionProto,
    ModelProto,
    NodeProto,
    TypeProto,
    ValueInfoProto,
    helper as oh,
    load,
    save as onnx_save,
    shape_inference as shi,
)
from onnx.defs import onnx_opset_version
import onnxruntime
from ..helpers import string_type
from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
from ..helpers.ort_session import (
    InferenceSessionForTorch,
    InferenceSessionForNumpy,
    _InferenceSession,
)
from .evaluator import ExtendedReferenceEvaluator

PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]


[docs] class OnnxruntimeEvaluator: """ This class loads an onnx model and the executes one by one the nodes with onnxruntime. This class is mostly meant for debugging. :param proto: proto or filename :param session_options: options :param providers: providers :param nvtx: enable nvidia events :param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers :param graph_optimization_level: see :class:`onnxruntime.SessionOptions` :param log_severity_level: see :class:`onnxruntime.SessionOptions` :param log_verbosity_level: see :class:`onnxruntime.SessionOptions` :param optimized_model_filepath: see :class:`onnxruntime.SessionOptions` :param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions` :param use_training_api: use onnxruntime-traning API :param verbose: verbosity :param local_functions: additional local function :param ir_version: ir version to use when unknown :param opsets: opsets to use when unknown :param whole: if True, do not split node by node """ def __init__( self, proto: Union[str, Proto, "OnnxruntimeEvaluator"], session_options: Optional[onnxruntime.SessionOptions] = None, providers: Optional[Union[str, List[str]]] = None, nvtx: bool = False, enable_profiling: bool = False, graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None, log_severity_level: Optional[int] = None, log_verbosity_level: Optional[int] = None, optimized_model_filepath: Optional[str] = None, disable_aot_function_inlining: Optional[bool] = None, use_training_api: bool = False, verbose: int = 0, local_functions: Optional[ Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]] ] = None, ir_version: int = 10, opsets: Optional[Union[int, Dict[str, int]]] = None, whole: bool = False, ): if isinstance(proto, str): self.proto: Proto = load(proto) elif isinstance(proto, OnnxruntimeEvaluator): assert isinstance( proto.proto, PROTO ), f"Unexpected type for proto.proto {type(proto.proto)}" self.proto = proto.proto else: self.proto = proto assert isinstance( self.proto, PROTO ), f"Unexpected type for self.proto {type(self.proto)}" self._cache: Dict[ Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037 ] = {} self.ir_version = ir_version self.opsets = opsets self.session_kwargs: Dict[str, Any] = dict( session_options=session_options, providers=providers, nvtx=nvtx, enable_profiling=enable_profiling, graph_optimization_level=graph_optimization_level, log_severity_level=log_severity_level, log_verbosity_level=log_verbosity_level, optimized_model_filepath=optimized_model_filepath, disable_aot_function_inlining=disable_aot_function_inlining, use_training_api=use_training_api, ) self.verbose = verbose self.sess_: Optional[_InferenceSession] = None if whole: self.nodes: Optional[List[NodeProto]] = None self.rt_inits_: Optional[Dict[str, Any]] = None self.rt_nodes_: Optional[List[NodeProto]] = None else: self.nodes = ( [self.proto] if isinstance(self.proto, NodeProto) else ( list( self.proto.graph.node if hasattr(self.proto, "graph") else self.proto.node ) ) ) self.rt_inits_ = ( {init.name: to_array_extended(init) for init in self.proto.graph.initializer} if hasattr(self.proto, "graph") else {} ) self.rt_nodes_ = self.nodes.copy() self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037 {(f.domain, f.name): self.__class__(f) for f in self.proto.functions} if hasattr(self.proto, "functions") else {} ) if local_functions: self.local_functions.update(local_functions) self.garbage_collector = self._build_garbage_collector() if self.rt_nodes_ else {} @property def input_names(self) -> List[str]: "Returns input names." assert self.proto, "self.proto is empty" if isinstance(self.proto, NodeProto): assert isinstance( self.nodes, list ), f"Unexpected type {type(self.nodes)} for self.nodes" return self.nodes[0].input return [ getattr(o, "name", o) for o in ( self.proto.graph.input if hasattr(self.proto, "graph") else self.proto.input ) ] @property def output_names(self) -> List[str]: "Returns output names." assert self.proto, "self.proto is empty" if isinstance(self.proto, NodeProto): assert isinstance( self.nodes, list ), f"Unexpected type {type(self.nodes)} for self.nodes" return self.nodes[0].output return [ getattr(o, "name", o) for o in ( self.proto.graph.output if hasattr(self.proto, "graph") else self.proto.output ) ] @property def input_types(self) -> List[TypeProto]: "Returns input types." if not isinstance(self.proto, (ModelProto, GraphProto)): raise ValueError(f"Cannot guess input types for type {type(self.proto)}") g = self.proto.graph if hasattr(self.proto, "graph") else self.proto return [i.type for i in g.input] @property def output_types(self) -> List[TypeProto]: "Returns output types." if not isinstance(self.proto, (ModelProto, GraphProto)): raise ValueError(f"Cannot guess output types for type {type(self.proto)}") g = self.proto.graph if hasattr(self.proto, "graph") else self.proto return [i.type for i in g.output] def _log_arg(self, a: Any) -> Any: if isinstance(a, (str, int, float)): return a device = f"D{a.get_device()}:" if hasattr(a, "detach") else "" if hasattr(a, "shape"): if self.verbose < 4: # noqa: PLR2004 return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]" elements = a.ravel().tolist() if len(elements) > 10: # noqa: PLR2004 elements = elements[:10] return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..." return f"{device}{a.dtype}:{a.shape}:{elements}" if hasattr(a, "append"): return ", ".join(map(self._log_arg, a)) return a def _log(self, level: int, pattern: str, *args: Any) -> None: if level < self.verbose: new_args = [self._log_arg(a) for a in args] print(pattern % tuple(new_args)) def _is_local_function(self, node: NodeProto) -> bool: return (node.domain, node.op_type) in self.local_functions
[docs] def run( self, outputs: Optional[List[str]], feed_inputs: Dict[str, Any], intermediate: bool = False, ) -> Union[Dict[str, Any], List[Any]]: """ Runs the model. It only works with numpy arrays. :param outputs: required outputs or None for all :param feed_inputs: inputs :param intermediate: returns all output instead of the last ones :return: outputs, as a list if return_all is False, as a dictionary if return_all is True """ if self.rt_nodes_ is None: # runs a whole if self.sess_ is None: assert self.proto, "self.proto is empty" _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values())) assert self.sess_, "mypy not happy" return self.sess_.run(outputs, feed_inputs) if outputs is None: outputs = self.output_names results: Dict[str, Any] = (self.rt_inits_ or {}).copy() for k, v in results.items(): self._log(2, " +C %s: %s", k, v) for k, v in feed_inputs.items(): assert not isinstance(v, str), f"Unexpected type str for {k!r}" self._log(2, " +I %s: %s", k, v) results[k] = v for i_node, node in enumerate(self.rt_nodes_ or []): self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i != "" and i not in results: raise RuntimeError( f"Unable to find input {i!r} in known results {sorted(results)}, " f"self.rt_inits_ has {sorted((self.rt_inits_ or {}))}, " f"feed_inputs has {sorted(feed_inputs)}." ) inputs = [(results[i] if i != "" else None) for i in node.input] if node.op_type == "If" and node.domain == "": outputs = self._run_if(node, inputs, results) elif node.op_type in {"Scan", "Loop"} and node.domain == "": outputs = self._run_scan(node, inputs, results) elif self._is_local_function(node): outputs = self._run_local(node, inputs, results) else: outputs = self._run(node, inputs, results) for name, value in zip(node.output, outputs): if name == "": continue self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] assert isinstance(name, str), f"unexpected type for name {type(name)}" results[name] = value if not intermediate: self._clean_unused_inplace(i_node, node, results) if intermediate: return results output_names = self.output_names for name in output_names: if name == "": continue if name not in results: raise RuntimeError( f"Unable to find output name {name!r} " f"in {sorted(results)}, proto is\n{pretty_onnx(self.proto)}" ) return [results[name] for name in output_names if name != ""]
def _build_garbage_collector(self) -> Dict[str, int]: """ Memorizes the results not needed anymore for every node. Returns a dictionary with the last node using the results. """ needed = {} for i, node in enumerate(self.rt_nodes_ or []): for name in node.input: needed[name] = i if node.op_type in {"Scan", "If", "Loop"}: hidden = self._get_hidden_node_inputs(node) for name in hidden: needed[name] = i if isinstance(self.proto, ModelProto): for o in self.proto.graph.output: needed[o.name] = len(self.rt_nodes_ or []) elif isinstance(self.proto, GraphProto): for o in self.proto.output: needed[o.name] = len(self.rt_nodes_ or []) elif isinstance(self.proto, FunctionProto): for o in self.proto.output: needed[o] = len(self.rt_nodes_ or []) return needed def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]): """ Cleans all results not needed anymore. Some models requires to clean the memory to be able to run. """ if not self.garbage_collector: return for name in node.input: if self.garbage_collector[name] == i_node and name in results: if self.verbose: t = results[name] print(f" - deletes: {name} - {t.dtype}:{t.shape}") del results[name] if node.op_type in {"Scan", "If", "Loop"}: hidden = self._get_hidden_node_inputs(node) for name in hidden: if self.garbage_collector[name] == i_node and name in results: if self.verbose: t = results[name] print(f" - deletes: {name} - {t.dtype}:{t.shape}") del results[name] def _make_model_proto( self, nodes: Sequence[NodeProto], vinputs: Sequence[ValueInfoProto], voutputs: Sequence[ValueInfoProto], ) -> ModelProto: onx = oh.make_model( oh.make_graph(nodes, "-", vinputs, voutputs), ir_version=getattr(self.proto, "ir_version", self.ir_version), functions=getattr(self.proto, "functions", None), ) del onx.opset_import[:] if hasattr(self.proto, "opset_import"): onx.opset_import.extend(self.proto.opset_import) elif self.opsets: if isinstance(self.opsets, int): onx.opset_import.append(oh.make_opsetid("", self.opsets)) else: onx.opset_import.extend( [oh.make_opsetid(k, v) for k, v in self.opsets.items()] ) else: onx.opset_import.append(oh.make_opsetid("", onnx_opset_version())) # That helps fixing bugs. onx = shi.infer_shapes(onx) return onx @classmethod def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]: """ Returns the hidden inputs (inputs coming from an upper context) used by a subgraph. """ hidden = set() memo = set(i.name for i in graph.initializer) memo |= set(i.name for i in graph.sparse_initializer) for node in graph.node: for i in node.input: if i not in memo: hidden.add(i) for att in node.attribute: if att.type == AttributeProto.GRAPH and att.g: hid = self._get_hidden_inputs(att.g) less = set(h for h in hid if h not in memo) hidden |= less memo |= set(node.output) return hidden @classmethod def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]: """Calls multiple _get_hidden_inputs on every attribute.""" if node.op_type not in {"Loop", "Scan", "If"}: return set() hidden = set() for att in node.attribute: if att.type == AttributeProto.GRAPH: hidden |= self._get_hidden_inputs(att.g) return hidden - (hidden & set(node.input)) def _get_sess( self, node: Union[ModelProto, NodeProto], inputs: List[Any] ) -> Tuple[ModelProto, _InferenceSession]: if isinstance(node, ModelProto): onx = node else: assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node" if node.op_type == "Constant": # We force the type to be a boolean. ref = ExtendedReferenceEvaluator(node) cst = ref.run(None, {})[0] vinputs: List[ValueInfoProto] = [] voutputs = [ oh.make_tensor_value_info( node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape ) ] else: unique_names = set() vinputs = [] for i, it in zip(node.input, inputs): if i == "" or i in unique_names: continue unique_names.add(i) value = oh.make_tensor_value_info( i, dtype_to_tensor_dtype(it.dtype), it.shape ) vinputs.append(value) # no need to run shape inference voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output] onx = self._make_model_proto([node], vinputs, voutputs) cls = ( InferenceSessionForNumpy if any(isinstance(i, np.ndarray) for i in inputs) else InferenceSessionForTorch ) try: sess = cls(onx, **self.session_kwargs) except ( onnxruntime.capi.onnxruntime_pybind11_state.Fail, onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph, onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument, ) as e: onnx_save(onx, "_debug_OnnxruntimeEvaluator_last_failure.onnx") raise RuntimeError( f"Unable to infer a session with inputs\n{string_type(inputs)}" f"\ndue to {e}\n{pretty_onnx(onx)}" ) from e return onx, sess def _get_sess_init_subgraph( self, node: NodeProto, inputs: List[Any], context: Dict[str, Any], g: GraphProto ) -> List[Any]: unique_names = set() vinputs = [] for i, it in zip(node.input, inputs): if i == "" or i in unique_names: continue unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape) vinputs.append(value) reduced_set = self._get_hidden_inputs(g) for i, v in context.items(): if i in reduced_set and i not in unique_names: unique_names.add(i) value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape) vinputs.append(value) return vinputs def _get_sess_if( self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: g = None for att in node.attribute: if att.name == branch: g = att.g assert g, f"Missing attribute {branch!r}" vinputs = self._get_sess_init_subgraph(node, inputs, context, g) voutputs = g.output identities = [ oh.make_node("Identity", [iname], [ginput.name]) for iname, ginput in zip(node.input, g.input) ] onx = self._make_model_proto([*identities, *g.node], vinputs, voutputs) sess = OnnxruntimeEvaluator( onx, local_functions=self.local_functions, verbose=self.verbose, ir_version=self.ir_version, opsets=self.opsets, **self.session_kwargs, ) return onx, sess def _get_sess_local( self, node: NodeProto, inputs: List[Any] ) -> Tuple[FunctionProto, "OnnxruntimeEvaluator"]: ev = self.local_functions[node.domain, node.op_type] sess = OnnxruntimeEvaluator( ev, local_functions=self.local_functions, verbose=self.verbose, ir_version=self.ir_version, opsets=self.opsets, **self.session_kwargs, ) return ev.proto, sess def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]: """Runs a node.""" types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: sess = self._cache[key][1] else: onx, sess = self._get_sess(node, inputs) self._cache[key] = onx, sess feeds = dict(zip(node.input, inputs)) if "" in feeds: feeds[""] = np.array([0], dtype=np.float32) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" outputs = list(sess.run(None, feeds)) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs def _run_if( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """Runs a node If.""" feeds = dict(zip(node.input, inputs)) feeds.update(results) if feeds[node.input[0]]: name = "then_branch" else: name = "else_branch" key = (id(node), name) if key in self._cache: sess = self._cache[key][1] else: self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" feeds = {name: results[name] for name in sess.input_names} outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs def _get_sess_scan( self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any] ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]: g = None for att in node.attribute: if att.name == branch: g = att.g assert g, f"Missing attribute {branch!r}" vinputs = self._get_sess_init_subgraph(node, inputs, context, g) begin = 0 if node.op_type == "Scan" else 1 voutputs = [] for name, _goutput in zip(node.output, g.output[begin:]): v = ValueInfoProto() # v.ParseFromString(goutput.SerializeToString()) v.name = name voutputs.append(v) # identities = [] # for iname, ginput in zip(node.input, g.input): # identities.append(oh.make_node("Identity", [iname], [ginput.name])) onx = self._make_model_proto([node], vinputs, voutputs) sess = OnnxruntimeEvaluator( onx, local_functions=self.local_functions, verbose=self.verbose, ir_version=self.ir_version, opsets=self.opsets, whole=True, **self.session_kwargs, ) return onx, sess def _run_scan( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """Runs a node Scan.""" feeds = dict(zip(node.input, inputs)) feeds.update(results) name = "body" key = (id(node), name) if key in self._cache: sess = self._cache[key][1] else: self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" feeds = {name: results[name] for name in sess.input_names} outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs def _run_local( self, node: NodeProto, inputs: List[Any], results: Dict[str, Any] ) -> List[Any]: """Runs a node.""" types = [(None if a is None else (a.dtype, a.shape)) for a in inputs] key = (id(node), *types) if key in self._cache: sess = self._cache[key][1] else: onx, sess = self._get_sess_local(node, inputs) self._cache[key] = onx, sess replace = dict(zip(node.input, sess.input_names)) assert len(node.input) == len(sess.input_names), ( f"Input mismatch: input_names={sess.input_names}, " f"replace={replace}, " f"type(self.proto)={type(self.proto)}, and node=\n{node}" ) feeds = {replace[i]: v for i, v in zip(node.input, inputs)} if "" in feeds: feeds[""] = np.array([0], dtype=np.float32) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" outputs = sess.run(None, feeds) assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}" return outputs