from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
from onnx import (
GraphProto,
FunctionProto,
ModelProto,
NodeProto,
TypeProto,
ValueInfoProto,
helper as oh,
load,
)
from onnx.defs import onnx_opset_version
import onnxruntime
from ..helpers import string_type
from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
from ..helpers.ort_session import (
InferenceSessionForTorch,
InferenceSessionForNumpy,
_InferenceSession,
)
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
[docs]
class OnnxruntimeEvaluator:
"""
This class loads an onnx model and the executes one by one the nodes
with onnxruntime. This class is mostly meant for debugging.
:param proto: proto or filename
:param session_options: options
:param providers: providers
:param nvtx: enable nvidia events
:param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
:param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
:param log_severity_level: see :class:`onnxruntime.SessionOptions`
:param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
:param use_training_api: use onnxruntime-traning API
:param verbose: verbosity
:param local_functions: additional local function
:param ir_version: ir version to use when unknown
:param opsets: opsets to use when unknown
"""
def __init__(
self,
proto: Union[str, Proto, "OnnxruntimeEvaluator"],
session_options: Optional[onnxruntime.SessionOptions] = None,
providers: Optional[Union[str, List[str]]] = None,
nvtx: bool = False,
enable_profiling: bool = False,
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
log_severity_level: Optional[int] = None,
log_verbosity_level: Optional[int] = None,
optimized_model_filepath: Optional[str] = None,
disable_aot_function_inlining: Optional[bool] = None,
use_training_api: bool = False,
verbose: int = 0,
local_functions: Optional[
Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]]
] = None,
ir_version: int = 10,
opsets: Optional[Union[int, Dict[str, int]]] = None,
):
if isinstance(proto, str):
self.proto: Proto = load(proto)
elif isinstance(proto, OnnxruntimeEvaluator):
assert isinstance(
proto.proto, PROTO
), f"Unexpected type for proto.proto {type(proto.proto)}"
self.proto = proto.proto
else:
self.proto = proto
assert isinstance(
self.proto, PROTO
), f"Unexpected type for self.proto {type(self.proto)}"
self._cache: Dict[
Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
] = {}
self.ir_version = ir_version
self.opsets = opsets
self.session_kwargs: Dict[str, Any] = dict(
session_options=session_options,
providers=providers,
nvtx=nvtx,
enable_profiling=enable_profiling,
graph_optimization_level=graph_optimization_level,
log_severity_level=log_severity_level,
log_verbosity_level=log_verbosity_level,
optimized_model_filepath=optimized_model_filepath,
disable_aot_function_inlining=disable_aot_function_inlining,
use_training_api=use_training_api,
)
self.nodes = (
[self.proto]
if isinstance(self.proto, NodeProto)
else (
list(
self.proto.graph.node if hasattr(self.proto, "graph") else self.proto.node
)
)
)
self.rt_inits_ = (
{init.name: to_array_extended(init) for init in self.proto.graph.initializer}
if hasattr(self.proto, "graph")
else {}
)
self.rt_nodes_ = self.nodes.copy()
self.verbose = verbose
self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037
{(f.domain, f.name): self.__class__(f) for f in self.proto.functions}
if hasattr(self.proto, "functions")
else {}
)
if local_functions:
self.local_functions.update(local_functions)
@property
def input_names(self) -> List[str]:
"Returns input names."
if isinstance(self.proto, NodeProto):
return self.nodes[0].input
return [
getattr(o, "name", o)
for o in (
self.proto.graph.input if hasattr(self.proto, "graph") else self.proto.input
)
]
@property
def output_names(self) -> List[str]:
"Returns output names."
if isinstance(self.proto, NodeProto):
return self.nodes[0].output
return [
getattr(o, "name", o)
for o in (
self.proto.graph.output if hasattr(self.proto, "graph") else self.proto.output
)
]
@property
def input_types(self) -> List[TypeProto]:
"Returns input types."
if not isinstance(self.proto, (ModelProto, GraphProto)):
raise ValueError(f"Cannot guess input types for type {type(self.proto)}")
g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
return [i.type for i in g.input]
@property
def output_types(self) -> List[TypeProto]:
"Returns output types."
if not isinstance(self.proto, (ModelProto, GraphProto)):
raise ValueError(f"Cannot guess output types for type {type(self.proto)}")
g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
return [i.type for i in g.output]
def _log_arg(self, a: Any) -> Any:
if isinstance(a, (str, int, float)):
return a
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
if hasattr(a, "shape"):
if self.verbose < 4: # noqa: PLR2004
return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
elements = a.ravel().tolist()
if len(elements) > 10: # noqa: PLR2004
elements = elements[:10]
return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
return f"{device}{a.dtype}:{a.shape}:{elements}"
if hasattr(a, "append"):
return ", ".join(map(self._log_arg, a))
return a
def _log(self, level: int, pattern: str, *args: Any) -> None:
if level < self.verbose:
new_args = [self._log_arg(a) for a in args]
print(pattern % tuple(new_args))
def _is_local_function(self, node: NodeProto) -> bool:
return (node.domain, node.op_type) in self.local_functions
[docs]
def run(
self,
outputs: Optional[List[str]],
feed_inputs: Dict[str, Any],
intermediate: bool = False,
) -> Union[Dict[str, Any], List[Any]]:
"""
Runs the model.
It only works with numpy arrays.
:param outputs: required outputs or None for all
:param feed_inputs: inputs
:param intermediate: returns all output instead of the last ones
:return: outputs, as a list if return_all is False,
as a dictionary if return_all is True
"""
if outputs is None:
outputs = self.output_names
results: Dict[str, Any] = self.rt_inits_.copy()
for k, v in self.rt_inits_.items():
self._log(2, " +C %s: %s", k, v)
for k, v in feed_inputs.items():
self._log(2, " +I %s: %s", k, v)
results[k] = v
for node in self.rt_nodes_:
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
for i in node.input:
if i != "" and i not in results:
raise RuntimeError(
f"Unable to find input {i!r} in known results {sorted(results)}, "
f"self.rt_inits_ has {sorted(self.rt_inits_)}, "
f"feed_inputs has {sorted(feed_inputs)}."
)
inputs = [(results[i] if i != "" else None) for i in node.input]
if node.op_type == "If" and node.domain == "":
outputs = self._run_if(node, inputs, results)
elif self._is_local_function(node):
outputs = self._run_local(node, inputs, results)
else:
outputs = self._run(node, inputs, results)
for name, value in zip(node.output, outputs):
if name == "":
continue
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
assert isinstance(name, str), f"unexpected type for name {type(name)}"
results[name] = value
if intermediate:
return results
output_names = self.output_names
for name in output_names:
if name == "":
continue
if name not in results:
raise RuntimeError(
f"Unable to find output name {name!r} "
f"in {sorted(results)}, proto is\n{pretty_onnx(self.proto)}"
)
return [results[name] for name in output_names if name != ""]
def _make_model_proto(
self,
nodes: Sequence[NodeProto],
vinputs: Sequence[ValueInfoProto],
voutputs: Sequence[ValueInfoProto],
) -> ModelProto:
onx = oh.make_model(
oh.make_graph(nodes, "-", vinputs, voutputs),
ir_version=getattr(self.proto, "ir_version", self.ir_version),
functions=getattr(self.proto, "functions", None),
)
del onx.opset_import[:]
if hasattr(self.proto, "opset_import"):
onx.opset_import.extend(self.proto.opset_import)
elif self.opsets:
if isinstance(self.opsets, int):
onx.opset_import.append(oh.make_opsetid("", self.opsets))
else:
onx.opset_import.extend(
[oh.make_opsetid(k, v) for k, v in self.opsets.items()]
)
else:
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
return onx
def _get_sess(
self, node: NodeProto, inputs: List[Any]
) -> Tuple[ModelProto, _InferenceSession]:
unique_names = set()
vinputs = []
for i, it in zip(node.input, inputs):
if i == "" or i in unique_names:
continue
unique_names.add(i)
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
vinputs.append(value)
# no need to run shape inference
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
onx = self._make_model_proto([node], vinputs, voutputs)
cls = (
InferenceSessionForNumpy
if any(isinstance(i, np.ndarray) for i in inputs)
else InferenceSessionForTorch
)
try:
sess = cls(onx, **self.session_kwargs)
except (
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
) as e:
raise RuntimeError(
f"Unable to infer a session with inputs\n{string_type(inputs)}"
f"\ndue to {e}\n{pretty_onnx(onx)}"
) from e
return onx, sess
def _get_sess_if(
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
unique_names = set()
vinputs = []
for i, it in zip(node.input, inputs):
if i == "" or i in unique_names:
continue
unique_names.add(i)
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
vinputs.append(value)
for i, v in context.items():
if i not in unique_names:
unique_names.add(i)
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
vinputs.append(value)
for att in node.attribute:
if att.name == branch:
g = att.g
voutputs = g.output
onx = self._make_model_proto(g.node, vinputs, voutputs)
sess = OnnxruntimeEvaluator(
onx,
local_functions=self.local_functions,
verbose=self.verbose,
ir_version=self.ir_version,
opsets=self.opsets,
**self.session_kwargs,
)
return onx, sess
def _get_sess_local(
self, node: NodeProto, inputs: List[Any]
) -> Tuple[FunctionProto, "OnnxruntimeEvaluator"]:
ev = self.local_functions[node.domain, node.op_type]
sess = OnnxruntimeEvaluator(
ev,
local_functions=self.local_functions,
verbose=self.verbose,
ir_version=self.ir_version,
opsets=self.opsets,
**self.session_kwargs,
)
return ev.proto, sess
def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
"""Runs a node."""
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
key = (id(node), *types)
if key in self._cache:
sess = self._cache[key][1]
else:
onx, sess = self._get_sess(node, inputs)
self._cache[key] = onx, sess
feeds = dict(zip(node.input, inputs))
if "" in feeds:
feeds[""] = np.array([0], dtype=np.float32)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
outputs = list(sess.run(None, feeds))
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs
def _run_if(
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
) -> List[Any]:
"""Runs a node if."""
feeds = dict(zip(node.input, inputs))
feeds.update(results)
if feeds[node.input[0]]:
name = "then_branch"
else:
name = "else_branch"
key = (id(node), name)
if key in self._cache:
sess = self._cache[key][1]
else:
self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
outputs = sess.run(None, feeds)
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs
def _run_local(
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
) -> List[Any]:
"""Runs a node."""
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
key = (id(node), *types)
if key in self._cache:
sess = self._cache[key][1]
else:
onx, sess = self._get_sess_local(node, inputs)
self._cache[key] = onx, sess
replace = dict(zip(node.input, sess.input_names))
assert len(node.input) == len(sess.input_names), (
f"Input mismatch: input_names={sess.input_names}, "
f"replace={replace}, "
f"type(self.proto)={type(self.proto)}, and node=\n{node}"
)
feeds = {replace[i]: v for i, v in zip(node.input, inputs)}
if "" in feeds:
feeds[""] = np.array([0], dtype=np.float32)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
outputs = sess.run(None, feeds)
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs