import inspect
import os
import pprint
import textwrap
from collections import Counter
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
from onnx import AttributeProto, FunctionProto, ModelProto, NodeProto
from ..xbuilder._dtype_helper import string_to_elem_type
[docs]
class MatchResult:
"""
Returns matching results.
:param pattern: object detecting the pattern
:param nodes: nodes to be replaced
:param apply: node computing the replacements
:param insert_at: insert the new nodes at this point if specified
"""
def __init__(
self,
pattern: "PatternOptimization",
nodes: List[NodeProto],
apply: Callable,
insert_at: Optional[NodeProto] = None,
):
self.pattern = pattern
self.nodes = nodes
self.apply = apply
self.insert_at = insert_at
def to_string(self, short: bool = True) -> str:
types = [n.op_type for n in self.nodes if n is not None]
if short:
return f"MatchResult: {self.pattern} replaces {types}"
inputs = set()
outputs = set()
for node in self.nodes:
if node is None:
continue
inputs |= set(node.input)
outputs |= set(node.output)
return (
f"MatchResult: {self.pattern} replaces {types}, "
f"inputs: {inputs}, outputs: {outputs}"
)
def __str__(self) -> str:
return self.to_string(short=True)
[docs]
def debug_string(self, g: Optional["GraphBuilder"] = None) -> str: # noqa: F821
"""
Returns a string showing the matched nodes.
"""
def _p(i, g=g):
if g.has_shape(i):
return f"{i}:{g.get_type(i)}:{g.get_shape(i)}"
return f"{i}:{g.get_type(i)}:R{g.get_rank(i)}"
rows = []
for ind, node in enumerate(self.nodes):
if node is None:
rows.append(f"{ind} -")
continue
rows.append(f"{ind} - {node.op_type}({node.input}) -> {node.output}")
if g:
rows.append("--------")
for ind, node in enumerate(self.nodes):
if node is None:
rows.append(f"{ind} -")
continue
rows.append(
f"{ind} - {node.op_type}({', '.join(map(_p, node.input))}) "
f"-> {', '.join(map(_p, node.output))}"
)
return "\n".join(rows)
[docs]
class PatternOptimization:
"""
Defines an optimization pattern.
Function match should return None if the match does not happen
or better ``self.none(node, inspect.currentframe().f_lineno)``.
That allows the user to know which line rejected a specific pattern
by setting environment variable ``LOG_PATTERN_OPTIMIZE=10``.
An environment variable equal to the class name can be set as well to
track this specific pattern.
:param verbose: determine the verbosity, this can be also dermine by setting up
environment variable ``LOG_PATTERN_OPTIMIZE=10``
:param priority: at each iteration,
all patterns whose priority is below one threshold
are executed, if none of them matches, the priority is increase
:param min_opset: can be applied if main opset is > min_opset
"""
def __init__(self, verbose: int = 0, priority: int = 1, min_opset: int = 1):
value = os.environ.get("LOG_PATTERN_OPTIMIZE", "0")
self.verbose = max(verbose, int(value))
value = os.environ.get(self.__class__.__name__, "0")
self.verbose = max(self.verbose, int(value))
self.priority = priority
self.min_opset = min_opset
def __str__(self) -> str:
return self.__class__.__name__
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
def __eq__(self, o: "PatternOptimization"):
"""
Basic comparison based on the class name.
"""
return type(o) == type(self) # noqa: E721
[docs]
def enumerate_matches(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
) -> Iterator:
"""
Enumerates all the
"""
if g.main_opset >= self.min_opset:
matched = []
# g.iter_nodes() iterates on g.builder.nodes: ->
# too slow to have a secondary iterator
for node in g.builder.nodes:
# This expression seems awkard but it saves 10% just by looking into
# the first item of the list and then, if necessary, walking through the
# rest of the outputs.
if g.is_used(node.output[0]) or any(g.is_used(o) for o in node.output[1:]):
# We avoid processing a node which is not used.
res = self.match(g, node, matched)
if res:
matched.append(res)
yield res
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
"""
Determines nodes around *node* which can be rewritten.
:param g: is a :class:`GraphBuilderPatternOptimization
<experimental_experiment.xoptim.GraphBuilderPatternOptimization>`,
it holds all the existing nodes, is able to return any information
about type, shape, the node before, the node after another one.
:param node: the matching must determine if some nodes around this one
are part of set of nodes this pattern optmizer can rewrite.
From there, the function explores wherever it needs,
checking any condition it needs.
:param matched: usually unused, it returns of nodes already matching
a pattern
The method must not modify the graph.
The method returns None if no match is found or an instance
of class :class:`MatchResult
<experimental_experiment.xoptim.MatchResult>`. It must contain:
* a list of nodes involved in the rewriting. It does not mean all
of them will be removed but all of them are needed to do the rewriting
and must not be impacted by other pattern optimizer.
* A function doing the rewriting (usually method *apply* of the pattern class).
* An existing node where the rewritten nodes can be inserted.
Knowing it makes it faster to rewriter. If not specified, the optimizer
will automatically determine the position of the new nodes.
"""
raise NotImplementedError(
f"This function must be overloaded in class {self.__class__}."
)
def _debug_print(self) -> str:
return ""
[docs]
def none(
self,
node: Optional[NodeProto] = None,
lineno: Optional[int] = None,
msg: Optional[Union[Callable, str]] = None,
):
"""
It may be useful which reason made a pattern matching fail.
Instead of returning None, method *match* can return the following
expression:
::
return self.none(node, inspect.currentframe().f_lineno)
By setting the verbosity (see next Section), the user may then know
which lines in the code returned None and which condition failed.
"""
if node and self.verbose:
if msg is None:
msg = ""
elif callable(msg):
msg = msg()
if msg:
msg = f"\n{msg}"
if self.verbose >= 10 and hasattr(self, "_debug"):
msg2 = self._debug_print()
if msg2:
msg2 = f"\n{textwrap.indent(msg2, ' ')}"
print(
f"[{self.__class__.__name__}.match] NONE - line: {lineno}:"
f"{os.path.split(self.__class__.__module__)[-1]}, "
f"op_type={node.op_type}, name={node.name}{msg}{msg2}"
)
elif self.verbose >= 9:
print(
f"[{self.__class__.__name__}.match] NONE - line: {lineno}:"
f"{os.path.split(self.__class__.__module__)[-1]}, "
f"op_type={node.op_type}, name={node.name}{msg}"
)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
*nodes: Sequence[NodeProto],
) -> List[NodeProto]:
"""
The method does the rewriting. It assumes it can happen.
It takes a list of nodes impacted by the rewriting assumes no other
pattern optimizer will be modify them. It receives the list of nodes
returned by method *apply*. Since it is a list of argument, method
*match* can include None values. The method returns the new nodes.
The optimizer considers that any node given to this function is removed
from the graph, and any node returned by it are added.
If a received node must be kept, it must be added to the list of returned node.
:param nodes: nodes returned by method *match*, there are then removed
:return: nodes to add to graph.
"""
raise NotImplementedError(
f"This function must be overloaded in class {self.__class__.__name__!r}."
)
[docs]
class EasyPatternOptimization(PatternOptimization):
"""
Implements a pattern optimization for quick experimentation.
The current implementation does not match on domain name.
It does not compares attributes either.
The environment variable ``AMBIGUITIES=1`` can be set to one to
raise an exception when this case happens.
"""
def __init__(self, verbose: int = 0, priority: int = 0, min_opset: int = 1):
super().__init__(verbose=verbose, priority=priority, min_opset=min_opset)
self._cache = {}
self._validate_parameters = {}
self._debug_ambiguities = int(os.environ.get("AMBIGUITIES", 0)) == 1
[docs]
def add_validate_param(self, key: str, value: Any):
"""
Stores a value to retrieve when apply_pattern is called.
"""
self._validate_parameters[key] = value
def get_validate_param(self, key: str) -> Any:
assert (
key in self._validate_parameters
), f"Unable to find key {key!r} in {sorted(self._validate_parameters)}"
return self._validate_parameters[key]
[docs]
def match_pattern(
self,
g: "GraphBuilder", # noqa: F821
*args: List[str],
**kwargs: Dict[str, Any],
):
"""
Builds the pattern to match.
"""
raise NotImplementedError(
f"Class {self.__class__.__name__!r} must overwrite method match_pattern."
)
def _build_pattern(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
fct: Callable,
) -> "GraphBuilderPatternOptimization": # noqa: F821
from .graph_builder_optim import GraphBuilderPatternOptimization
kwargs = {}
args = []
# There should be a better way.
sig = inspect.signature(fct)
anns = []
for i, p in enumerate(sig.parameters.values()):
if i == 0:
continue
if p.default is not inspect._empty:
# an attribute
kwargs[p.name] = p.default
else:
args.append(p.name)
anns.append(p.annotation)
assert len(kwargs) == 0, f"Attributes are not supported yet but kwargs={kwargs}"
g2 = g.builder.empty_copy(as_function=True, constant_size=2**30)
for name, ann in zip(args, anns):
if ann is None or ann is str or ann is inspect._empty:
g2.make_tensor_input(name, 0, None, False, marker=f"_build_pattern1_{name}")
# Type is unknown
g2.set_type(name, -1)
continue
assert isinstance(
ann, str
), f"Annotation for {name!r} must be a string or None but ann={ann!r}"
itype = string_to_elem_type(ann)
g2.make_tensor_input(name, itype, None, False, marker=f"_build_pattern2_{name}")
output = fct(g2, *args, **kwargs)
if isinstance(output, str):
g2.make_tensor_output(output, 0, None, is_dimension=False)
else:
for name in output:
g2.make_tensor_output(name, 0, None, is_dimension=False)
pat = GraphBuilderPatternOptimization(
g2, verbose=max(0, g.verbose - 1), processor=g.processor
)
pat._build()
return pat
def _get_match_pattern(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
) -> "GraphBuilderPatternOptimization": # noqa: F821
cache_key = 0, tuple(sorted(g.opsets.items()))
if cache_key in self._cache:
return self._cache[cache_key]
pat = self._build_pattern(g, self.match_pattern)
self._cache[cache_key] = pat
return pat
def _get_apply_pattern(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
) -> "GraphBuilderPatternOptimization": # noqa: F821
cache_key = 1, tuple(sorted(g.opsets.items()))
if cache_key in self._cache:
return self._cache[cache_key]
pat = self._build_pattern(g, self.apply_pattern)
self._cache[cache_key] = pat
return pat
[docs]
def display_pattern(self, g, fct) -> str:
"""
Shows the pattern to match or to apply.
"""
pat = self._build_pattern(g, fct)
rows = []
rows.append(
f"{fct.__name__}({', '.join(pat.input_names)}) -> {', '.join(pat.output_names)}"
)
for node in pat.nodes:
rows.append(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
return "\n".join(rows)
def _match_backward(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
pat: "GraphBuilderPatternOptimization", # noqa: F821
marked: Dict[int, Tuple[NodeProto, NodeProto]],
pair_results_names: Dict[str, str],
stacked: List[int],
n: NodeProto,
pn: NodeProto,
) -> Optional[int]:
"""
Matches backward.
:param g: graph
:param node: root node (the node the matched began with,
used only for debugging)
:param pat: pattern
:param marked: nodes of the pattern marked as already matched
:param stacked: next node to look into
:param n: node coming from the graph
:param ns: node coming from the pattern
:return: number of matched nodes, None or False to indicate a failed match
"""
res = 0
# predecessors
if len(n.input) != len(pn.input):
# not the same number of inputs
self._hint(
"BACKWARD: not the same number of inputs",
"-- pattern",
pn,
"-- model",
n,
)
return self.none(node, inspect.currentframe().f_lineno)
pattern_input_names = set(pat.input_names)
for nr, pnr in zip(n.input, pn.input):
if (
pnr not in pattern_input_names
and not g.is_constant(nr)
and len(g.next_nodes(nr)) != len(pat.next_nodes(pnr))
):
self._hint(
"BACKWARD: one input is used outside the pattern",
"-- pattern input and pattern node",
pnr,
pn,
"-- model input and model node",
nr,
n,
"-- len(pat.next_nodes(pnr))",
len(pat.next_nodes(pnr)),
*pat.next_nodes(pnr),
type(pn),
"-- len(g.next_nodes(nr)))",
len(g.next_nodes(nr)),
*g.next_nodes(nr),
type(n),
)
return self.none(node, inspect.currentframe().f_lineno)
for i, pi in zip(n.input, pn.input):
ppred = pat.node_before(pi)
if ppred is None:
# ppred is None means the pattern ends here.
continue
pred = g.node_before(i)
if pred is None:
# No node in the graph.
self._hint(
"BACKWARD: no node in the graph",
"-- pred",
pred,
"-- ppred",
ppred,
)
return self.none(node, inspect.currentframe().f_lineno)
if pred.op_type != ppred.op_type or len(pred.input) != len(ppred.input):
# Distinct type
self._hint(
"BACKWARD: distinct types or distinct number of inputs",
"-- pred",
pred,
"-- ppred",
ppred,
)
return self.none(node, inspect.currentframe().f_lineno)
# matching backward
key = id(ppred)
if key not in marked:
marked[key] = pred, ppred
stacked.append(key)
res += 1
if self.verbose > 5 and res > 0:
print(f"[EasyPatternOptimization._match_backward] add {res} nodes")
return res
def _match_forward(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
pat: "GraphBuilderPatternOptimization", # noqa: F821
marked: Dict[int, Tuple[NodeProto, NodeProto]],
pair_results_names: Dict[str, str],
stacked: List[int],
n: Union[NodeProto, str],
pn: Union[NodeProto, str],
) -> Optional[int]:
"""
Matches forward.
:param g: graph
:param node: root node (the node the matched began with,
used only for debugging),
:param pat: pattern
:param marked: nodes of the pattern marked as already matched
:param stacked: next node to look into
:param n: node coming from the graph,
it can be a string to start from a result
:param ns: node coming from the pattern,
it can be a string to start from a result
:return: number of matched nodes to continue,
None or False to indicate a failed match
"""
res = 0
# successors
if isinstance(n, NodeProto) and isinstance(pn, NodeProto):
if len(n.output) != len(pn.output):
# not the same number of outputs
self._hint(
"FORWARD: not the same number of outputs",
"-- pattern",
pn,
"-- model",
n,
)
return self.none(node, inspect.currentframe().f_lineno)
matched_results = list(zip(n.output, pn.output))
elif isinstance(n, str) and isinstance(pn, str):
matched_results = [(n, pn)]
else:
raise AssertionError(f"Unexpected types for n: {type(n)} and pn: {type(pn)}.")
for o, op in matched_results:
ns = g.next_nodes(o)
pns = pat.next_nodes(op)
if len(pns) == 0:
# The pattern has no node forward, the matching stops.
continue
if len(ns) < len(pns):
# Not enough nodes in the graph to match the pattern,
# the result is known.
self._hint(
"FORWARD: not enough nodes in the graph to match the pattern",
"-- o",
o,
"-- po",
op,
"-- len(ns)",
len(ns),
"-- len(pns)",
len(pns),
)
return self.none(node, inspect.currentframe().f_lineno)
# Here comes the fun part, there is the same number of successors or more
# nodes in the graph to match with the pattern.
# And we have to handle the nodes already marked as found.
# Hopefully, there is only one option.
if len(ns) == len(pns) == 1:
# Let's deal with the simple case
if ns[0].op_type != pns[0].op_type or len(ns[0].input) != len(pns[0].input):
self._hint(
"FORWARD: distinct types or distinct number of inputs",
"-- pred",
ns[0],
"-- ppred",
pns[0],
)
return self.none(node, inspect.currentframe().f_lineno)
amb = self._has_ambiguities(pair_results_names, ns[0], pns[0])
if amb:
self._hint(
"BACKWARD: ambiguities with names",
"-- ambiguities",
ns[0],
pns[0],
"-- pairs",
pair_results_names,
"-- pattern",
self._pattern_to_string(g),
)
if self._debug_ambiguities:
raise AssertionError(
f"An ambiguities was detected, ns[0]="
f"{g.builder.pretty_node(ns[0], short=True)}, "
f"pns[0]={g.builder.pretty_node(pns[0], short=True)},\n"
f"pairs={pprint.pformat(pair_results_names)}\n-- pattern -- \n"
f"{self._pattern_to_string(g)}\n-- graph --\n"
f"{g.builder.pretty_text()}"
)
return self.none(node, inspect.currentframe().f_lineno)
key = id(pns[0])
if key not in marked:
marked[key] = ns[0], pns[0]
self._update_ambiguities(pair_results_names, ns[0], pns[0])
stacked.append(key)
res += 1
continue
# Let's remove the nodes already marked.
p_marked = [_ for _ in pns if id(_) not in marked]
id_marked = [id(marked[id(_)][0]) for _ in pns if id(_) in marked]
assert len(id_marked) + len(p_marked) == len(pns), (
f"Unexpected, id_marked={id_marked}, "
f"id_p_marked={set(map(id, p_marked))}, "
f"pns_ids={set(map(id, pns))}, "
f"ns_ids={set(map(id, ns))}, o={o!r}, op={op!r}, "
f"n.op_type={n.op_type!r}, "
f"n.output={n.output}, np.output={pn.output}, "
f"ns_types={set(_.op_type for _ in ns)}, "
f"pns_types={set(_.op_type for _ in pns)}"
)
free = [_ for _ in ns if id(_) not in id_marked]
if len(p_marked) == 0:
# Everything is already marked.
continue
if len(free) < len(p_marked):
# Not enough successors to match the remaining patterns.
return self.none(node, inspect.currentframe().f_lineno)
if len(p_marked) == len(free) == 1:
# Only one option again.
if p_marked[0].op_type != free[0].op_type or len(p_marked[0].input) != len(
free[0].input
):
self._hint(
"FORWARD: distinct types or distinct number of inputs",
"-- pred",
p_marked[0],
"-- ppred",
free[0],
)
return self.none(node, inspect.currentframe().f_lineno)
amb = self._has_ambiguities(pair_results_names, free[0], p_marked[0])
if amb:
self._hint(
"FORWARD: ambiguities with names",
"-- ambiguities",
free[0],
p_marked[0],
"-- pairs",
pair_results_names,
)
if self._debug_ambiguities:
raise AssertionError(
f"An ambiguities was detected, free[0]="
f"{g.builder.pretty_node(free[0], short=True)}, "
f"p_marked[0]={g.builder.pretty_node(p_marked[0], short=True)}, "
f"pairs={pprint.pformat(pair_results_names)}\n-- pattern -- \n"
f"{self._pattern_to_string(g)}\n-- graph --\n"
f"{g.builder.pretty_text()}"
)
return self.none(node, inspect.currentframe().f_lineno)
key = id(p_marked[0])
if key not in marked:
marked[key] = free[0], p_marked[0]
self._update_ambiguities(
pair_results_names,
free[0],
p_marked[0],
debug_msg=lambda: textwrap.indent(
self.display_pattern(g, self.match_pattern), " "
),
)
stacked.append(key)
res += 1
continue
# And now another fun part, let's try to handle the case when there
# is only one option, matching on node type only returns one option.
expected_op_type = [_.op_type for _ in p_marked]
ec = Counter(expected_op_type)
gc = Counter(_.op_type for _ in free)
if len(ec) != len(gc) or set(ec) != set(gc):
# number of unique operator types is different.
self._hint(
"FORWARD: number of unique operator types is different",
"-- pattern",
ec,
pn,
"-- model",
gc,
n,
"-- model-marked",
id_marked,
)
return self.none(node, inspect.currentframe().f_lineno)
for k, v in ec.items():
if gc[k] < v:
# Not enough types to match.
return self.none(node, inspect.currentframe().f_lineno)
# At this stage, we know matching the types is possible.
# We first mark whatever is possible.
ptype_to_node = {_.op_type: _ for _ in p_marked}
gtype_to_node = {_.op_type: _ for _ in free}
missing = []
for k, v in ec.items():
if gc[k] == v == 1:
key = id(ptype_to_node[k])
amb = self._has_ambiguities(
pair_results_names, gtype_to_node[k], ptype_to_node[k]
)
if not amb and key not in marked:
self._update_ambiguities(
pair_results_names, gtype_to_node[k], ptype_to_node[k]
)
marked[key] = gtype_to_node[k], ptype_to_node[k]
stacked.append(key)
res += 1
else:
missing.append(k)
if not missing:
continue
# At this stage, there are mutiple options for matching. We can:
# 1. make assumptions and continue
# 2. mark the node as incomplete matching, we could end up stuck anyway.
assert True, (
f"There are more than one option, this will be implemented later, "
f"ec={ec}, gc={gc}"
)
if self.verbose > 5 and res > 0:
print(f"[EasyPatternOptimization._match_forward] add {res} nodes")
return res
def _debug_print(self) -> str:
if not hasattr(self, "_debug"):
return ""
def _s(s):
if len(s) <= 30:
return s
return f"{s[:15]}...{s[-15:]}"
def _p(n, full=False):
if isinstance(n, NodeProto):
if full:
return (
f"{n.op_type}({', '.join(map(_s, n.input))}) "
f"-> ({', '.join(map(_s, n.output))})"
)
return f"{n.op_type}({','.join(map(_s, n.input))})"
return str(n)
rows = []
for k, v in sorted(self._debug.items()):
if k == "stacked":
rows.append(f"len({k})={len(v)}:{v}")
continue
if k == "iteration":
rows.append(f"{k}={v}")
continue
if k == "marked":
rows.append(f"--marked-- #{len(v)}")
for i, tu in v.items():
rows.append(f" {_p(tu[0])} ~ {_p(tu[1])} [{id(tu[0])}-{i}]")
continue
if k == "hint":
rows.append(f"--hint--: {v[0]}")
for i in v[1:]:
rows.append(" " + _p(i, full=True))
continue
if k in {"node", "pattern", "pattern_node", "pattern_nodes"}:
continue
rows.append(f"-- not shown {k}")
return "\n".join(rows)
def _hint(self, *args: Sequence[Any]):
"""
Add debugging information to help users.
"""
if self.verbose >= 5:
self._debug["hint"] = args
[docs]
def validate_mapping(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
deleted_nodes: List[NodeProto],
pattern_nodes: Optional[List[NodeProto]] = None,
) -> bool:
"""
Validates the mapping.
:param g: GraphBuilder
:param deleted_nodes: matched nodes from the model (to be deleted)
:param pattern_nodes: matched nodes coming from the pattern
:return: validate the mapping or not, default is True
"""
return True
[docs]
def validate_attribute_mapping(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
deleted_nodes: List[NodeProto],
pattern_nodes: Optional[List[NodeProto]] = None,
) -> bool:
"""
Validates the mapping of the attributes
:param g: GraphBuilder
:param deleted_nodes: matched nodes from the model (to be deleted)
:param pattern_nodes: matched nodes coming from the pattern
:return: validate the mapping or not, default is True
"""
assert len(deleted_nodes) == len(pattern_nodes), (
f"Mismatched number of nodes len(deleted_nodes)={len(deleted_nodes)}, "
f"len(pattern_nodes)={len(pattern_nodes)}"
)
for i, (node, pat_node) in enumerate(zip(deleted_nodes, pattern_nodes)):
assert node.op_type == pat_node.op_type or node.domain != pat_node.domain, (
f"Node type mismatch at position {i}, {node.op_type!r} != "
f"{pat_node.op_type!r} or {node.domain!r} != {pat_node.domain!r}"
)
in_graph = {att.name: att for att in node.attribute}
for att in pat_node.attribute:
if att.name not in in_graph:
if self.verbose >= 5:
print(
f"[EasyPatternOptimization.validate_attribute_mapping] failed "
f"attribute {att.name!r} (missing), nodes: "
f"{g.builder.pretty_node(node, short=True)} / "
f"{g.builder.pretty_node(pat_node, short=True)}"
)
return False
n_att = in_graph[att.name]
if att.type != n_att.type:
if self.verbose >= 5:
print(
f"[EasyPatternOptimization.validate_attribute_mapping] failed "
f"attribute {att.name!r} (type), "
f"nodes: {g.builder.pretty_node(node, short=True)} / "
f"{g.builder.pretty_node(pat_node, short=True)}"
)
return False
if att.type == AttributeProto.INT and att.i != n_att.i:
if (
att.name == "axis"
and node.op_type in {"Split", "Concat"}
and g.has_rank(node.input[0])
):
# Let's compare negative value.
rk = g.get_rank(node.input[0])
i1 = (att.i + rk) % rk
i2 = (n_att.i + rk) % rk
if i1 != i2:
if self.verbose >= 5:
print(
f"[EasyPatternOptimization.validate_attribute_mapping] failed "
f"attribute {att.name!r} (value int), nodes: "
f"{g.builder.pretty_node(node, short=True)} / "
f"{g.builder.pretty_node(pat_node, short=True)}"
)
return False
if att.type == AttributeProto.FLOAT and att.f != n_att.f:
if self.verbose >= 5:
print(
f"[EasyPatternOptimization.validate_attribute_mapping] failed "
f"attribute {att.name!r} (value float), nodes: "
f"{g.builder.pretty_node(node, short=True)} / "
f"{g.builder.pretty_node(pat_node, short=True)}"
)
return False
if att.type == AttributeProto.STRING and att.s != n_att.s:
if self.verbose >= 5:
print(
f"[EasyPatternOptimization.validate_attribute_mapping] "
f"failed attribute {att.name!r} (value string), "
f"nodes: {g.builder.pretty_node(node, short=True)} / "
f"{g.builder.pretty_node(pat_node, short=True)}"
)
return False
assert att.type in {
AttributeProto.INT,
AttributeProto.FLOAT,
AttributeProto.STRING,
}, (
f"Attribute comparison not implemented for data_type={att.type}, "
f"att={att} in node {pat_node}"
)
return True
def _update_ambiguities(
self, pair_results_names, node: NodeProto, pattern_node: NodeProto, debug_msg=Callable
):
for a, b in zip(node.input, pattern_node.input):
if b in pair_results_names:
assert pair_results_names[b] == a, (
f"Ambiguity {b!r} is mapped to {pair_results_names[b]!r} and {a!r} "
f"pair_results_names={pair_results_names}, pattern is\n"
f"{debug_msg()}"
)
else:
pair_results_names[b] = a
for a, b in zip(node.output, pattern_node.output):
if b in pair_results_names:
assert pair_results_names[b] == a, (
f"Ambiguity {b!r} is mapped to {pair_results_names[b]!r} and {a!r} "
f"pair_results_names={pair_results_names}, pattern is\n"
f"{debug_msg()}"
)
else:
pair_results_names[b] = a
def _has_ambiguities(
self, pair_results_names, node: NodeProto, pattern_node: NodeProto
) -> bool:
for a, b in zip(node.input, pattern_node.input):
if b in pair_results_names and pair_results_names[b] != a:
return True
for a, b in zip(node.output, pattern_node.output):
if b in pair_results_names and pair_results_names[b] != a:
return True
return False
def _pattern_to_string(self, g: "GraphBuilder"): # noqa: F821
return textwrap.indent(self.display_pattern(g, self.match_pattern), " ")
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
pat = self._get_match_pattern(g)
# Let's match the first node.
# Then we need to match successors and predecessors.
p_node = pat.nodes[-1] # the last one
if node.op_type != p_node.op_type:
# The first node does not have the same type.
return self.none()
if len(node.input) != len(p_node.input):
return self.none(node, inspect.currentframe().f_lineno)
check_ids = set(id(n) for n in pat.nodes)
if self.verbose > 5:
print(
f"[EasyPatternOptimization.match] -- starts with "
f"{node.op_type}({', '.join(node.input)})"
)
if self.verbose >= 10:
print("[EasyPatternOptimization.match] match pattern")
print(self._pattern_to_string(g))
pair_results_names = {}
self._update_ambiguities(pair_results_names, node, p_node)
marked = {id(p_node): (node, p_node)}
stacked = [id(p_node)]
iteration = 0
if self.verbose > 5:
self._debug = dict(
pattern=pat,
marked=marked,
stacked=stacked,
iteration=iteration,
node=node,
pattern_node=p_node,
pattern_nodes=pat.nodes,
)
# to avoid infinite loops.
max_iter = len(pat.nodes) * 2
while stacked and iteration < max_iter:
assert all(id(b[1]) in check_ids for b in marked.values()), (
f"At least one id is not part of the pattern ids={check_ids}, "
f"marked={set(id(b[1]) for b in marked.values())}"
)
iteration += 1
if self.verbose > 5:
print(
f"[EasyPatternOptimization.match] iteration={iteration} "
f"n_marked={len(marked)}, n_stacked={len(stacked)}, "
f"marked_types={Counter(_[1].op_type for _ in marked.values())}"
)
idn = stacked.pop()
n, pn = marked[idn]
fall_back_candidates = None
if any(pat.node_before(i) is not None for i in pn.input):
# There are backward nodes in the pattern.
res = self._match_backward(
g, node, pat, marked, pair_results_names, stacked, n, pn
)
if res is None:
if self.verbose > 5:
print("[EasyPatternOptimization.match] done. backward failed.")
return res
else:
# We check then if an input or pn has an unmatched node.
for x in pn.input:
psuccessors = pat.next_nodes(x)
if len(psuccessors) == 1:
# It is itself.
continue
for pnn in psuccessors:
if id(pnn) not in marked:
# One unmarked node is consuming the input.
# The potential list of candidates.
fall_back_candidates = list(zip(n.input, pn.input))
break
assert all(id(b[1]) in check_ids for b in marked.values()), (
f"At least one id is not part of the pattern ids={check_ids}, "
f"marked={set(id(b[1]) for b in marked.values())}"
)
res = self._match_forward(g, node, pat, marked, pair_results_names, stacked, n, pn)
if res is None:
if self.verbose > 5:
print("[EasyPatternOptimization.match] done. forward failed.")
return res
if res == 0 and fall_back_candidates:
# No backward possible, no forward either.
# We make sure that one of pattern inputs is not linked to another
# node in the pattern itself.
for candidate in fall_back_candidates:
res = self._match_forward(
g, node, pat, marked, pair_results_names, stacked, *candidate
)
if res is None or res == 0:
continue
break
assert all(id(b[1]) in check_ids for b in marked.values()), (
f"At least one id is not part of the pattern ids={check_ids}, "
f"marked={set(id(b[1]) for b in marked.values())}"
)
if self.verbose > 5:
self._debug["iteration"] = iteration
if iteration >= max_iter and stacked:
self.hint("reached {iteration}>={max_iter} iterations")
return self.none(node, inspect.currentframe().f_lineno)
# At this point, the pattern is matched but let's make sure.
assert len(stacked) == 0, f"There are still {len(stacked)} nodes to explore."
if len(marked) != len(pat.nodes):
# This should matched in most cases but when there are
# multiple outputs,
self._hint(
"MATCH: not enough matched nodes",
"-- len(marked)",
len(marked),
"-- len(pat.nodes)",
len(pat.nodes),
)
return self.none(node, inspect.currentframe().f_lineno)
# We order the matched nodes in the same order than the pattern
# to let next functions to be able to build the matching again.
matched_nodes = [marked[id(n)][0] for i, n in enumerate(pat.nodes)]
if not self.validate_attribute_mapping(g, matched_nodes, pat.nodes):
if self.verbose >= 2:
print(
f"[EasyPatternOptimization.match] attribute validation failed-1 "
f"{len(marked)} marked nodes with {iteration} iterations"
)
return None
if not self.validate_mapping(g, matched_nodes, pat.nodes):
if self.verbose >= 2:
print(
f"[EasyPatternOptimization.match] validation failed-2 "
f"{len(marked)} marked nodes with {iteration} iterations"
)
return None
if self.verbose > 5:
print(
f"[EasyPatternOptimization.match] done = matched. "
f"{len(marked)} marked nodes with {iteration} iterations"
)
if self.verbose >= 10:
for node, pat_node in zip(matched_nodes, pat.nodes):
sleft = f"{node.op_type}({node.input})->{node.output}"
print(
f" {sleft}{' ' * (60 - len(sleft))}"
f"MATCHED {pat_node.op_type}"
f"({pat_node.input})->{pat_node.output}"
)
return MatchResult(self, matched_nodes, self.apply)
[docs]
def apply_pattern(self, g: "GraphBuilder", *args, **kwargs): # noqa: F821
"""
Applies the replacement.
"""
raise NotImplementedError(
f"Class {self.__class__.__name__!r} must overwrite method 'apply_pattern'."
)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
*nodes: Sequence[NodeProto],
) -> List[NodeProto]:
# Why build the pattern gain
pat = self._get_match_pattern(g)
assert len(nodes) == len(pat.nodes), (
f"Mismatch matched nodes pattern has {len(pat.nodes)} != {len(nodes)} = "
f"the number of matched nodes"
)
new_pat = self._build_pattern(g, self.apply_pattern)
assert len(new_pat.inputs) == len(pat.inputs), (
f"Not the same number of inputs, matched inputs={len(new_pat.inputs)}, "
f"got {len(pat.inputs)} in the applied pattern."
)
assert len(new_pat.outputs) == len(pat.outputs), (
f"Not the same number of outputs, matched outputs={new_pat.output_names}, "
f"got {pat.output_names} in the applied pattern."
)
if g.verbose > 5:
print(
f"[EasyPatternOptimization.apply] replace {len(nodes)} nodes: "
f"{self.display_pattern(g, self.apply_pattern)}"
)
matched_pattern_to_applied_pattern = {}
for i, j in zip(pat.input_names, new_pat.input_names):
matched_pattern_to_applied_pattern[i] = j
for i, j in zip(pat.output_names, new_pat.output_names):
matched_pattern_to_applied_pattern[i] = j
matched_pattern_to_graph_name = {}
input_names = set(pat.input_names)
output_names = set(pat.output_names)
matched_pairs = list(zip(nodes, pat.nodes))
for gn, pn in matched_pairs:
assert (
gn.op_type == pn.op_type
), f"Unexpected type mismatch {gn.op_type!r} != {pn.op_type!r}"
assert len(gn.input) == len(
pn.input
), f"Unexpected number of inputs for type {gn.op_type}"
for a, b in zip(gn.input, pn.input):
if b not in input_names or b == "":
# optional input or not an interesting input
continue
if b in matched_pattern_to_graph_name:
assert matched_pattern_to_graph_name[b] == a, (
f"Ambiguities, pattern name {b!r} means "
f"{a!r} or {matched_pattern_to_graph_name[b]!r}"
)
else:
matched_pattern_to_graph_name[b] = a
assert len(gn.output) == len(
pn.output
), f"Unexpected number of outputs for type {gn.op_type}"
for a, b in zip(gn.output, pn.output):
if b not in output_names or b == "":
# Only final outputs are interesting.
continue
assert a != "", f"{a!r} cannot be optional"
if b in matched_pattern_to_graph_name:
assert matched_pattern_to_graph_name[b] == a, (
f"Ambiguities, pattern name {b!r} means "
f"{a!r} or {matched_pattern_to_graph_name[b]}"
)
else:
matched_pattern_to_graph_name[b] = a
replacements = {}
for k, v in matched_pattern_to_graph_name.items():
replacements[matched_pattern_to_applied_pattern[k]] = v
# Creation of the new initializers
for name, init in new_pat.builder.initializers_dict.items():
# We add them to the graph, they will be removed if unused.
new_name = g.make_initializer(
name, init, source=f"EasyPatternOptimization.init/from({name})"
)
replacements[new_name] = name
# Creation of the new node.
new_nodes = []
for node in new_pat.nodes:
new_inputs = []
for i in node.input:
assert i in replacements, f"Unable to find {i!r} in {replacements}"
ni = replacements[i]
new_inputs.append(ni)
new_outputs = []
for o in node.output:
if o in replacements:
new_outputs.append(replacements[o])
else:
# We give it a new name.
n = g.unique_name(o)
replacements[o] = n
new_outputs.append(n)
if (
node.op_type == "Constant"
and node.domain == ""
and len(node.attribute) == 1
and node.attribute[0].name == "value"
):
value = node.attribute[0].t
size = np.prod(value.dims)
if size >= g.builder.optimization_options.constant_size:
# We check the size to convert it into initializer if needed.
name = g.make_initializer(
new_outputs[0],
value,
source=f"EasyPatternOptimization.constant/from({new_outputs[0]})",
)
assert name == new_outputs[0], f"Name mismatch {name} != {new_outputs[0]}"
continue
new_node = g.make_node(
node.op_type,
new_inputs,
new_outputs,
domain=node.domain,
name=node.name,
)
new_node.attribute.extend(node.attribute)
new_nodes.append(new_node)
if g.verbose > 5:
print(f"[EasyPatternOptimization.apply] done with {len(new_nodes)} nodes")
self.post_apply_pattern(g, *nodes)
return new_nodes
[docs]
def post_apply_pattern(self, g, *nodes):
"""
Method to overload to apply as step after the pattern was applied.
"""
[docs]
class OnnxEasyPatternOptimization(EasyPatternOptimization):
"""
Implementations pattern matching with onnx models.
:param match_model: model expressing the pattern to match
:param apply_model: model expression the replacement pattern
"""
def __init__(
self,
match_model: Union[ModelProto, FunctionProto],
apply_model: Union[ModelProto, FunctionProto],
verbose: int = 0,
):
super().__init__(verbose=verbose)
self._match_model = match_model
self._apply_model = apply_model
def _build_pattern(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
fct: Callable,
) -> "GraphBuilderPatternOptimization": # noqa: F821
if fct == self.match_pattern:
onx = self._match_model
elif fct == self.apply_pattern:
onx = self._apply_model
else:
raise AssertionError(f"Cannot return pattern for unknown method {fct!r}.")
from ..xbuilder import GraphBuilder
from .graph_builder_optim import GraphBuilderPatternOptimization
gb = GraphBuilder(onx)
g2 = GraphBuilderPatternOptimization(
gb, verbose=max(0, gb.verbose - 1), processor=g.processor
)
g2._build()
return g2
[docs]
def make_pattern_from_onnx(
name: str,
match_model: Union[ModelProto, FunctionProto],
apply_model: Union[ModelProto, FunctionProto],
verbose: int = 0,
):
"""
Dynamically create a new class inheriting from
:class:`EasyPatternOptimization`.
:param name: class name
:param match_model: model expressing the pattern to match
:param apply_model: model expression the replacement pattern
:param verbose: verbosity
:return: instance of a new class
"""
new_type = type(name, (OnnxEasyPatternOptimization,), {})
return new_type(match_model, apply_model, verbose=verbose)