from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
from onnx import (
AttributeProto,
GraphProto,
FunctionProto,
ModelProto,
NodeProto,
TypeProto,
ValueInfoProto,
helper as oh,
load,
save as onnx_save,
shape_inference as shi,
)
from onnx.defs import onnx_opset_version
import onnxruntime
from ..helpers import string_type
from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
from ..helpers.ort_session import (
InferenceSessionForTorch,
InferenceSessionForNumpy,
_InferenceSession,
)
from .evaluator import ExtendedReferenceEvaluator
PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
[docs]
class OnnxruntimeEvaluator:
"""
This class loads an onnx model and the executes one by one the nodes
with onnxruntime. This class is mostly meant for debugging.
:param proto: proto or filename
:param session_options: options
:param providers: providers
:param nvtx: enable nvidia events
:param providers: `None`, `"CPU"`, `"CUDA"` or a list of providers
:param graph_optimization_level: see :class:`onnxruntime.SessionOptions`
:param log_severity_level: see :class:`onnxruntime.SessionOptions`
:param log_verbosity_level: see :class:`onnxruntime.SessionOptions`
:param optimized_model_filepath: see :class:`onnxruntime.SessionOptions`
:param disable_aot_function_inlining: see :class:`onnxruntime.SessionOptions`
:param use_training_api: use onnxruntime-traning API
:param verbose: verbosity
:param local_functions: additional local function
:param ir_version: ir version to use when unknown
:param opsets: opsets to use when unknown
:param whole: if True, do not split node by node
"""
def __init__(
self,
proto: Union[str, Proto, "OnnxruntimeEvaluator"],
session_options: Optional[onnxruntime.SessionOptions] = None,
providers: Optional[Union[str, List[str]]] = None,
nvtx: bool = False,
enable_profiling: bool = False,
graph_optimization_level: Union[onnxruntime.GraphOptimizationLevel, bool] = None,
log_severity_level: Optional[int] = None,
log_verbosity_level: Optional[int] = None,
optimized_model_filepath: Optional[str] = None,
disable_aot_function_inlining: Optional[bool] = None,
use_training_api: bool = False,
verbose: int = 0,
local_functions: Optional[
Dict[Tuple[str, str], Union[Proto, "OnnxruntimeEvaluator"]]
] = None,
ir_version: int = 10,
opsets: Optional[Union[int, Dict[str, int]]] = None,
whole: bool = False,
):
if isinstance(proto, str):
self.proto: Proto = load(proto)
elif isinstance(proto, OnnxruntimeEvaluator):
assert isinstance(
proto.proto, PROTO
), f"Unexpected type for proto.proto {type(proto.proto)}"
self.proto = proto.proto
else:
self.proto = proto
assert isinstance(
self.proto, PROTO
), f"Unexpected type for self.proto {type(self.proto)}"
self._cache: Dict[
Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
] = {}
self.ir_version = ir_version
self.opsets = opsets
self.session_kwargs: Dict[str, Any] = dict(
session_options=session_options,
providers=providers,
nvtx=nvtx,
enable_profiling=enable_profiling,
graph_optimization_level=graph_optimization_level,
log_severity_level=log_severity_level,
log_verbosity_level=log_verbosity_level,
optimized_model_filepath=optimized_model_filepath,
disable_aot_function_inlining=disable_aot_function_inlining,
use_training_api=use_training_api,
)
self.verbose = verbose
self.sess_: Optional[_InferenceSession] = None
if whole:
self.nodes: Optional[List[NodeProto]] = None
self.rt_inits_: Optional[Dict[str, Any]] = None
self.rt_nodes_: Optional[List[NodeProto]] = None
else:
self.nodes = (
[self.proto]
if isinstance(self.proto, NodeProto)
else (
list(
self.proto.graph.node
if hasattr(self.proto, "graph")
else self.proto.node
)
)
)
self.rt_inits_ = (
{init.name: to_array_extended(init) for init in self.proto.graph.initializer}
if hasattr(self.proto, "graph")
else {}
)
self.rt_nodes_ = self.nodes.copy()
self.local_functions: Dict[Tuple[str, str], "OnnxruntimeEvaluator"] = ( # noqa: UP037
{(f.domain, f.name): self.__class__(f) for f in self.proto.functions}
if hasattr(self.proto, "functions")
else {}
)
if local_functions:
self.local_functions.update(local_functions)
self.garbage_collector = self._build_garbage_collector() if self.rt_nodes_ else {}
@property
def input_names(self) -> List[str]:
"Returns input names."
assert self.proto, "self.proto is empty"
if isinstance(self.proto, NodeProto):
assert isinstance(
self.nodes, list
), f"Unexpected type {type(self.nodes)} for self.nodes"
return self.nodes[0].input
return [
getattr(o, "name", o)
for o in (
self.proto.graph.input if hasattr(self.proto, "graph") else self.proto.input
)
]
@property
def output_names(self) -> List[str]:
"Returns output names."
assert self.proto, "self.proto is empty"
if isinstance(self.proto, NodeProto):
assert isinstance(
self.nodes, list
), f"Unexpected type {type(self.nodes)} for self.nodes"
return self.nodes[0].output
return [
getattr(o, "name", o)
for o in (
self.proto.graph.output if hasattr(self.proto, "graph") else self.proto.output
)
]
@property
def input_types(self) -> List[TypeProto]:
"Returns input types."
if not isinstance(self.proto, (ModelProto, GraphProto)):
raise ValueError(f"Cannot guess input types for type {type(self.proto)}")
g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
return [i.type for i in g.input]
@property
def output_types(self) -> List[TypeProto]:
"Returns output types."
if not isinstance(self.proto, (ModelProto, GraphProto)):
raise ValueError(f"Cannot guess output types for type {type(self.proto)}")
g = self.proto.graph if hasattr(self.proto, "graph") else self.proto
return [i.type for i in g.output]
def _log_arg(self, a: Any) -> Any:
if isinstance(a, (str, int, float)):
return a
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
if hasattr(a, "shape"):
if self.verbose < 4: # noqa: PLR2004
return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
elements = a.ravel().tolist()
if len(elements) > 10: # noqa: PLR2004
elements = elements[:10]
return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
return f"{device}{a.dtype}:{a.shape}:{elements}"
if hasattr(a, "append"):
return ", ".join(map(self._log_arg, a))
return a
def _log(self, level: int, pattern: str, *args: Any) -> None:
if level < self.verbose:
new_args = [self._log_arg(a) for a in args]
print(pattern % tuple(new_args))
def _is_local_function(self, node: NodeProto) -> bool:
return (node.domain, node.op_type) in self.local_functions
[docs]
def run(
self,
outputs: Optional[List[str]],
feed_inputs: Dict[str, Any],
intermediate: bool = False,
) -> Union[Dict[str, Any], List[Any]]:
"""
Runs the model.
It only works with numpy arrays.
:param outputs: required outputs or None for all
:param feed_inputs: inputs
:param intermediate: returns all output instead of the last ones
:return: outputs, as a list if return_all is False,
as a dictionary if return_all is True
"""
if self.rt_nodes_ is None:
# runs a whole
if self.sess_ is None:
assert self.proto, "self.proto is empty"
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
assert self.sess_, "mypy not happy"
return self.sess_.run(outputs, feed_inputs)
if outputs is None:
outputs = self.output_names
results: Dict[str, Any] = (self.rt_inits_ or {}).copy()
for k, v in results.items():
self._log(2, " +C %s: %s", k, v)
for k, v in feed_inputs.items():
assert not isinstance(v, str), f"Unexpected type str for {k!r}"
self._log(2, " +I %s: %s", k, v)
results[k] = v
for i_node, node in enumerate(self.rt_nodes_ or []):
self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output)
for i in node.input:
if i != "" and i not in results:
raise RuntimeError(
f"Unable to find input {i!r} in known results {sorted(results)}, "
f"self.rt_inits_ has {sorted((self.rt_inits_ or {}))}, "
f"feed_inputs has {sorted(feed_inputs)}."
)
inputs = [(results[i] if i != "" else None) for i in node.input]
if node.op_type == "If" and node.domain == "":
outputs = self._run_if(node, inputs, results)
elif node.op_type in {"Scan", "Loop"} and node.domain == "":
outputs = self._run_scan(node, inputs, results)
elif self._is_local_function(node):
outputs = self._run_local(node, inputs, results)
else:
outputs = self._run(node, inputs, results)
for name, value in zip(node.output, outputs):
if name == "":
continue
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
assert isinstance(name, str), f"unexpected type for name {type(name)}"
results[name] = value
if not intermediate:
self._clean_unused_inplace(i_node, node, results)
if intermediate:
return results
output_names = self.output_names
for name in output_names:
if name == "":
continue
if name not in results:
raise RuntimeError(
f"Unable to find output name {name!r} "
f"in {sorted(results)}, proto is\n{pretty_onnx(self.proto)}"
)
return [results[name] for name in output_names if name != ""]
def _build_garbage_collector(self) -> Dict[str, int]:
"""
Memorizes the results not needed anymore for every node.
Returns a dictionary with the last node using the results.
"""
needed = {}
for i, node in enumerate(self.rt_nodes_ or []):
for name in node.input:
needed[name] = i
if node.op_type in {"Scan", "If", "Loop"}:
hidden = self._get_hidden_node_inputs(node)
for name in hidden:
needed[name] = i
if isinstance(self.proto, ModelProto):
for o in self.proto.graph.output:
needed[o.name] = len(self.rt_nodes_ or [])
elif isinstance(self.proto, GraphProto):
for o in self.proto.output:
needed[o.name] = len(self.rt_nodes_ or [])
elif isinstance(self.proto, FunctionProto):
for o in self.proto.output:
needed[o] = len(self.rt_nodes_ or [])
return needed
def _clean_unused_inplace(self, i_node: int, node: NodeProto, results: Dict[str, Any]):
"""
Cleans all results not needed anymore. Some models requires to clean the memory
to be able to run.
"""
if not self.garbage_collector:
return
for name in node.input:
if self.garbage_collector[name] == i_node and name in results:
if self.verbose:
t = results[name]
print(f" - deletes: {name} - {t.dtype}:{t.shape}")
del results[name]
if node.op_type in {"Scan", "If", "Loop"}:
hidden = self._get_hidden_node_inputs(node)
for name in hidden:
if self.garbage_collector[name] == i_node and name in results:
if self.verbose:
t = results[name]
print(f" - deletes: {name} - {t.dtype}:{t.shape}")
del results[name]
def _make_model_proto(
self,
nodes: Sequence[NodeProto],
vinputs: Sequence[ValueInfoProto],
voutputs: Sequence[ValueInfoProto],
) -> ModelProto:
onx = oh.make_model(
oh.make_graph(nodes, "-", vinputs, voutputs),
ir_version=getattr(self.proto, "ir_version", self.ir_version),
functions=getattr(self.proto, "functions", None),
)
del onx.opset_import[:]
if hasattr(self.proto, "opset_import"):
onx.opset_import.extend(self.proto.opset_import)
elif self.opsets:
if isinstance(self.opsets, int):
onx.opset_import.append(oh.make_opsetid("", self.opsets))
else:
onx.opset_import.extend(
[oh.make_opsetid(k, v) for k, v in self.opsets.items()]
)
else:
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
# That helps fixing bugs.
onx = shi.infer_shapes(onx)
return onx
@classmethod
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
"""
Returns the hidden inputs (inputs coming from an upper context)
used by a subgraph.
"""
hidden = set()
memo = set(i.name for i in graph.initializer)
memo |= set(i.name for i in graph.sparse_initializer)
for node in graph.node:
for i in node.input:
if i not in memo:
hidden.add(i)
for att in node.attribute:
if att.type == AttributeProto.GRAPH and att.g:
hid = self._get_hidden_inputs(att.g)
less = set(h for h in hid if h not in memo)
hidden |= less
memo |= set(node.output)
return hidden
@classmethod
def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
"""Calls multiple _get_hidden_inputs on every attribute."""
if node.op_type not in {"Loop", "Scan", "If"}:
return set()
hidden = set()
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
hidden |= self._get_hidden_inputs(att.g)
return hidden - (hidden & set(node.input))
def _get_sess(
self, node: Union[ModelProto, NodeProto], inputs: List[Any]
) -> Tuple[ModelProto, _InferenceSession]:
if isinstance(node, ModelProto):
onx = node
else:
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
if node.op_type == "Constant":
# We force the type to be a boolean.
ref = ExtendedReferenceEvaluator(node)
cst = ref.run(None, {})[0]
vinputs: List[ValueInfoProto] = []
voutputs = [
oh.make_tensor_value_info(
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
)
]
else:
unique_names = set()
vinputs = []
for i, it in zip(node.input, inputs):
if i == "" or i in unique_names:
continue
unique_names.add(i)
value = oh.make_tensor_value_info(
i, dtype_to_tensor_dtype(it.dtype), it.shape
)
vinputs.append(value)
# no need to run shape inference
voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
onx = self._make_model_proto([node], vinputs, voutputs)
cls = (
InferenceSessionForNumpy
if any(isinstance(i, np.ndarray) for i in inputs)
else InferenceSessionForTorch
)
try:
sess = cls(onx, **self.session_kwargs)
except (
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument,
) as e:
onnx_save(onx, "_debug_OnnxruntimeEvaluator_last_failure.onnx")
raise RuntimeError(
f"Unable to infer a session with inputs\n{string_type(inputs)}"
f"\ndue to {e}\n{pretty_onnx(onx)}"
) from e
return onx, sess
def _get_sess_init_subgraph(
self, node: NodeProto, inputs: List[Any], context: Dict[str, Any], g: GraphProto
) -> List[Any]:
unique_names = set()
vinputs = []
for i, it in zip(node.input, inputs):
if i == "" or i in unique_names:
continue
unique_names.add(i)
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
vinputs.append(value)
reduced_set = self._get_hidden_inputs(g)
for i, v in context.items():
if i in reduced_set and i not in unique_names:
unique_names.add(i)
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
vinputs.append(value)
return vinputs
def _get_sess_if(
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
g = None
for att in node.attribute:
if att.name == branch:
g = att.g
assert g, f"Missing attribute {branch!r}"
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
voutputs = g.output
identities = [
oh.make_node("Identity", [iname], [ginput.name])
for iname, ginput in zip(node.input, g.input)
]
onx = self._make_model_proto([*identities, *g.node], vinputs, voutputs)
sess = OnnxruntimeEvaluator(
onx,
local_functions=self.local_functions,
verbose=self.verbose,
ir_version=self.ir_version,
opsets=self.opsets,
**self.session_kwargs,
)
return onx, sess
def _get_sess_local(
self, node: NodeProto, inputs: List[Any]
) -> Tuple[FunctionProto, "OnnxruntimeEvaluator"]:
ev = self.local_functions[node.domain, node.op_type]
sess = OnnxruntimeEvaluator(
ev,
local_functions=self.local_functions,
verbose=self.verbose,
ir_version=self.ir_version,
opsets=self.opsets,
**self.session_kwargs,
)
return ev.proto, sess
def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
"""Runs a node."""
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
key = (id(node), *types)
if key in self._cache:
sess = self._cache[key][1]
else:
onx, sess = self._get_sess(node, inputs)
self._cache[key] = onx, sess
feeds = dict(zip(node.input, inputs))
if "" in feeds:
feeds[""] = np.array([0], dtype=np.float32)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
outputs = list(sess.run(None, feeds))
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs
def _run_if(
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
) -> List[Any]:
"""Runs a node If."""
feeds = dict(zip(node.input, inputs))
feeds.update(results)
if feeds[node.input[0]]:
name = "then_branch"
else:
name = "else_branch"
key = (id(node), name)
if key in self._cache:
sess = self._cache[key][1]
else:
self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
feeds = {name: results[name] for name in sess.input_names}
outputs = sess.run(None, feeds)
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs
def _get_sess_scan(
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
g = None
for att in node.attribute:
if att.name == branch:
g = att.g
assert g, f"Missing attribute {branch!r}"
vinputs = self._get_sess_init_subgraph(node, inputs, context, g)
begin = 0 if node.op_type == "Scan" else 1
voutputs = []
for name, _goutput in zip(node.output, g.output[begin:]):
v = ValueInfoProto()
# v.ParseFromString(goutput.SerializeToString())
v.name = name
voutputs.append(v)
# identities = []
# for iname, ginput in zip(node.input, g.input):
# identities.append(oh.make_node("Identity", [iname], [ginput.name]))
onx = self._make_model_proto([node], vinputs, voutputs)
sess = OnnxruntimeEvaluator(
onx,
local_functions=self.local_functions,
verbose=self.verbose,
ir_version=self.ir_version,
opsets=self.opsets,
whole=True,
**self.session_kwargs,
)
return onx, sess
def _run_scan(
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
) -> List[Any]:
"""Runs a node Scan."""
feeds = dict(zip(node.input, inputs))
feeds.update(results)
name = "body"
key = (id(node), name)
if key in self._cache:
sess = self._cache[key][1]
else:
self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
feeds = {name: results[name] for name in sess.input_names}
outputs = sess.run(None, feeds)
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs
def _run_local(
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
) -> List[Any]:
"""Runs a node."""
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
key = (id(node), *types)
if key in self._cache:
sess = self._cache[key][1]
else:
onx, sess = self._get_sess_local(node, inputs)
self._cache[key] = onx, sess
replace = dict(zip(node.input, sess.input_names))
assert len(node.input) == len(sess.input_names), (
f"Input mismatch: input_names={sess.input_names}, "
f"replace={replace}, "
f"type(self.proto)={type(self.proto)}, and node=\n{node}"
)
feeds = {replace[i]: v for i, v in zip(node.input, inputs)}
if "" in feeds:
feeds[""] = np.array([0], dtype=np.float32)
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
outputs = sess.run(None, feeds)
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
return outputs