Intermediate results with onnxruntime

Example Intermediate results with (ONNX) ReferenceEvaluator demonstrated how to run a python runtime on a model but it may very slow sometimes and it could show some discrepancies if the only provider is not CPU. Let’s use OnnxruntimeEvaluator. It splits the model into node and runs them independently until it succeeds or fails. This class converts every node into model based on the types discovered during the execution. It relies on InferenceSessionForTorch or InferenceSessionForNumpy for the execution. This example uses torch tensor and bfloat16.

A failing model

The issue here is a an operator Cast trying to convert a result into a non-existing type.

import onnx
import onnx.helper as oh
import torch
import onnxruntime
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_cuda
from onnx_diagnostic.helpers.onnx_helper import from_array_extended
from onnx_diagnostic.reference import OnnxruntimeEvaluator

TBFLOAT16 = onnx.TensorProto.BFLOAT16

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Mul", ["X", "Y"], ["xy"], name="n0"),
            oh.make_node("Sigmoid", ["xy"], ["sy"], name="n1"),
            oh.make_node("Add", ["sy", "one"], ["C"], name="n2"),
            oh.make_node("Cast", ["C"], ["X999"], to=999, name="failing"),
            oh.make_node("CastLike", ["X999", "Y"], ["Z"], name="n4"),
        ],
        "-nd-",
        [
            oh.make_tensor_value_info("X", TBFLOAT16, ["a", "b", "c"]),
            oh.make_tensor_value_info("Y", TBFLOAT16, ["a", "b", "c"]),
        ],
        [oh.make_tensor_value_info("Z", TBFLOAT16, ["a", "b", "c"])],
        [from_array_extended(torch.tensor([1], dtype=torch.bfloat16), name="one")],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=9,
)

We check it is failing.

try:
    onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
except onnxruntime.capi.onnxruntime_pybind11_state.Fail as e:
    print(e)
[ONNXRuntimeError] : 1 : FAIL : Node (failing) Op (Cast) [TypeInferenceError] Attribute to does not specify a valid type in .

OnnxruntimeEvaluator

This class extends onnx.reference.ReferenceEvaluator with operators outside the standard but defined by onnxruntime. verbose=10 tells the class to print as much as possible, verbose=0 prints nothing. Intermediate values for more or less verbosity.

ref = OnnxruntimeEvaluator(model, verbose=10)
feeds = dict(
    X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16)
)
try:
    ref.run(None, feeds)
except Exception as e:
    print("ERROR", type(e), e)
 +C one: A:bfloat16:(1,):[1.0]
 +I X: T:D-1:torch.bfloat16:torch.Size([3, 4]):0.33984375,0.05078125,0.83203125,0.1328125,0.4765625,0.61328125,0.875,0.3515625,0.90625,0.55078125...
 +I Y: T:D-1:torch.bfloat16:torch.Size([3, 4]):0.54296875,0.56640625,0.6015625,0.3359375,0.2109375,0.140625,0.375,0.125,0.14453125,0.71875...
Mul(X, Y) -> xy
ERROR <class 'onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented'> [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for Mul(14) node with name 'n0'

onnxruntime may not support bfloat16 on CPU. See onnxruntime kernels.

if has_cuda():
    ref = OnnxruntimeEvaluator(model, providers="cuda", verbose=10)
    feeds = dict(
        X=torch.rand((3, 4), dtype=torch.bfloat16), Y=torch.rand((3, 4), dtype=torch.bfloat16)
    )
    try:
        ref.run(None, feeds)
    except Exception as e:
        print("ERROR", type(e), e)
 +C one: A:bfloat16:(1,):[1.0]
 +I X: T:D-1:torch.bfloat16:torch.Size([3, 4]):0.67578125,0.77734375,0.26953125,0.45703125,0.8671875,0.33984375,0.50390625,0.70703125,0.1015625,0.328125...
 +I Y: T:D-1:torch.bfloat16:torch.Size([3, 4]):0.98828125,0.38671875,0.20703125,0.51171875,0.765625,0.09765625,0.62890625,0.75,0.953125,0.1171875...
Mul(X, Y) -> xy
 + xy: T:D-1:torch.bfloat16:torch.Size([3, 4]):0.66796875,0.30078125,0.055908203125,0.2333984375,0.6640625,0.033203125,0.31640625,0.53125,0.0966796875,0.03857421875...
 - deletes: X - torch.bfloat16:torch.Size([3, 4])
Sigmoid(xy) -> sy
 + sy: T:D-1:torch.bfloat16:torch.Size([3, 4]):0.66015625,0.57421875,0.515625,0.5546875,0.66015625,0.5078125,0.578125,0.62890625,0.5234375,0.51171875...
 - deletes: xy - torch.bfloat16:torch.Size([3, 4])
Add(sy, one) -> C
 + C: A:bfloat16:(3, 4):1.65625,1.578125,1.515625,1.5546875,1.65625,1.5078125,1.578125,1.625,1.5234375,1.515625...
 - deletes: sy - torch.bfloat16:torch.Size([3, 4])
 - deletes: one - bfloat16:(1,)
Cast(C) -> X999
ERROR <class 'RuntimeError'> Unable to create a session stored in '_debug_InferenceSession_last_failure.onnx'), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']

We can see it run until it reaches Cast and stops. The error message is not always obvious to interpret. It gets improved every time from time to time. This runtime is useful when it fails for a numerical reason. It is possible to insert prints in the python code to print more information or debug if needed.

doc.plot_legend("onnxruntime\nrunning\nstep by step", "OnnxruntimeEvaluator", "lightgrey")
plot failing onnxruntime evaluator

Total running time of the script: (0 minutes 2.162 seconds)

Related examples

Intermediate results with (ONNX) ReferenceEvaluator

Intermediate results with (ONNX) ReferenceEvaluator

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Dumps intermediate results of a torch model

Dumps intermediate results of a torch model

Gallery generated by Sphinx-Gallery