yobx.reference

ExtendedReferenceEvaluator

class yobx.reference.ExtendedReferenceEvaluator(proto: Any, opsets: Dict[str, int] | None = None, functions: List[ReferenceEvaluator | FunctionProto] | None = None, verbose: int = 0, new_ops: List[type[OpRun]] | None = None, **kwargs)[source][source]

Extends onnx.reference.ReferenceEvaluator with additional operator kernels for non-standard domains such as com.microsoft.

The evaluator allows testing scenarios outside what a standard ONNX backend can handle, such as optimization patterns that rely on ONNX Runtime contrib operators (e.g. FusedMatMul, QuickGelu).

Basic usage — run an ONNX model with standard operators:

<<<

import numpy as np
import onnx.helper as oh
import onnx
from yobx.reference import ExtendedReferenceEvaluator

TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
    oh.make_graph(
        [oh.make_node("Add", ["X", "Y"], ["Z"])],
        "add_graph",
        [
            oh.make_tensor_value_info("X", TFLOAT, [None, None]),
            oh.make_tensor_value_info("Y", TFLOAT, [None, None]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)
ref = ExtendedReferenceEvaluator(model)
x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
(result,) = ref.run(None, {"X": x, "Y": x})
print(result)

>>>

    [[2. 4.]
     [6. 8.]]

Using contrib operators — run a com.microsoft operator:

<<<

import numpy as np
import onnx.helper as oh
import onnx
from yobx.reference import ExtendedReferenceEvaluator

TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
    oh.make_graph(
        [oh.make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")],
        "fused_mm",
        [
            oh.make_tensor_value_info("X", TFLOAT, None),
            oh.make_tensor_value_info("Y", TFLOAT, None),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, None)],
    ),
    opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)],
)
ref = ExtendedReferenceEvaluator(model)
a = np.arange(4, dtype=np.float32).reshape(2, 2)
(result,) = ref.run(None, {"X": a, "Y": a})
print(result)

>>>

    [[ 2.  3.]
     [ 6. 11.]]

Adding custom operators — pass extra OpRun subclasses via new_ops:

<<<

import numpy as np
import onnx.helper as oh
import onnx
from onnx.reference.op_run import OpRun
from yobx.reference import ExtendedReferenceEvaluator

TFLOAT = onnx.TensorProto.FLOAT


class MyCustomOp(OpRun):
    op_domain = "my.domain"

    def _run(self, X):
        return (X * 2,)


model = oh.make_model(
    oh.make_graph(
        [oh.make_node("MyCustomOp", ["X"], ["Z"], domain="my.domain")],
        "custom_graph",
        [oh.make_tensor_value_info("X", TFLOAT, [None])],
        [oh.make_tensor_value_info("Z", TFLOAT, [None])],
    ),
    opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("my.domain", 1)],
    ir_version=10,
)
ref = ExtendedReferenceEvaluator(model, new_ops=[MyCustomOp])
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
(result,) = ref.run(None, {"X": x})
print(result)

>>>

    [2. 4. 6.]

The new_ops list is merged with default_ops; you do not need to re-list the built-in contrib operators.

The class overloads or adds the following operators by default:

<<<

import pprint
from yobx.reference import ExtendedReferenceEvaluator

pprint.pprint(ExtendedReferenceEvaluator.default_ops)

>>>

    [<class 'yobx.reference.ops.op_attention.Attention'>,
     <class 'yobx.reference.ops.op_bias_softmax.BiasSoftmax'>,
     <class 'yobx.reference.ops.op_complex.ComplexModule'>,
     <class 'yobx.reference.ops.op_fused_matmul.FusedMatMul'>,
     <class 'yobx.reference.ops.op_memcpy_host.MemcpyFromHost'>,
     <class 'yobx.reference.ops.op_memcpy_host.MemcpyToHost'>,
     <class 'yobx.reference.ops.op_qlinear_conv.QLinearConv'>,
     <class 'yobx.reference.ops.op_qlinear_average_pool.QLinearAveragePool'>,
     <class 'yobx.reference.ops.op_quick_gelu.QuickGelu'>,
     <class 'yobx.reference.ops.op_skip_layer_normalization.SkipLayerNormalization'>,
     <class 'yobx.reference.ops.op_complex.ToComplex'>]
run(*args, **kwargs)[source][source]

See onnx.reference.ReferenceEvaluator.run().