Source code for yobx.reference.evaluator

from typing import Any, Dict, List, Optional, Union
from onnx import FunctionProto, ModelProto, NodeProto, TypeProto
from onnx.defs import get_schema
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from .ops.op_attention import Attention
from .ops.op_bias_softmax import BiasSoftmax
from .ops.op_complex import ComplexModule, ToComplex
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_qlinear_average_pool import QLinearAveragePool
from .ops.op_qlinear_conv import QLinearConv
from .ops.op_quick_gelu import QuickGelu
from .ops.op_skip_layer_normalization import SkipLayerNormalization


[docs] class ExtendedReferenceEvaluator(ReferenceEvaluator): """ Extends :class:`onnx.reference.ReferenceEvaluator` with additional operator kernels for non-standard domains such as ``com.microsoft``. The evaluator allows testing scenarios outside what a standard ONNX backend can handle, such as optimization patterns that rely on ONNX Runtime contrib operators (e.g. :class:`FusedMatMul <yobx.reference.ops.op_fused_matmul.FusedMatMul>`, :class:`QuickGelu <yobx.reference.ops.op_quick_gelu.QuickGelu>`). **Basic usage** — run an ONNX model with standard operators: .. runpython:: :showcode: import numpy as np import onnx.helper as oh import onnx from yobx.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [oh.make_node("Add", ["X", "Y"], ["Z"])], "add_graph", [ oh.make_tensor_value_info("X", TFLOAT, [None, None]), oh.make_tensor_value_info("Y", TFLOAT, [None, None]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) ref = ExtendedReferenceEvaluator(model) x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) (result,) = ref.run(None, {"X": x, "Y": x}) print(result) **Using contrib operators** — run a ``com.microsoft`` operator: .. runpython:: :showcode: import numpy as np import onnx.helper as oh import onnx from yobx.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT model = oh.make_model( oh.make_graph( [oh.make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")], "fused_mm", [ oh.make_tensor_value_info("X", TFLOAT, None), oh.make_tensor_value_info("Y", TFLOAT, None), ], [oh.make_tensor_value_info("Z", TFLOAT, None)], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], ) ref = ExtendedReferenceEvaluator(model) a = np.arange(4, dtype=np.float32).reshape(2, 2) (result,) = ref.run(None, {"X": a, "Y": a}) print(result) **Adding custom operators** — pass extra :class:`OpRun <onnx.reference.op_run.OpRun>` subclasses via ``new_ops``: .. runpython:: :showcode: import numpy as np import onnx.helper as oh import onnx from onnx.reference.op_run import OpRun from yobx.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT class MyCustomOp(OpRun): op_domain = "my.domain" def _run(self, X): return (X * 2,) model = oh.make_model( oh.make_graph( [oh.make_node("MyCustomOp", ["X"], ["Z"], domain="my.domain")], "custom_graph", [oh.make_tensor_value_info("X", TFLOAT, [None])], [oh.make_tensor_value_info("Z", TFLOAT, [None])], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("my.domain", 1)], ir_version=10, ) ref = ExtendedReferenceEvaluator(model, new_ops=[MyCustomOp]) x = np.array([1.0, 2.0, 3.0], dtype=np.float32) (result,) = ref.run(None, {"X": x}) print(result) The ``new_ops`` list is *merged* with :attr:`default_ops`; you do not need to re-list the built-in contrib operators. The class overloads or adds the following operators by default: .. runpython:: :showcode: import pprint from yobx.reference import ExtendedReferenceEvaluator pprint.pprint(ExtendedReferenceEvaluator.default_ops) """ default_ops: List[type[OpRun]] = [ Attention, BiasSoftmax, ComplexModule, FusedMatMul, MemcpyFromHost, MemcpyToHost, QLinearConv, QLinearAveragePool, QuickGelu, SkipLayerNormalization, ToComplex, ] @staticmethod def filter_ops(proto, new_ops, opsets): if opsets is None and isinstance(proto, (ModelProto, FunctionProto)): opsets = {d.domain: d.version for d in proto.opset_import} best = {} renamed = {} for cl in new_ops: if "_" not in cl.__name__: continue vers = cl.__name__.split("_") try: v = int(vers[-1]) except ValueError: # not a version continue if opsets is not None and v > opsets.get(cl.op_domain, 1): continue renamed[cl.__name__] = cl key = cl.op_domain, "_".join(vers[:-1]) if key not in best or best[key][0] < v: best[key] = (v, cl) modified = [] for cl in new_ops: if cl.__name__ not in renamed: modified.append(cl) for k, v in best.items(): atts = {"domain": k[0]} bases = (v[1],) if not hasattr(v[1], "op_schema"): atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain) new_cl = type(k[1], bases, atts) modified.append(new_cl) new_ops = modified return new_ops def __init__( self, proto: Any, opsets: Optional[Dict[str, int]] = None, functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None, verbose: int = 0, new_ops: Optional[List[type[OpRun]]] = None, **kwargs, ): if new_ops is None: new_ops = ExtendedReferenceEvaluator.default_ops else: new_ops = new_ops.copy() new_ops.extend(ExtendedReferenceEvaluator.default_ops) new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets) ReferenceEvaluator.__init__( self, proto, opsets=opsets, functions=functions, verbose=verbose, new_ops=new_ops, **kwargs, )
[docs] def run(self, *args, **kwargs): """See :meth:`onnx.reference.ReferenceEvaluator.run`.""" if len(args) == 1 and isinstance(args[0], list): feeds = dict(zip(self.input_names, args[0])) return self.run(None, feeds, **kwargs) if isinstance(self.proto_, FunctionProto): return self._run_function(*args, **kwargs) return ReferenceEvaluator.run(self, *args, **kwargs)
def _load_impl(self, node: NodeProto, input_types: TypeProto | None = None) -> Any: res = super()._load_impl(node, input_types) assert ( not hasattr(res, "op_domain") or res.op_domain == node.domain ), f"Domain mismatch {res.op_domain!r} != {node.domain} for node={node}" return res def _run_function( self, output_names, feed_inputs: Dict[str, Any], attributes: Optional[Dict[str, Any]] = None, intermediate: bool = False, ) -> Union[Dict[str, Any], List[Any]]: # type: ignore if output_names is None: output_names = self.output_names # step 1: inputs and initializers results = {"": None} # optional input results.update(self.rt_inits_) # type: ignore[arg-type] results.update(feed_inputs) for k, v in self.rt_inits_.items(): self._log(2, " +C %s: %s", k, v) # type: ignore[arg-type] for k, v in feed_inputs.items(): self._log(2, " +I %s: %s", k, v) # type: ignore[arg-type] # step 2: execute nodes for node in self.rt_nodes_: self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) for i in node.input: if i not in results: raise RuntimeError( f"Unable to find input {i!r} in known results {sorted(results)}, " f"self.rt_inits_ has {sorted(self.rt_inits_)}, " f"feed_inputs has {sorted(feed_inputs)}." ) inputs = [results[i] for i in node.input] linked_attributes = {} if node.has_linked_attribute and attributes: linked_attributes["linked_attributes"] = attributes if node.need_context(): outputs = node.run(*inputs, context=results, **linked_attributes) else: outputs = node.run(*inputs, **linked_attributes) for name, value in zip(node.output, outputs): self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] results[name] = value # return the results if intermediate: return results for name in output_names: if name not in results: raise RuntimeError( f"Unable to find output name {name!r} " f"in {sorted(results)}, proto is\n{self.proto_}" ) return [results[name] for name in output_names]