Intermediate results with (ONNX) ReferenceEvaluator

Let’s assume onnxruntime crashes without telling why or where. The first thing is do is to locate where. For that, we run a python runtime which is going to run until it fails.

A failing model

The issue here is a an operator Cast trying to convert a result into a non-existing type.

import numpy as np
import onnx
import onnx.helper as oh
import onnxruntime
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.reference import ExtendedReferenceEvaluator

TFLOAT = onnx.TensorProto.FLOAT

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Mul", ["X", "Y"], ["xy"], name="n0"),
            oh.make_node("Sigmoid", ["xy"], ["sy"], name="n1"),
            oh.make_node("Add", ["sy", "one"], ["C"], name="n2"),
            oh.make_node("Cast", ["C"], ["X999"], to=999, name="failing"),
            oh.make_node("CastLike", ["X999", "Y"], ["Z"], name="n4"),
        ],
        "-nd-",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
            oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
        [from_array_extended(np.array([1], dtype=np.float32), name="one")],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=9,
)

We check it is failing.

try:
    onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
    print(e)
[ONNXRuntimeError] : 1 : FAIL : Node (failing) Op (Cast) [TypeInferenceError] Attribute to does not specify a valid type in .

ExtendedReferenceEvaluator

This class extends onnx.reference.ReferenceEvaluator with operators outside the standard but defined by onnxruntime. verbose=10 tells the class to print as much as possible, verbose=0 prints nothing. Intermediate values for more or less verbosity.

ref = ExtendedReferenceEvaluator(model, verbose=10)
feeds = dict(
    X=np.random.rand(3, 4).astype(np.float32), Y=np.random.rand(3, 4).astype(np.float32)
)
try:
    ref.run(None, feeds)
except Exception as e:
    print("ERROR", type(e), e)
 +C one: float32:(1,):[1.0]
 +I X: float32:(3, 4):0.20402809977531433,0.9137852191925049,0.8619707226753235,0.5603037476539612,0.38157159090042114...
 +I Y: float32:(3, 4):0.15071871876716614,0.9215355515480042,0.9482867121696472,0.12428570538759232,0.45832890272140503...
Mul(X, Y) -> xy
 + xy: float32:(3, 4):0.030750853940844536,0.8420855402946472,0.8173953890800476,0.06963774561882019,0.17488528788089752...
Sigmoid(xy) -> sy
 + sy: float32:(3, 4):0.5076870918273926,0.698904275894165,0.6936831474304199,0.5174024105072021,0.5436102151870728...
Add(sy, one) -> C
 + C: float32:(3, 4):1.5076870918273926,1.698904275894165,1.69368314743042,1.5174024105072021,1.5436102151870728...
Cast(C) -> X999
ERROR <class 'KeyError'> 999

We can see it run until it reaches Cast and stops. The error message is not always obvious to interpret. It gets improved every time from time to time. This runtime is useful when it fails for a numerical reason. It is possible to insert prints in the python code to print more information or debug if needed.

doc.plot_legend("Python Runtime\nfor ONNX", "ExtendedReferenceEvalutor", "lightgrey")
plot failing reference evaluator

Total running time of the script: (0 minutes 0.075 seconds)

Related examples

Intermediate results with onnxruntime

Intermediate results with onnxruntime

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Find and fix an export issue due to dynamic shapes

Find and fix an export issue due to dynamic shapes

Gallery generated by Sphinx-Gallery