.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_extended_reference_evaluator.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_extended_reference_evaluator.py: .. _l-plot-extended-reference-evaluator: ExtendedReferenceEvaluator: running models with contrib operators ================================================================= :class:`ExtendedReferenceEvaluator ` extends :class:`onnx.reference.ReferenceEvaluator` with additional operator kernels for non-standard domains such as ``com.microsoft``. This makes it possible to execute and test ONNX models that contain ONNX Runtime contrib operators (e.g. ``FusedMatMul``, ``QuickGelu``) without needing a full ONNX Runtime installation just for unit-testing an optimization pattern. This example shows: 1. Running a model with standard ONNX operators. 2. Running a model that uses the ``FusedMatMul`` contrib operator. 3. Running a model that uses the ``QuickGelu`` contrib operator. 4. Adding a custom operator implementation via ``new_ops``. .. GENERATED FROM PYTHON SOURCE LINES 24-32 .. code-block:: Python import numpy as np import onnx import onnx.helper as oh from yobx.reference import ExtendedReferenceEvaluator TFLOAT = onnx.TensorProto.FLOAT .. GENERATED FROM PYTHON SOURCE LINES 33-39 1. Standard ONNX operators -------------------------- :class:`ExtendedReferenceEvaluator` is a drop-in replacement for :class:`onnx.reference.ReferenceEvaluator`. Any model that runs with the standard evaluator also runs here. .. GENERATED FROM PYTHON SOURCE LINES 39-60 .. code-block:: Python model_add = oh.make_model( oh.make_graph( [oh.make_node("Add", ["X", "Y"], ["Z"])], "add_graph", [ oh.make_tensor_value_info("X", TFLOAT, [None, None]), oh.make_tensor_value_info("Y", TFLOAT, [None, None]), ], [oh.make_tensor_value_info("Z", TFLOAT, [None, None])], ), opset_imports=[oh.make_opsetid("", 18)], ir_version=10, ) x = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) ref = ExtendedReferenceEvaluator(model_add) (result,) = ref.run(None, {"X": x, "Y": x}) print("Add result:\n", result) assert np.allclose(result, x + x) .. rst-class:: sphx-glr-script-out .. code-block:: none Add result: [[2. 4.] [6. 8.]] .. GENERATED FROM PYTHON SOURCE LINES 61-68 2. FusedMatMul (com.microsoft contrib operator) ------------------------------------------------ ``FusedMatMul`` is an ONNX Runtime contrib operator that fuses a matrix multiplication with optional transpositions. The standard :class:`onnx.reference.ReferenceEvaluator` does not know about it, but :class:`ExtendedReferenceEvaluator` does. .. GENERATED FROM PYTHON SOURCE LINES 68-88 .. code-block:: Python model_fmm = oh.make_model( oh.make_graph( [oh.make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")], "fused_matmul_graph", [ oh.make_tensor_value_info("X", TFLOAT, None), oh.make_tensor_value_info("Y", TFLOAT, None), ], [oh.make_tensor_value_info("Z", TFLOAT, None)], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], ) a = np.arange(4, dtype=np.float32).reshape(2, 2) ref_fmm = ExtendedReferenceEvaluator(model_fmm) (z,) = ref_fmm.run(None, {"X": a, "Y": a}) print("FusedMatMul result:\n", z) assert np.allclose(z, a @ a) .. rst-class:: sphx-glr-script-out .. code-block:: none FusedMatMul result: [[ 2. 3.] [ 6. 11.]] .. GENERATED FROM PYTHON SOURCE LINES 89-90 With ``transA=1`` the first operand is transposed before the multiplication. .. GENERATED FROM PYTHON SOURCE LINES 90-109 .. code-block:: Python model_fmm_t = oh.make_model( oh.make_graph( [oh.make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft", transA=1)], "fused_matmul_transA_graph", [ oh.make_tensor_value_info("X", TFLOAT, None), oh.make_tensor_value_info("Y", TFLOAT, None), ], [oh.make_tensor_value_info("Z", TFLOAT, None)], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], ) ref_fmm_t = ExtendedReferenceEvaluator(model_fmm_t) (z_t,) = ref_fmm_t.run(None, {"X": a, "Y": a}) print("FusedMatMul(transA=1) result:\n", z_t) assert np.allclose(z_t, a.T @ a) .. rst-class:: sphx-glr-script-out .. code-block:: none FusedMatMul(transA=1) result: [[ 4. 6.] [ 6. 10.]] .. GENERATED FROM PYTHON SOURCE LINES 110-115 3. QuickGelu (com.microsoft contrib operator) ---------------------------------------------- ``QuickGelu`` applies the gated sigmoid activation ``x * sigmoid(alpha * x)`` element-wise. .. GENERATED FROM PYTHON SOURCE LINES 115-132 .. code-block:: Python model_gelu = oh.make_model( oh.make_graph( [oh.make_node("QuickGelu", ["X"], ["Z"], domain="com.microsoft", alpha=1.702)], "quick_gelu_graph", [oh.make_tensor_value_info("X", TFLOAT, None)], [oh.make_tensor_value_info("Z", TFLOAT, None)], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], ) x_gelu = np.array([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=np.float32) ref_gelu = ExtendedReferenceEvaluator(model_gelu) (z_gelu,) = ref_gelu.run(None, {"X": x_gelu}) print("QuickGelu result:", z_gelu) .. rst-class:: sphx-glr-script-out .. code-block:: none QuickGelu result: [-0.06434137 -0.15420423 0. 0.84579575 1.9356586 ] .. GENERATED FROM PYTHON SOURCE LINES 133-140 4. Adding a custom operator via ``new_ops`` ------------------------------------------- Any :class:`OpRun ` subclass can be passed through the ``new_ops`` argument. The built-in :attr:`default_ops ` are always merged in automatically, so you only need to list your additions. .. GENERATED FROM PYTHON SOURCE LINES 140-170 .. code-block:: Python from onnx.reference.op_run import OpRun # noqa: E402 class Scale(OpRun): """Multiplies every element of X by a constant *factor*.""" op_domain = "my.domain" def _run(self, X, factor=2.0): # type: ignore[override] return (X * np.float32(factor),) model_custom = oh.make_model( oh.make_graph( [oh.make_node("Scale", ["X"], ["Z"], domain="my.domain", factor=3.0)], "scale_graph", [oh.make_tensor_value_info("X", TFLOAT, [None])], [oh.make_tensor_value_info("Z", TFLOAT, [None])], ), opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("my.domain", 1)], ir_version=10, ) x_s = np.array([1.0, 2.0, 3.0], dtype=np.float32) ref_custom = ExtendedReferenceEvaluator(model_custom, new_ops=[Scale]) (z_s,) = ref_custom.run(None, {"X": x_s}) print("Scale(factor=3) result:", z_s) assert np.allclose(z_s, x_s * 3.0) .. rst-class:: sphx-glr-script-out .. code-block:: none Scale(factor=3) result: [3. 6. 9.] .. GENERATED FROM PYTHON SOURCE LINES 171-176 5. Listing the default operators --------------------------------- :attr:`default_ops` shows all operator implementations that are registered automatically. .. GENERATED FROM PYTHON SOURCE LINES 176-180 .. code-block:: Python import pprint # noqa: E402 pprint.pprint(ExtendedReferenceEvaluator.default_ops) .. rst-class:: sphx-glr-script-out .. code-block:: none [, , , , , , , , , , ] .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.041 seconds) .. _sphx_glr_download_auto_examples_plot_extended_reference_evaluator.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_extended_reference_evaluator.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_extended_reference_evaluator.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_extended_reference_evaluator.zip ` .. include:: plot_extended_reference_evaluator.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_