Interesting Helpers

ONNX Serialization of Nested Structured with Tensors

The main goal is to serialize any Python structure into ONNX format. This relies on MiniOnnxBuilder. Example MiniOnnxBuilder: serialize tensors to an ONNX model shows an example.

MiniOnnxBuilder creates minimal ONNX models whose only purpose is to store tensors as initializers and return them when the model is executed. The model has no inputs — running it simply replays the stored values. This is useful for:

  • capturing intermediate activations or model weights for debugging,

  • persisting arbitrary nested Python structures (dicts, tuples, lists, torch tensors) in a standard, portable format,

  • sharing small test fixtures without committing raw binary files.

The class exposes three methods to add outputs:

  • append_output_initializer — stores a single tensor (numpy array or torch tensor). When randomize=True the values are replaced by a random-number generator node, keeping shape and dtype but discarding the original data to reduce model size.

  • append_output_sequence — wraps a list of tensors into an ONNX Sequence output.

  • append_output_dict — stores a dict of tensors as two outputs (keys and values).

Two higher-level helpers build on top of MiniOnnxBuilder to handle arbitrary nesting automatically:

The following snippet shows a round-trip for a small nested input dictionary:

<<<

import numpy as np
import torch
from yobx.helpers.mini_onnx_builder import (
    create_onnx_model_from_input_tensors,
    create_input_tensors_from_onnx_model,
)

inputs = {
    "ids": np.array([1, 2, 3], dtype=np.int64),
    "hidden": torch.zeros(2, 4, dtype=torch.float32),
}
proto = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(proto)
print("keys:", list(restored.keys()))
for k, v in inputs.items():
    arr = v if isinstance(v, np.ndarray) else v.numpy()
    print(f"  {k}: shape={arr.shape}, dtype={arr.dtype}")

>>>

    keys: ['ids', 'hidden']
      ids: shape=(3,), dtype=int64
      hidden: shape=(2, 4), dtype=float32

The higher-level helpers also handle deeply nested structures. The next snippet serializes a dict whose values include a flat tensor and a list of tensors (typical for past-key-value caches in transformer models):

<<<

import numpy as np
from yobx.helpers.mini_onnx_builder import (
    create_onnx_model_from_input_tensors,
    create_input_tensors_from_onnx_model,
)

inputs = {
    "input_ids": np.array([[1, 2, 3]], dtype=np.int64),
    "past_key_values": [
        np.zeros((1, 4, 6, 8), dtype=np.float32),  # layer 0 keys
        np.zeros((1, 4, 6, 8), dtype=np.float32),  # layer 0 values
    ],
}

proto = create_onnx_model_from_input_tensors(inputs)
restored = create_input_tensors_from_onnx_model(proto)

print("top-level keys:", list(restored.keys()))
print("input_ids     :", restored["input_ids"].shape)
print(
    "past_key_values is a",
    type(restored["past_key_values"]).__name__,
    "of length",
    len(restored["past_key_values"]),
)
for i, arr in enumerate(restored["past_key_values"]):
    print(f"  [{i}]: shape={arr.shape}")

>>>

    top-level keys: ['input_ids', 'past_key_values']
    input_ids     : (1, 3)
    past_key_values is a list of length 2
      [0]: shape=(1, 4, 6, 8)
      [1]: shape=(1, 4, 6, 8)

ONNX Graph Visualization

to_dot converts an onnx.ModelProto into a DOT string suitable for rendering with Graphviz.

<<<

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from yobx.helpers.dot_helper import to_dot

TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "Y"], ["added"]),
            oh.make_node("MatMul", ["added", "W"], ["mm"]),
            oh.make_node("Relu", ["mm"], ["Z"]),
        ],
        "add_matmul_relu",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 4]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 4]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 2])],
        [
            onh.from_array(
                np.zeros((4, 2), dtype=np.float32),
                name="W",
            )
        ],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)
dot = to_dot(model)
print(dot[:200], "...")

>>>

    digraph {
      graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
      node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
      edge [arrowhead=v ...

The resulting DOT source can be rendered directly in the documentation with the gdot directive from sphinx-runpython:

digraph {
  graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
  node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
  edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
  I_0 [label="X\nFLOAT(batch,seq,4)", fillcolor="#aaeeaa"];
  I_1 [label="Y\nFLOAT(batch,seq,4)", fillcolor="#aaeeaa"];
  i_2 [label="W\nFLOAT(4, 2)", fillcolor="#cccc00"];
  Add_3 [label="Add(., .)", fillcolor="#cccccc"];
  MatMul_4 [label="MatMul(., .)", fillcolor="#ee9999"];
  Relu_5 [label="Relu(.)", fillcolor="#cccccc"];
  I_0 -> Add_3 [label="FLOAT(batch,seq,4)"];
  I_1 -> Add_3 [label="FLOAT(batch,seq,4)"];
  Add_3 -> MatMul_4 [label="FLOAT(batch,seq,4)"];
  i_2 -> MatMul_4 [label="FLOAT(4, 2)"];
  MatMul_4 -> Relu_5 [label="FLOAT(batch,seq,2)"];
  O_6 [label="Z\nFLOAT(batch,seq,2)", fillcolor="#aaaaee"];
  Relu_5 -> O_6;
}

See also

ONNX Graph Visualization with to_dot — sphinx-gallery example demonstrating to_dot on a simple hand-built model.