yaourt.reference.evaluator#
- class yaourt.reference.evaluator.ExtendedReferenceEvaluator(proto: Any, opsets: Dict[str, int] | None = None, functions: List[ReferenceEvaluator | FunctionProto] | None = None, verbose: int = 0, new_ops: List[type[OpRun]] | None = None, **kwargs: Any)#
Extends
onnx.reference.ReferenceEvaluatorwith a richer API and support for versioned operator look-up.The evaluator is a drop-in replacement for
onnx.reference.ReferenceEvaluator. It adds:Automatic version selection – when multiple versioned implementations of the same operator are provided (e.g.
MyOp_13,MyOp_18), the evaluator picks the highest version that does not exceed the opset declared in the model.Convenient run shortcut –
run(feeds)(a single list argument) is accepted in addition to the standardrun(None, feeds)form.Function-proto support –
onnx.FunctionProtomodels can be executed directly, with full support for linked attributes and intermediate result inspection.Domain-assertion guard – a runtime check verifies that every loaded implementation reports the same
op_domainas the node it is serving, helping to catch configuration mistakes early.
default_opslists theOpRunsubclasses that are registered by default. This list is empty in the base class; sub-classes or callers can populate it to add domain-specific kernels without requiring every user to passnew_opsexplicitly.Basic usage — run a model with standard ONNX operators:
import numpy as np import onnx.helper as oh import onnx from yaourt.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [oh.make_node("Add", ["X", "Y"], ["Z"])], "add_graph", [ oh.make_tensor_value_info("X", TFLOAT, [None, None]), oh.make_tensor_value_info("Y", TFLOAT, [None, None]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) ref = ExtendedReferenceEvaluator(model) x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) (result,) = ref.run(None, {"X": x, "Y": x}) print(result)
Convenience run — pass inputs as a list (zipped with
input_names):import numpy as np import onnx.helper as oh import onnx from yaourt.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [oh.make_node("Add", ["X", "Y"], ["Z"])], "add_graph", [ oh.make_tensor_value_info("X", TFLOAT, [None, None]), oh.make_tensor_value_info("Y", TFLOAT, [None, None]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) ref = ExtendedReferenceEvaluator(model) x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) (result,) = ref.run([x, x]) print(result)
Adding custom operators — pass extra
OpRunsubclasses vianew_ops:import numpy as np import onnx.helper as oh import onnx from onnx.reference.op_run import OpRun from yaourt.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT class MyCustomOp(OpRun): op_domain = "my.domain" def _run(self, X): return (X * 2,) model = oh.make_model( oh.make_graph( [oh.make_node("MyCustomOp", ["X"], ["Z"], domain="my.domain")], "custom_graph", [oh.make_tensor_value_info("X", TFLOAT, [None])], [oh.make_tensor_value_info("Z", TFLOAT, [None])], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("my.domain", 1)], ir_version=10, ) ref = ExtendedReferenceEvaluator(model, new_ops=[MyCustomOp]) x = np.array([1.0, 2.0, 3.0], dtype=np.float32) (result,) = ref.run(None, {"X": x}) print(result)
The
new_opslist is merged withdefault_ops; you do not need to re-list operators that are already in the default set.Versioned operator selection — when multiple implementations of the same operator are provided with a trailing
_<version>suffix, the evaluator automatically selects the highest version that does not exceed the opset declared in the model:from onnx.reference.op_run import OpRun from yaourt.reference import ExtendedReferenceEvaluator class MyOp_1(OpRun): op_domain = "custom" def _run(self, X): return (X,) class MyOp_3(OpRun): op_domain = "custom" def _run(self, X): return (X * 3,) # Only MyOp_1 will be used when the model declares opset version 2.
The class overloads or adds the following operators by default:
import pprint from yaourt.reference import ExtendedReferenceEvaluator pprint.pprint(ExtendedReferenceEvaluator.default_ops)
- __init__(proto: Any, opsets: Dict[str, int] | None = None, functions: List[ReferenceEvaluator | FunctionProto] | None = None, verbose: int = 0, new_ops: List[type[OpRun]] | None = None, **kwargs: Any)#
- static filter_ops(proto: Any, new_ops: List[type[OpRun]], opsets: Dict[str, int] | None) List[type[OpRun]]#
Filters and deduplicates versioned operator implementations.
For each operator that has multiple versioned implementations (identified by a trailing
_<int>suffix in the class name), keeps only the one with the highest version number that does not exceed the opset version declared in proto for that domain.- Parameters:
proto – an ONNX
ModelProtoorFunctionProto, used to read the declared opset versions. May beNone.new_ops – list of
OpRunsubclasses to filter.opsets – explicit opset map
{domain: version}; takes precedence over opsets embedded in proto when notNone.
- Returns:
filtered list of operator implementations.
- run(*args: Any, **kwargs: Any) Any#
Runs the model and returns the outputs.
Accepts both the standard
run(output_names, feeds)calling convention and a convenience shortcutrun(feeds)where feeds is a list of arrays that is zipped withinput_names.See
onnx.reference.ReferenceEvaluator.run()for full parameter documentation.