Note
Go to the end to download the full example code.
ONNX Graph Visualization with to_mermaid#
to_mermaid converts an
onnx.ModelProto into a Mermaid
flowchart TD string that can be rendered by any Mermaid-compatible viewer
(e.g. GitHub Markdown, the Mermaid live editor, or Sphinx with the
sphinxcontrib-mermaid extension).
The function:
assigns different CSS classes to different node kinds (inputs are green, initializers are yellow, operators are light-grey, outputs are light-blue),
inlines small scalar constants and 1-D initializers whose length is ≤ 9 directly onto the node label so the graph stays compact,
uses
BasicShapeBuilderto annotate every edge with its inferred dtype and shape (when available),handles
Scan/Loop/Ifsub-graphs by drawing dotted edges for outer-scope values consumed by the sub-graph.
The output is a plain Mermaid string; it can be embedded directly in Markdown
or saved to a .mmd file.
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from IPython.display import HTML
from yobx.doc import draw_graph_mermaid
from yobx.helpers.mermaid_helper import to_mermaid
TFLOAT = onnx.TensorProto.FLOAT
Build a small model#
The graph performs the following operations:
Add(X, Y)— element-wise sum with shape(batch, seq, d).MatMul(added, W)— project the last dimension tod//2.Relu(Z)— element-wise ReLU activation.
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Add", ["X", "Y"], ["added"]),
oh.make_node("MatMul", ["added", "W"], ["mm"]),
oh.make_node("Relu", ["mm"], ["Z"]),
],
"add_matmul_relu",
[
oh.make_tensor_value_info("X", TFLOAT, ["batch", "seq", 4]),
oh.make_tensor_value_info("Y", TFLOAT, ["batch", "seq", 4]),
],
[oh.make_tensor_value_info("Z", TFLOAT, ["batch", "seq", 2])],
[onh.from_array(np.random.randn(4, 2).astype(np.float32), name="W")],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=10,
)
Convert to Mermaid#
mermaid_src = to_mermaid(model)
print(mermaid_src)
flowchart TD
I_0["X\nFLOAT(batch,seq,4)"]:::input
I_1["Y\nFLOAT(batch,seq,4)"]:::input
i_2["W\nFLOAT(4, 2)"]:::init
Add_3["Add(., .)"]:::op
MatMul_4["MatMul(., .)"]:::op
Relu_5["Relu(.)"]:::op
I_0 -->|"FLOAT(batch,seq,4)"| Add_3
I_1 -->|"FLOAT(batch,seq,4)"| Add_3
Add_3 -->|"FLOAT(batch,seq,4)"| MatMul_4
i_2 -->|"FLOAT(4, 2)"| MatMul_4
MatMul_4 -->|"FLOAT(batch,seq,2)"| Relu_5
O_6["Z\nFLOAT(batch,seq,2)"]:::output
Relu_5 --> O_6
classDef input fill:#aaeeaa,stroke:#00aa00,color:#000
classDef init fill:#cccc00,stroke:#888800,color:#000
classDef op fill:#cccccc,stroke:#666666,color:#000
classDef output fill:#aaaaee,stroke:#0000aa,color:#000
Display the graph#
The diagram is rendered to SVG via the mermaid.ink online service (through
mermaid-py) and displayed by wrapping the SVG content in
IPython.display.HTML so that sphinx-gallery captures and embeds it.
HTML(draw_graph_mermaid(model))
Total running time of the script: (0 minutes 2.803 seconds)
Related examples