Evaluators#
yobx provides three evaluators for running ONNX models in Python.
They share a common interface (__init__(proto, ...) + run(outputs, feeds))
but differ in their backend, tensor type, and primary use-case:
Class |
Backend |
Tensor type |
Best suited for |
|---|---|---|---|
|
onnx reference (Python) |
NumPy
|
unit-testing, contrib ops, pure Python debugging |
|
ONNX Runtime
(node-by-node or
whole, |
NumPy or PyTorch |
debugging ORT execution, inspecting intermediate results, or whole-model ORT inference |
|
PyTorch (Python) |
|
GPU execution, memory-efficient evaluation, custom PyTorch ops |
Quick comparison
All three evaluators accept an onnx.ModelProto (or filename) and return a
list of outputs when called via run(None, feed_dict). The key
differences are:
ExtendedReferenceEvaluator — a pure Python, NumPy-based evaluator that extends
onnx.reference.ReferenceEvaluatorwith extra kernels for non-standard domains (com.microsoft,ai.onnx.complex). No ONNX Runtime installation is required. Ideal for unit tests and operator prototyping.OnnxruntimeEvaluator — executes each graph node individually through
onnxruntime.InferenceSession. Because every node is run in isolation it is easy to inspect every intermediate result and to compare them against a reference. Accepts both NumPy arrays and PyTorch tensors. Passwhole=Trueto skip node-by-node splitting and hand the complete model to a single ORT session (faster, but intermediate results are not accessible).TorchReferenceEvaluator — runs every node with hand-written PyTorch kernels. Inputs and outputs are
torch.Tensor. Supports CUDA viaproviders=["CUDAExecutionProvider"]. Well-suited for evaluating large models on the GPU where keeping activations as PyTorch tensors avoids expensive NumPy round-trips.
ExtendedReferenceEvaluator#
yobx.reference.ExtendedReferenceEvaluator extends
onnx.reference.ReferenceEvaluator with additional operator kernels
for non-standard domains such as com.microsoft and ai.onnx.complex.
The standard onnx.reference.ReferenceEvaluator only knows about
operators defined in the ONNX standard. ONNX Runtime ships many contrib
operators (domain com.microsoft) that are widely used in production
models — for example FusedMatMul, QuickGelu and Attention.
ExtendedReferenceEvaluator makes it possible to
run and unit-test such models with pure Python, without requiring a full
ONNX Runtime installation.
Built-in operators#
The following table lists the operator implementations that are registered
automatically. They are available as default_ops.
Class |
Domain |
Description |
|---|---|---|
|
com.microsoft |
Multi-head self-attention with optional mask |
|
com.microsoft |
Softmax with an additive bias term |
|
ai.onnx.complex |
Element-wise modulus of a complex tensor |
|
com.microsoft |
Matrix multiplication with optional transpositions ( |
|
(default) |
Identity copy (device ↔ host no-op) |
|
(default) |
Identity copy (device ↔ host no-op) |
|
com.microsoft |
Quantized average pooling |
|
com.microsoft |
Quantized 2-D convolution |
|
com.microsoft |
Gated sigmoid activation |
|
com.microsoft |
Residual add followed by layer normalisation |
|
ai.onnx.complex |
Converts a real tensor |
The full list at runtime can be printed with:
<<<
import pprint
from yobx.reference import ExtendedReferenceEvaluator
pprint.pprint(ExtendedReferenceEvaluator.default_ops)
>>>
[<class 'yobx.reference.ops.op__overwrite_gather.Gather'>,
<class 'yobx.reference.ops.op__overwrite_gather_elements.GatherElements'>,
<class 'yobx.reference.ops.op__overwrite_scatter_elements.ScatterElements'>,
<class 'yobx.reference.ops.op_attention.Attention'>,
<class 'yobx.reference.ops.op_bias_softmax.BiasSoftmax'>,
<class 'yobx.reference.ops.op_complex.ComplexModule'>,
<class 'yobx.reference.ops.op_fused_matmul.FusedMatMul'>,
<class 'yobx.reference.ops.op_memcpy_host.MemcpyFromHost'>,
<class 'yobx.reference.ops.op_memcpy_host.MemcpyToHost'>,
<class 'yobx.reference.ops.op_qlinear_conv.QLinearConv'>,
<class 'yobx.reference.ops.op_qlinear_average_pool.QLinearAveragePool'>,
<class 'yobx.reference.ops.op_quick_gelu.QuickGelu'>,
<class 'yobx.reference.ops.op_simplified_layer_normalization.SimplifiedLayerNormalization'>,
<class 'yobx.reference.ops.op_skip_layer_normalization.SkipLayerNormalization'>,
<class 'yobx.reference.ops.op_complex.ToComplex'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.AddAdd'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.AddMul'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.AddSharedInput'>,
<class 'yobx.reference.ops.op__extended_scatternd_of_shape.MaskedScatterNDOfShape'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.MulAdd'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.MulMul'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.MulSharedInput'>,
<class 'yobx.reference.ops.op__extended_mul_sigmoid.MulSigmoid'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.MulSub'>,
<class 'yobx.reference.ops.op__extended_negxplus1.NegXplus1'>,
<class 'yobx.reference.ops.op__extended_replace_zero.ReplaceZero'>,
<class 'yobx.reference.ops.op__extended_rotary.Rotary'>,
<class 'yobx.reference.ops.op__extended_scatternd_of_shape.ScatterNDOfShape'>,
<class 'yobx.reference.ops.op__extended_add_add_mul_mul.SubMul'>,
<class 'yobx.reference.ops.op__extended_transpose_cast.Transpose2DCastFP16'>,
<class 'yobx.reference.ops.op__extended_transpose_cast.Transpose2DCastFP32'>,
<class 'yobx.reference.ops.op__extended_tri_matrix.TriMatrix'>]
Basic usage#
ExtendedReferenceEvaluator is a drop-in replacement
for onnx.reference.ReferenceEvaluator. Any model that runs with the
standard evaluator also runs here.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[oh.make_node("Add", ["X", "Y"], ["Z"])],
"add_graph",
[
oh.make_tensor_value_info("X", TFLOAT, [None, None]),
oh.make_tensor_value_info("Y", TFLOAT, [None, None]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
ref = ExtendedReferenceEvaluator(model)
x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
(result,) = ref.run(None, {"X": x, "Y": x})
print(result)
>>>
[[2. 4.]
[6. 8.]]
Contrib operators#
Models that use ONNX Runtime contrib operators can be run directly.
The example below uses FusedMatMul — a com.microsoft operator that
fuses matrix multiplication with optional transposition of either operand.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node(
"FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft", transA=1
)
],
"fused_mm",
[
oh.make_tensor_value_info("X", TFLOAT, None),
oh.make_tensor_value_info("Y", TFLOAT, None),
],
[oh.make_tensor_value_info("Z", TFLOAT, None)],
),
opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)],
ir_version=10,
)
ref = ExtendedReferenceEvaluator(model)
a = np.arange(4, dtype=np.float32).reshape(2, 2)
(result,) = ref.run(None, {"X": a, "Y": a})
print(result) # a.T @ a
>>>
[[ 4. 6.]
[ 6. 10.]]
Adding custom operators#
Pass extra OpRun subclasses through
the new_ops argument. They are merged with default_ops; you do
not need to re-list the built-in contrib operators.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from onnx.reference.op_run import OpRun
from yobx.reference import ExtendedReferenceEvaluator
TFLOAT = onnx.TensorProto.FLOAT
class MyCustomOp(OpRun):
op_domain = "my.domain"
def _run(self, X):
return (X * 2,)
model = oh.make_model(
oh.make_graph(
[oh.make_node("MyCustomOp", ["X"], ["Z"], domain="my.domain")],
"custom_graph",
[oh.make_tensor_value_info("X", TFLOAT, [None])],
[oh.make_tensor_value_info("Z", TFLOAT, [None])],
),
opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("my.domain", 1)],
ir_version=10,
)
ref = ExtendedReferenceEvaluator(model, new_ops=[MyCustomOp])
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
(result,) = ref.run(None, {"X": x})
print(result) # [2. 4. 6.]
>>>
[2. 4. 6.]
Inspecting intermediate results#
Pass verbose=10 to ExtendedReferenceEvaluator
to print every input, every intermediate result, and every output as the
model executes. This is useful for debugging incorrect outputs or
understanding how values flow through the graph.
The verbose parameter maps to the logging levels used internally by
onnx.reference.ReferenceEvaluator:
verbose=0(default) — silentverbose=2— prints each node as it executes (NodeOp(inputs) -> outputs)verbose=3or higher — also prints the value of every input, initializer constant (+C), and intermediate/final result (+,+I)
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference import ExtendedReferenceEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["T"]),
oh.make_node("Tanh", ["T"], ["Z"]),
],
"add_tanh",
[
oh.make_tensor_value_info("X", TFLOAT, [None, None]),
oh.make_tensor_value_info("Y", TFLOAT, [None, None]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
ref = ExtendedReferenceEvaluator(model, verbose=10)
x = np.array([[1.0, -2.0], [3.0, -4.0]], dtype=np.float32)
(result,) = ref.run(None, {"X": x, "Y": x})
print("result:", result)
>>>
+I X: float32:(2, 2):[1.0, -2.0, 3.0, -4.0]
+I Y: float32:(2, 2):[1.0, -2.0, 3.0, -4.0]
Add(X, Y) -> T
+ T: float32:(2, 2):[2.0, -4.0, 6.0, -8.0]
Tanh(T) -> Z
+ Z: float32:(2, 2):[0.9640275835990906, -0.9993293285369873, 0.9999877214431763, -0.9999997615814209]
result: [[ 0.964 -0.999]
[ 1. -1. ]]
The lines prefixed with +I are model inputs; lines with +C are
initializer constants; and lines with + (after a node execution line)
are the intermediate or final outputs produced by that node.
Operator versioning#
When a model imports multiple versions of a domain (e.g. opset 13 and 17),
filter_ops
selects the best (highest version that does not exceed the model opset)
implementation from the new_ops list.
This mirrors the versioning convention used by
onnx.reference.ReferenceEvaluator itself: operator classes whose
names end in _<version> (e.g. MyOp_13, MyOp_17) are treated as
versioned alternatives and the most appropriate one is chosen automatically.
See also
ExtendedReferenceEvaluator: running models with contrib operators — sphinx-gallery example
demonstrating standard operators, FusedMatMul, QuickGelu, and
custom operator injection.
OnnxruntimeEvaluator#
yobx.reference.OnnxruntimeEvaluator executes an ONNX model
node by node using onnxruntime.InferenceSession as the kernel
backend. Each node is wrapped in a tiny single-node ONNX model and fed
into ORT individually, which means every intermediate tensor is accessible
after the run. This makes the class especially useful for:
debugging — comparing intermediate activations between two models or between ORT and a reference implementation;
mixed-precision inspection — examining how casts or quantisation layers change the values at each step;
portability — the same code runs on CPU or GPU simply by passing different
providers.
Basic usage#
The API mirrors onnx.reference.ReferenceEvaluator: pass an
onnx.ModelProto (or a filename) and call run().
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference.onnxruntime_evaluator import OnnxruntimeEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[oh.make_node("Add", ["X", "Y"], ["Z"])],
"add_graph",
[
oh.make_tensor_value_info("X", TFLOAT, [None, None]),
oh.make_tensor_value_info("Y", TFLOAT, [None, None]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
ref = OnnxruntimeEvaluator(model)
x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
(result,) = ref.run(None, {"X": x, "Y": x})
print(result)
>>>
[[2. 4.]
[6. 8.]]
Inspecting intermediate results#
Pass verbose=2 to print every node that executes together with its
inputs and outputs. Pass intermediate=True to run() to get back
a dictionary that maps every result name (inputs, constants, and
intermediate tensors) to its value.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference.onnxruntime_evaluator import OnnxruntimeEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["T"]),
oh.make_node("Tanh", ["T"], ["Z"]),
],
"add_tanh",
[
oh.make_tensor_value_info("X", TFLOAT, [None, None]),
oh.make_tensor_value_info("Y", TFLOAT, [None, None]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
ref = OnnxruntimeEvaluator(model, verbose=2)
x = np.array([[1.0, -2.0], [3.0, -4.0]], dtype=np.float32)
all_results = ref.run(None, {"X": x, "Y": x}, intermediate=True)
for name, value in sorted(all_results.items()):
print(f"{name}: {value}")
>>>
Add(X, Y) -> T
Tanh(T) -> Z
T: [[ 2. -4.]
[ 6. -8.]]
X: [[ 1. -2.]
[ 3. -4.]]
Y: [[ 1. -2.]
[ 3. -4.]]
Z: [[ 0.964 -0.999]
[ 1. -1. ]]
Running the whole model at once#
By default OnnxruntimeEvaluator splits the graph into individual nodes
and runs each one separately (whole=False). Passing whole=True
hands the complete model to a single onnxruntime.InferenceSession
and is equivalent to calling ORT directly. This mode is faster but does
not allow intermediate result inspection.
<<<
import numpy as np
import onnx
import onnx.helper as oh
from yobx.reference.onnxruntime_evaluator import OnnxruntimeEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[oh.make_node("Sigmoid", ["X"], ["Z"])],
"sigmoid_graph",
[oh.make_tensor_value_info("X", TFLOAT, [None])],
[oh.make_tensor_value_info("Z", TFLOAT, [None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
ref = OnnxruntimeEvaluator(model, whole=True)
x = np.array([-1.0, 0.0, 1.0, 2.0], dtype=np.float32)
(result,) = ref.run(None, {"X": x})
print(result)
>>>
[0.269 0.5 0.731 0.881]
TorchReferenceEvaluator#
yobx.reference.TorchReferenceEvaluator is a pure-Python evaluator
that runs every ONNX node with hand-written torch kernels. Inputs
and outputs are torch.Tensor, which means:
there are no NumPy round-trips between nodes;
the model can be evaluated on CUDA by passing
providers=["CUDAExecutionProvider"];intermediate tensors are freed as soon as they are no longer needed, which reduces peak memory usage for large models.
The available kernels can be listed with
get_kernels().
Basic usage#
<<<
import onnx
import onnx.helper as oh
import torch
from yobx.helpers import string_type
from yobx.reference.torch_evaluator import TorchReferenceEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Sigmoid", ["Y"], ["sy"]),
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
oh.make_node("Mul", ["X", "ysy"], ["final"]),
],
"silu",
[
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
],
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
sess = TorchReferenceEvaluator(model)
feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
result = sess.run(None, feeds)
print(string_type(result, with_shape=True, with_min_max=True))
>>>
#1[T1s4x5[0.008510555140674114,0.43604037165641785:A0.1840971229597926]]
Inspecting intermediate results#
Pass verbose=1 to print every kernel execution and every tensor freed
during the run. This lets you trace the exact execution order and see when
memory is reclaimed.
<<<
import onnx
import onnx.helper as oh
import torch
from yobx.helpers import string_type
from yobx.reference.torch_evaluator import TorchReferenceEvaluator
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["T"]),
oh.make_node("Tanh", ["T"], ["Z"]),
],
"add_tanh",
[
oh.make_tensor_value_info("X", TFLOAT, [None, None]),
oh.make_tensor_value_info("Y", TFLOAT, [None, None]),
],
[oh.make_tensor_value_info("Z", TFLOAT, [None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
sess = TorchReferenceEvaluator(model, verbose=1)
feeds = dict(
X=torch.tensor([[1.0, -2.0], [3.0, -4.0]]),
Y=torch.tensor([[1.0, -2.0], [3.0, -4.0]]),
)
result = sess.run(None, feeds)
print(string_type(result, with_shape=True))
>>>
+I X: RuntimeValue(name='X', kind=5, shape=(2, 2), value=CT1s2x2[-4.0,3.0:A-0.5])
+I Y: RuntimeValue(name='Y', kind=5, shape=(2, 2), value=CT1s2x2[-4.0,3.0:A-0.5])
Add_1(X, Y) -> T
+R T: RuntimeValue(name='T', kind=1, shape=(2, 2), is_shape=False, value=CT1s2x2[-8.0,6.0:A-1.0])
- clean X
- clean Y
Tanh_6(T) -> Z
+R Z: RuntimeValue(name='Z', kind=9, shape=(2, 2), is_shape=False, value=CT1s2x2[-0.9999997615814209,0.9999877214431763:A-0.00882844626903534])
- clean T
++ outputs Z
- clean X
- clean Y
- clean Z
#1[T1s2x2]
Custom kernels#
A specific ONNX op can be replaced by passing a dictionary to
custom_kernels. The keys are (domain, op_type) tuples and the
values are subclasses of
OpRunKernel.
This is useful, for example, to delegate a single op to ONNX Runtime while
keeping the rest of the graph in PyTorch.
<<<
import numpy as np
import onnx
import onnx.helper as oh
import torch
from yobx.helpers import string_type
from yobx.reference.torch_evaluator import TorchReferenceEvaluator
from yobx.reference.torch_ops import OpRunKernel, OpRunTensor
TFLOAT = onnx.TensorProto.FLOAT
class SigmoidCPU(OpRunKernel):
"Custom Sigmoid that always runs on CPU."
def run(self, x):
t = x.tensor.cpu()
return OpRunTensor(torch.sigmoid(t).to(x.tensor.device))
model = oh.make_model(
oh.make_graph(
[oh.make_node("Sigmoid", ["X"], ["Z"])],
"sigmoid_graph",
[oh.make_tensor_value_info("X", TFLOAT, [None])],
[oh.make_tensor_value_info("Z", TFLOAT, [None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
sess = TorchReferenceEvaluator(model, custom_kernels={("", "Sigmoid"): SigmoidCPU})
x = torch.tensor([-1.0, 0.0, 1.0, 2.0])
(result,) = sess.run(None, {"X": x})
print(result)
>>>
tensor([0.2689, 0.5000, 0.7311, 0.8808])
Available kernels#
<<<
from yobx.reference.torch_evaluator import get_kernels
for k, v in sorted(get_kernels().items()):
domain, name, version = k
f = f"{name}({version})" if domain == "" else f"{name}[{domain}]({version})"
add = " " * max(25 - len(f), 0)
dd = " -- device dependent" if v.device_dependent() else ""
print(f"{f}{add} -- {v.__name__}{dd}")
>>>
Abs(1) -- Abs_1
Add(1) -- Add_1
And(1) -- And_1
AveragePool(11) -- AveragePool_11
Cast(6) -- Cast_6
CastLike(15) -- CastLike_15
Concat(1) -- Concat_1
ConcatFromSequence(11) -- ConcatFromSequence_11
ConstantOfShape(9) -- ConstantOfShape_9 -- device dependent
Conv(11) -- Conv_11
Cos(1) -- Cos_1
Div(1) -- Div_1
Equal(1) -- Equal_1
Erf(9) -- Erf_9
Exp(1) -- Exp_1
Expand(8) -- Expand_8
Gather(1) -- Gather_1
Greater(1) -- Greater_1
GreaterOrEqual(1) -- GreaterOrEqual_1
Identity(1) -- Identity_1
If(1) -- If_1
IsNaN(9) -- IsNaN_9
LayerNormalization(17) -- LayerNormalization_17
Less(1) -- Less_1
LessOrEqual(1) -- LessOrEqual_1
Log(1) -- Log_1
Loop(16) -- Loop_16
MatMul(1) -- MatMul_1
Mul(1) -- Mul_1
Neg(1) -- Neg_1
NonZero(13) -- NonZero_13
Not(1) -- Not_1
Or(1) -- Or_1
Pow(12) -- Pow_12
Range(11) -- Range_11 -- device dependent
Reciprocal(1) -- Reciprocal_1
ReduceMax(18) -- ReduceMax_18
ReduceMean(18) -- ReduceMean_18
ReduceMin(17) -- ReduceMin_17
ReduceMin(18) -- ReduceMin_18
ReduceSum(13) -- ReduceSum_13
Reshape(14) -- Reshape_14
ScatterND(16) -- ScatterND_16
SequenceEmpty(11) -- SequenceEmpty_11
SequenceInsert(11) -- SequenceInsert_11
Shape(15) -- Shape_15
Sigmoid(6) -- Sigmoid_6
Sin(1) -- Sin_1
Slice(13) -- Slice_13
Softmax(13) -- Softmax_13
Split(18) -- Split_18
Sqrt(1) -- Sqrt_1
Squeeze(13) -- Squeeze_13
Sub(1) -- Sub_1
Tanh(6) -- Tanh_6
Tile(6) -- Tile_6
Transpose(1) -- Transpose_1
Trilu(14) -- Trilu_14
Unsqueeze(13) -- Unsqueeze_13
Where(9) -- Where_9
See also
Comparing the three evaluators — sphinx-gallery example that runs the same model through all three evaluators and verifies the outputs agree.