ONNX Graph Visualization with to_mermaid#

to_mermaid converts an onnx.ModelProto into a Mermaid flowchart TD string that can be rendered by any Mermaid-compatible viewer (e.g. GitHub Markdown, the Mermaid live editor, or Sphinx with the sphinxcontrib-mermaid extension).

The function:

  • assigns different CSS classes to different node kinds (inputs are green, initializers are yellow, operators are light-grey, outputs are light-blue),

  • inlines small scalar constants and 1-D initializers whose length is ≤ 9 directly onto the node label so the graph stays compact,

  • uses BasicShapeBuilder to annotate every edge with its inferred dtype and shape (when available),

  • handles Scan / Loop / If sub-graphs by drawing dotted edges for outer-scope values consumed by the sub-graph.

The output is a plain Mermaid string; it can be embedded directly in Markdown or saved to a .mmd file.

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from IPython.display import HTML
from yobx.doc import draw_graph_mermaid
from yobx.helpers.mermaid_helper import to_mermaid

TFLOAT = onnx.TensorProto.FLOAT

Build a small model#

The graph performs the following operations:

  1. Add(X, Y) — element-wise sum with shape (batch, seq, d).

  2. MatMul(added, W) — project the last dimension to d//2.

  3. Relu(Z) — element-wise ReLU activation.

model = oh.make_model(
    oh.make_graph(
        [
            oh.make_node("Add", ["X", "Y"], ["added"]),
            oh.make_node("MatMul", ["added", "W"], ["mm"]),
            oh.make_node("Relu", ["mm"], ["Z"]),
        ],
        "add_matmul_relu",
        [
            oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 4]),
            oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 4]),
        ],
        [oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 2])],
        [onh.from_array(np.random.randn(4, 2).astype(np.float32), name="W")],
    ),
    opset_imports=[oh.make_opsetid("", 18)],
    ir_version=10,
)

Convert to Mermaid#

mermaid_src = to_mermaid(model)
print(mermaid_src)
flowchart TD
    I_0["X\nFLOAT(batch,seq,4)"]:::input
    I_1["Y\nFLOAT(batch,seq,4)"]:::input
    i_2["W\nFLOAT(4, 2)"]:::init
    Add_3["Add(., .)"]:::op
    MatMul_4["MatMul(., .)"]:::op
    Relu_5["Relu(.)"]:::op
    I_0 -->|"FLOAT(batch,seq,4)"| Add_3
    I_1 -->|"FLOAT(batch,seq,4)"| Add_3
    Add_3 -->|"FLOAT(batch,seq,4)"| MatMul_4
    i_2 -->|"FLOAT(4, 2)"| MatMul_4
    MatMul_4 -->|"FLOAT(batch,seq,2)"| Relu_5
    O_6["Z\nFLOAT(batch,seq,2)"]:::output
    Relu_5 --> O_6
    classDef input fill:#aaeeaa,stroke:#00aa00,color:#000
    classDef init fill:#cccc00,stroke:#888800,color:#000
    classDef op fill:#cccccc,stroke:#666666,color:#000
    classDef output fill:#aaaaee,stroke:#0000aa,color:#000

Display the graph#

The diagram is rendered to SVG via the mermaid.ink online service (through mermaid-py) and displayed by wrapping the SVG content in IPython.display.HTML so that sphinx-gallery captures and embeds it.

HTML(draw_graph_mermaid(model))

FLOAT(batch,seq,4)

FLOAT(batch,seq,4)

FLOAT(batch,seq,4)

FLOAT(4, 2)

FLOAT(batch,seq,2)

X\nFLOAT(batch,seq,4)

Y\nFLOAT(batch,seq,4)

W\nFLOAT(4, 2)

Add(., .)

MatMul(., .)

Relu(.)

Z\nFLOAT(batch,seq,2)



Total running time of the script: (0 minutes 2.803 seconds)

Related examples

ONNX Graph Visualization with to_dot

ONNX Graph Visualization with to_dot

Comparing the five ONNX translation APIs

Comparing the five ONNX translation APIs

Computed Shapes: Add + Concat + Reshape

Computed Shapes: Add + Concat + Reshape

Gallery generated by Sphinx-Gallery