.xoptim.unfused

experimental_experiment.xoptim.unfused.unfused_nodes(onx: GraphBuilder | ModelProto, input_names: List[str], output_names: List[str], as_proto: bool = False) List[NodeProto] | ModelProto[source]

Extracts from a model all the nodes starting from input_names and producing output_names.

This function does not handle subgraphs.

Parameters:
  • onx – model

  • input_names – list of inputs to consider

  • output_names – list of outputs to consider

  • as_proto – produces a ModelProto instead of a list of nodes

Returns:

list of involved nodes

One code example:

import onnx
from experimental_experiment.xoptim.unfused import unfused_nodes
from experimental_experiment.helpers import pretty_onnx


print("-- load optmized model")
optimized = onnx.load("<optimized_model.onnx>")

print("-- loading not optimized model")
not_optimized = onnx.load("<not_optimized_model.onnx>")
print("-- done")


def look_for_pattern(
    not_optimized: onnx.ModelProto, optimized: onnx.ModelProto, node_type: str
):
    print()
    print(f"-- looking for fused nodes, type={node_type!r}")

    fused_node = None
    for node in optimized.graph.node:
        if node.op_type == node_type:
            fused_node = node
            input_names = [i for i in node.input if i]
            output_names = [o for o in node.output if o]
            break

    assert input_names, "No fused node was found."

    print("-- fused_node")
    print(
        pretty_onnx(
            fused_node, with_attributes=True,
            highlight=set(input_names) | set(output_names),
        )
    )
    print(f"-- input_names={input_names}")
    print(f"-- output_names={output_names}")
    print("--")

    print("-- looking for fused nodes")
    fused = unfused_nodes(not_optimized, input_names, output_names, as_proto=True)

    print("-- save onnx")
    onnx.save(fused, f"unfused_{node_type}.onnx")

    print("--")
    print(f"-- found {len(fused)} nodes")
    for node in fused:
        print(
            pretty_onnx(
                node, with_attributes=True,
                highlight=set(input_names) | set(output_names),
            )
        )
    print("--")


look_for_pattern(not_optimized, optimized, "Attention")
look_for_pattern(not_optimized, optimized, "SkipLayerNormalization")