import inspect
from typing import List, Optional
from onnx import NodeProto
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class FusedMatMulDivPattern(PatternOptimization):
"""
Replaces the Matmul, Div into FusedMatMul.
"""
def __init__(self, verbose: int = 0, priority: int = 2):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if (node.op_type != "MatMul" or node.domain != "") and (
node.op_type != "FusedMatMul" or node.domain != "com.microsoft"
):
return self.none()
next_nodes = g.next_nodes(node.output[0])
if len(next_nodes) != 1:
return self.none(node, inspect.currentframe().f_lineno)
op_type = next_nodes[0].op_type
if op_type not in ("Mul", "Div"):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant_scalar(next_nodes[0].input[1]):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_nodes[0]], self.apply, insert_at=next_nodes[0])
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
node_div: NodeProto,
) -> List[NodeProto]:
alpha = 1.0
atts = []
if node.op_type == "FusedMatMul":
for att in node.attribute:
if att.name == "alpha":
alpha *= att.f
else:
atts.append(att)
cst = g.get_computed_constant(node_div.input[1])
scale = float(cst if len(cst.shape) == 0 else cst[0])
if node_div.op_type == "Div":
alpha /= scale
else:
alpha *= scale
mm = g.make_node(
"FusedMatMul",
node.input,
node_div.output,
domain="com.microsoft",
alpha=alpha,
name=f"{self.__class__.__name__}--{node.name}",
)
if atts:
mm.attribute.extend(atts)
return [mm]
[docs]
class FusedMatMulPattern(PatternOptimization):
"""
Replaces the sequence Transpose, Matmul into FusedMatMul.
"""
def __init__(self, verbose: int = 0, priority: int = 2):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if (node.op_type != "MatMul" or node.domain != "") and (
node.op_type != "FusedMatMul" or node.domain != "com.microsoft"
):
return self.none()
if node.op_type == "FusedMatMul":
transA = g.get_attribute(node, "transA", exc=False) or 0
transB = g.get_attribute(node, "transB", exc=False) or 0
if transA != transB:
# one side is already transposed.
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_rank(node.input[0]) or not g.has_rank(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if g.get_rank(node.input[0]) < 2 or g.get_rank(node.input[1]) < 2:
return self.none(node, inspect.currentframe().f_lineno)
if g.get_rank(node.input[0]) <= 2 and g.get_rank(node.input[1]) <= 2:
# Regular Gemm.
return self.none(node, inspect.currentframe().f_lineno)
nodes_before = [g.node_before(node.input[0]), g.node_before(node.input[1])]
ns = [
(n if n is not None and n.op_type == "Transpose" and n.domain == "" else None)
for n in nodes_before
]
if len([_ for _ in ns if _ is not None]) == 0:
return self.none(node, inspect.currentframe().f_lineno)
if g.has_processor("CUDA"):
nns = []
for n in ns:
if n is None:
nns.append(n)
continue
if g.is_used_more_than_once(n.output[0]):
nns.append(None)
continue
nns.append(n)
if len([_ for _ in ns if _ is not None]) == 0:
return self.none(node, inspect.currentframe().f_lineno)
ns = nns
hints = []
found = False
nns = []
for n in ns:
if n is None:
nns.append(None)
continue
perm = list(g.get_attribute(n, "perm").ints)
expecting = list(range(len(perm)))
expecting[-2], expecting[-1] = expecting[-1], expecting[-2]
if perm != expecting:
hints.append(dict(expecting=expecting, perm=perm))
nns.append(None)
continue
found = True
nns.append(n)
ns = nns
if not found:
# unexpected transpose
return self.none(node, inspect.currentframe().f_lineno, lambda: f"hints={hints}")
# At this stage, one or two inputs are transposed before being used.
# MatMul or Gemm are operating on 2D tensors.
nodes = [*ns, node]
if nodes[0] is not None and nodes[1] is not None:
# Both are available, we only transpose one.
nodes[0] = None
if not g.is_used_more_than_once(node.output[0]):
next_node = g.next_node(node.output[0])
if (
next_node.op_type in {"Div", "Mul"}
and next_node.domain == ""
and g.is_constant_scalar(next_node.input[1])
):
# The node can be fused with matmul
nodes.append(next_node)
return MatchResult(self, nodes, self.apply)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node_before_left: Optional[NodeProto],
node_before_right: Optional[NodeProto],
node: NodeProto,
scale: Optional[NodeProto] = None,
) -> List[NodeProto]:
inputs = [
(node.input[0] if node_before_left is None else node_before_left.input[0]),
(node.input[1] if node_before_right is None else node_before_right.input[0]),
*node.input[2:],
]
transA = 0 if node_before_left is None else 1
transB = 0 if node_before_right is None else 1
transBatchA = 0
transBatchB = 0
keep = []
for att in node.attribute:
if att.name in {"alpha", "beta"}:
keep.append(att)
elif att.name == "transA":
transA = (att.i + transA) % 2
elif att.name == "transB":
transB = (att.i + transB) % 2
elif att.name == "transBatchA":
transBatchA = att.i
elif att.name == "transBatchB":
transBatchB = att.i
else:
raise NotImplementedError(
f"Unexpected attribute {att.name!r}={att} for node={node}"
)
kwargs = dict(
transA=transA,
transB=transB,
transBatchA=transBatchA,
transBatchB=transBatchB,
)
if scale is not None:
# Let's include the scale as well
cst = g.get_computed_constant(scale.input[1])
value = float(cst[0] if cst.shape == (1,) else cst)
assert scale.op_type in {
"Div",
"Mul",
}, f"Match did not check next_node type {scale.op_type!r}"
alpha = value if scale.op_type == "Mul" else (1.0 / value)
kwargs["alpha"] = alpha
output = scale.output[0]
else:
output = node.output[0]
new_node = g.make_node(
"FusedMatMul",
inputs,
[output],
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
domain="com.microsoft",
**kwargs,
)
new_node.attribute.extend(keep)
res = [new_node]
if node_before_left is not None and g.is_used_more_than_once(
node_before_left.output[0]
):
# This is not efficient on CUDA.
res.append(node_before_left)
if node_before_right is not None and g.is_used_more_than_once(
node_before_right.output[0]
):
# This is not efficient on CUDA.
res.append(node_before_right)
return res
[docs]
class FusedMatMulx2Pattern(PatternOptimization):
"""
Replaces the sequence Div by a scalar consumed by two FusedMatMul.
"""
def __init__(self, verbose: int = 0, priority: int = 3):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if (node.op_type not in "MatMul" or node.domain != "") and (
node.op_type != "FusedMatMul" or node.domain != "com.microsoft"
):
return self.none()
div_node = None
for name in node.input:
n = g.node_before(name)
if n is None:
continue
if n.op_type not in {"Mul", "Div"} or n.domain != "":
continue
if not g.is_constant_scalar(n.input[1]):
continue
div_node = n
break
if div_node is None:
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(div_node.output[0])
op_types = [n.op_type for n in next_nodes]
if any(t not in {"FusedMatMul", "MatMul"} for t in op_types):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [div_node, *next_nodes], self.apply)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
div_node: Optional[NodeProto],
*mnodes: Optional[NodeProto],
) -> List[NodeProto]:
cst = g.get_constant_scalar(div_node.input[1])
if div_node.op_type == "Div":
cst = 1.0 / cst
new_nodes = []
for node in mnodes:
alpha = 1.0
atts = []
for att in node.attribute:
if att.name == "alpha":
alpha = float(att.f)
else:
atts.append(att)
new_inputs = [
(div_node.input[0] if i == div_node.output[0] else i) for i in node.input
]
alpha *= cst
new_node = g.make_node(
"FusedMatMul",
new_inputs,
node.output,
domain="com.microsoft",
alpha=alpha,
name=f"{self.__class__.__name__}--{node.name}",
)
if atts:
new_node.attribute.extend(atts)
new_nodes.append(new_node)
return new_nodes
[docs]
class FusedMatMulTransposePattern(PatternOptimization):
"""
Replaces the sequence (Fused)Matmul(A,B) + Transpose
into FusedMatMul(B.T, A.T).
"""
def __init__(self, verbose: int = 0, priority: int = 3):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if (node.op_type not in "MatMul" or node.domain != "") and (
node.op_type != "FusedMatMul" or node.domain != "com.microsoft"
):
return self.none()
next_nodes = g.next_nodes(node.output[0])
if (
len(next_nodes) != 1
or next_nodes[0].op_type != "Transpose"
or next_nodes[0].domain != ""
):
return self.none(node, inspect.currentframe().f_lineno)
transpose_node = next_nodes[0]
perm = list(g.get_attribute(transpose_node, "perm").ints)
if len(perm) > 2:
if perm[:-2] != list(range(len(perm) - 2)):
return self.none(node, inspect.currentframe().f_lineno)
if perm[-2:] != [len(perm) - 1, len(perm) - 2]:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, transpose_node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
transpose_node: NodeProto,
) -> List[NodeProto]:
default_values = dict(transA=0, transB=0, transBatchA=0, transBatchB=0, alpha=1.0)
kwargs = g.get_attributes_with_default(node, **default_values)
kwargs["transA"], kwargs["transB"] = 1 - kwargs["transB"], 1 - kwargs["transA"]
remove = []
for k in kwargs:
if kwargs[k] == default_values[k]:
remove.append(k)
for r in remove:
del kwargs[r]
new_node = g.make_node(
"FusedMatMul",
[node.input[1], node.input[0]],
transpose_node.output,
domain="com.microsoft",
name=f"{self.__class__.__name__}--{node.name}",
**kwargs,
)
return [new_node]