ONNX Inspection Helpers#

yobx ships a collection of utilities for reading and understanding ONNX protobuf structures. They are scattered across yobx.helpers.onnx_helper and yobx.helpers.helper and are primarily used internally by the builder, optimizer, and evaluator, but are equally useful for debugging, testing, and quick exploration.

Printing models and nodes#

pretty_onnx() converts any ONNX protobuf object to a compact, human-readable text representation. It works on onnx.ModelProto, onnx.GraphProto, onnx.FunctionProto, onnx.NodeProto, onnx.TensorProto, onnx.ValueInfoProto, and onnx.AttributeProto.

<<<

import onnx.helper as oh
import onnx
from yobx.helpers.onnx_helper import pretty_onnx

TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Relu", ["X"], ["Y"]),
            oh.make_node("Transpose", ["Y"], ["Z"], perm=[1, 0]),
        ],
        "relu_transpose",
        [oh.make_tensor_value_info("X", TFLOAT, [None, 4])],
        [oh.make_tensor_value_info("Z", TFLOAT, [4, None])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)
print(pretty_onnx(model))

>>>

    opset: domain='' version=18
    input: name='X' type=dtype('float32') shape=['', 4]
    Relu(X) -> Y
      Transpose(Y, perm=[1,0]) -> Z
    output: name='Z' type=dtype('float32') shape=[4, '']

For a single node, pass with_attributes=True to include attribute values:

<<<

import onnx.helper as oh
from yobx.helpers.onnx_helper import pretty_onnx

node = oh.make_node("Transpose", ["X"], ["Y"], perm=[1, 0, 2])
print(pretty_onnx(node, with_attributes=True))

>>>

    Transpose(X) -> Y  ---  perm=[1, 0, 2]

Inspecting shapes and types#

string_type() produces a concise one-line description of any Python object including nested structures of torch.Tensor and numpy.ndarray. It is used throughout the library for error messages and debug logging.

<<<

import numpy as np
from yobx.helpers import string_type

obj = [
    np.zeros((2, 4), dtype=np.float32),
    np.zeros((2, 4), dtype=np.float32),
]
print(string_type(obj, with_shape=True))

>>>

    #2[A1s2x4,A1s2x4]

Comparing numerical outputs#

max_diff() computes the maximum absolute and relative differences between two nested structures. It is the primary validation utility used when checking that an exported ONNX model produces the same outputs as the original PyTorch model.

import numpy as np
from yobx.helpers import max_diff

ref = {"logits": np.array([[1.0, 2.0, 3.0]])}
got = {"logits": np.array([[1.0, 2.0001, 3.0]])}

diff = max_diff(ref, got)
print("max absolute diff:", diff["abs"])
print("max relative diff:", diff["rel"])

Walking subgraphs#

Several ONNX operators (If, Loop, Scan, SequenceMap) embed subgraphs as attributes. Traversing a model fully therefore requires recursing into these subgraphs. enumerate_subgraphs() yields every embedded onnx.GraphProto in depth-first order, including the top-level graph:

<<<

import onnx
import onnx.helper as oh
from yobx.helpers.onnx_helper import enumerate_subgraphs

TFLOAT = onnx.TensorProto.FLOAT
then_graph = oh.make_graph(
    [oh.make_node("Relu", ["X"], ["Y"])],
    "then",
    [],
    [oh.make_tensor_value_info("Y", TFLOAT, [None, 4])],
)
else_graph = oh.make_graph(
    [oh.make_node("Abs", ["X"], ["Y"])],
    "else",
    [],
    [oh.make_tensor_value_info("Y", TFLOAT, [None, 4])],
)
model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node(
                "If",
                ["cond"],
                ["Y"],
                then_branch=then_graph,
                else_branch=else_graph,
            )
        ],
        "if_model",
        [
            oh.make_tensor_value_info("cond", onnx.TensorProto.BOOL, []),
            oh.make_tensor_value_info("X", TFLOAT, [None, 4]),
        ],
        [oh.make_tensor_value_info("Y", TFLOAT, [None, 4])],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)
graphs = list(enumerate_subgraphs(model.graph))
print("number of graphs (including main):", len(graphs))
print("graph names:", [g.name for g in graphs])

>>>

    number of graphs (including main): 3
    graph names: ['if_model', 'else', 'then']

Flattening nested inputs#

flatten_object() recursively flattens any nested Python structure (dicts, lists, tuples, torch.Tensor, numpy.ndarray) into a single flat list of leaf tensors. This is useful when assembling the flat input list expected by onnxruntime.InferenceSession:

<<<

import numpy as np
from yobx.helpers.helper import flatten_object

inputs = {
    "input_ids": np.array([[1, 2, 3]], dtype=np.int64),
    "past_kv": [np.zeros((1, 4, 6, 8)), np.zeros((1, 4, 6, 8))],
}
flat = flatten_object(inputs, drop_keys=True)
print("number of leaf tensors:", len(flat))
for i, t in enumerate(flat):
    print(f"  [{i}]: shape={t.shape}, dtype={t.dtype}")

>>>

    number of leaf tensors: 3
      [0]: shape=(1, 3), dtype=int64
      [1]: shape=(1, 4, 6, 8), dtype=float64
      [2]: shape=(1, 4, 6, 8), dtype=float64

See also

Interesting HelpersMiniOnnxBuilder for serializing nested tensor structures to ONNX initializers.

Evaluators — evaluator classes that consume ONNX models and produce outputs compatible with the utilities described above.