import sys
from typing import List, Union
import onnx
import onnx.helper as oh
from ..xbuilder import GraphBuilder
from ..helpers import from_array_extended
[docs]
def unfused_nodes(
onx: Union[GraphBuilder, onnx.ModelProto],
input_names: List[str],
output_names: List[str],
as_proto: bool = False,
) -> Union[List[onnx.NodeProto], onnx.ModelProto]:
"""
Extracts from a model all the nodes starting from ``input_names`` and
producing ``output_names``.
This function does not handle subgraphs.
:param onx: model
:param input_names: list of inputs to consider
:param output_names: list of outputs to consider
:param as_proto: produces a ModelProto instead of a list of nodes
:return: list of involved nodes
One code example:
.. code-block:: python
import onnx
from experimental_experiment.xoptim.unfused import unfused_nodes
from experimental_experiment.helpers import pretty_onnx
print("-- load optmized model")
optimized = onnx.load("<optimized_model.onnx>")
print("-- loading not optimized model")
not_optimized = onnx.load("<not_optimized_model.onnx>")
print("-- done")
def look_for_pattern(
not_optimized: onnx.ModelProto, optimized: onnx.ModelProto, node_type: str
):
print()
print(f"-- looking for fused nodes, type={node_type!r}")
fused_node = None
for node in optimized.graph.node:
if node.op_type == node_type:
fused_node = node
input_names = [i for i in node.input if i]
output_names = [o for o in node.output if o]
break
assert input_names, "No fused node was found."
print("-- fused_node")
print(
pretty_onnx(
fused_node, with_attributes=True,
highlight=set(input_names) | set(output_names),
)
)
print(f"-- input_names={input_names}")
print(f"-- output_names={output_names}")
print("--")
print("-- looking for fused nodes")
fused = unfused_nodes(not_optimized, input_names, output_names, as_proto=True)
print("-- save onnx")
onnx.save(fused, f"unfused_{node_type}.onnx")
print("--")
print(f"-- found {len(fused)} nodes")
for node in fused:
print(
pretty_onnx(
node, with_attributes=True,
highlight=set(input_names) | set(output_names),
)
)
print("--")
look_for_pattern(not_optimized, optimized, "Attention")
look_for_pattern(not_optimized, optimized, "SkipLayerNormalization")
"""
if isinstance(onx, onnx.ModelProto):
return unfused_nodes(GraphBuilder(onx), input_names, output_names, as_proto=as_proto)
assert isinstance(onx, GraphBuilder), f"Unexpected type {type(onx)} for onx."
# First step, we go backward.
nodes = []
needed = set(output_names)
not_needed = set(input_names)
for node in onx.nodes[::-1]:
so = set(node.output)
if so & needed:
nodes.append(node)
needed |= set(i for i in node.input if i not in not_needed)
needed -= needed & set(node.output)
elif so & not_needed == so:
not_needed |= set(node.input)
if not needed:
break
# Then we go forward.
keep = []
input_involded = set(input_names)
for node in reversed(nodes):
if set(node.input) & input_involded:
keep.append(node)
input_involded |= set(node.output)
if not as_proto:
return keep
# For a whole model, let's consider the missing inputs.
input_names_all = input_names.copy()
known = set(input_names)
all_names = set()
for node in keep:
all_names |= set(node.input)
all_names |= set(node.output)
for i in node.input:
if i not in known:
known.add(i)
input_names_all.append(i)
known |= set(node.output)
init = {}
shapes = {}
for name in [*input_names_all, *output_names]:
if onx.is_constant(name):
init[name] = onx.get_constant(name)
continue
shapes[name] = onx.make_tensor_value_info_from_name(name)
inits, large_inits = onx._build_initializers(
switch_low_high=sys.byteorder != "big",
large_model=False,
subset=set(init),
external_threshold=1024,
)
assert (
not large_inits
), f"Not yet implemeted with large initializers large_inits={set(large_inits)}"
# We convert constants to initializers.
init_names = set(i.name for i in inits)
for name in init:
if name in init_names:
continue
# A constant.
cst = onx.get_constant(name, computed_value=True)
inits.append(from_array_extended(cst, name=name))
model = oh.make_model(
oh.make_graph(
keep,
"unfused_nodes",
[shapes[n] for n in input_names_all if n not in init],
[shapes[n] for n in output_names],
),
ir_version=onx.ir_version,
opset_imports=[oh.make_opsetid(*o) for o in onx.opsets.items()],
)
model.graph.initializer.extend(inits)
model.graph.value_info.extend(
onx.make_tensor_value_info_from_name(n) for n in all_names if onx.has_shape(n)
)
return model