Source code for experimental_experiment.xoptim.patterns_investigation.element_wise
from typing import List, Optional
from onnx import NodeProto
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class BinaryInvestigation(PatternOptimization):
"""
Looks into
"""
_ops = {"Add", "Div", "Mul", "Sub"}
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._ops:
return self.none()
left = g.node_before(node.input[0])
right = g.node_before(node.input[1])
if left is None and right is None:
return self.none()
nodes = [node, left, right]
if self.verbose:
print(f"[{self.__class__.__name__}] {self.report(g, *nodes)}")
return self.none()
def _str(cls, g, node):
if node.op_type in cls._ops:
sh1 = g.get_shape(node.input[0]) if g.has_shape(node.input[0]) else ("?",)
sh2 = g.get_shape(node.input[1]) if g.has_shape(node.input[1]) else ("?",)
if len(sh1) == 0:
sh1 = (1,)
if len(sh2) == 0:
sh2 = (1,)
sh1 = "x".join(map(str, sh1))
sh2 = "x".join(map(str, sh2))
return f"{node.op_type}({sh1}, {sh2})"
return f"{node.op_type}(...)"
def report(
cls,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
left: Optional[NodeProto],
right: Optional[NodeProto],
):
rows = [cls._str(g, node)]
rows.append(f"[{cls._str(g, left)}]" if left is not None else "[?]")
rows.append(f"[{cls._str(g, right)}]" if right is not None else "[?]")
return " --- ".join(rows)