import inspect
from typing import List, Optional
from onnx import NodeProto
from ..patterns_api import MatchResult, PatternOptimization, EasyPatternOptimization
from ..patterns.onnx_functions import GeluPattern
[docs]
class BiasGeluPattern(PatternOptimization):
"""
Replaces by ``y = BiasGelu(x, B)``::
t = x + B
y = t ( Erf(1 / t) + 1)
"""
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Erf" or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
div = g.node_before(node.input[0])
if (
not g.is_constant_scalar(div.input[1])
or g.get_constant_scalar(div.input[1]) != 1.4140625
):
return self.none(node, inspect.currentframe().f_lineno)
add = g.node_before(div.input[0])
if add.op_type != "Add" or add.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(add.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
add1_nexts = g.next_nodes(add.output[0])
if len(add1_nexts) != 2:
return self.none(node, inspect.currentframe().f_lineno)
add_next = g.next_nodes(node.output[0])
if len(add_next) != 1:
return self.none(node, inspect.currentframe().f_lineno)
add_1 = add_next[0]
if add_1.op_type != "Add" or add_1.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if (
not g.is_constant_scalar(add_1.input[1])
or g.get_constant_scalar(add_1.input[1]) != 1
):
return self.none(node, inspect.currentframe().f_lineno)
muls = g.next_nodes(add_1.output[0])
if len(muls) != 1:
return self.none(node, inspect.currentframe().f_lineno)
mul = muls[0]
if mul.op_type != "Mul" or mul.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if set(mul.input) != {add.output[0], add_1.output[0]}:
return self.none(node, inspect.currentframe().f_lineno)
halfs = g.next_nodes(mul.output[0])
if len(halfs) != 1:
return self.none(node, inspect.currentframe().f_lineno)
half = halfs[0]
if half.op_type != "Mul" or half.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
index = 1 if half.input[0] == mul.output[0] else 0
if (
not g.is_constant_scalar(half.input[index])
or g.get_constant_scalar(half.input[index]) != 0.5
):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(
self, [add, div, node, add_1, mul, half], self.apply, insert_at=node
)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
add_node: NodeProto,
div_node: NodeProto,
erf_node: NodeProto,
add_1_node: NodeProto,
mul_node: NodeProto,
half_node: NodeProto,
) -> List[NodeProto]:
return [
g.make_node(
"BiasGelu",
add_node.input,
half_node.output,
domain="com.microsoft",
doc_string=erf_node.doc_string,
name=f"{self.__class__.__name__}--{erf_node.name}",
)
]
[docs]
class GeluOrtPattern(GeluPattern):
"""
Detects the decomposed version of Gelu with Tanh
.. math::
y = \\frac{x}{2} \\left(1 + \\tanh\\left(\\sqrt{\\frac{2}{\\pi}}
(x + 0.044715 * x^3)\\right)\\right)
"""
def __init__(
self,
verbose: int = 0,
priority: int = 0,
min_opset: int = 1,
domain: str = "com.microsoft",
):
super().__init__(verbose, priority, min_opset=min_opset)
self.domain = domain
[docs]
class GeluErfPattern(EasyPatternOptimization):
"""
Detects the decomposed version of Gelu with Erf.
"""
def __init__(self, verbose: int = 0, priority: int = 0, min_opset: int = 1):
super().__init__(verbose, priority, min_opset=min_opset)
[docs]
def match_pattern(self, g: "GraphBuilder", x, cst2, one, c05): # noqa: F821
xd = g.op.Div(x, cst2) # 1.4140625
exd = g.op.Erf(xd)
aexd = g.op.Add(exd, one) # 1
mul = g.op.Mul(x, aexd)
return g.op.Mul(c05, mul) # 0.5
[docs]
def apply_pattern(self, g: "GraphBuilder", x, cst2, one, c05): # noqa: F821
return g.anyop.Gelu(x, domain="com.microsoft")
[docs]
def validate_mapping(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
deleted_nodes: List[NodeProto],
pattern_nodes: Optional[List[NodeProto]] = None,
) -> bool:
assert len(deleted_nodes) == 5, f"Unexpected pattern length {len(deleted_nodes)}"
assert deleted_nodes[0].op_type == "Div", f"-- {deleted_nodes[0]}"
cst2 = deleted_nodes[0].input[1]
assert deleted_nodes[2].op_type == "Add", f"-- {deleted_nodes[2]}"
one = deleted_nodes[2].input[1]
assert deleted_nodes[4].op_type == "Mul", f"-- {deleted_nodes[4]}"
c05 = deleted_nodes[4].input[0]
node = deleted_nodes[1]
if not g.is_constant_scalar(cst2) or g.get_constant_scalar(cst2) != 1.4140625:
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant_scalar(one) or g.get_constant_scalar(one) != 1:
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant_scalar(c05) or g.get_constant_scalar(c05) != 0.5:
return self.none(node, inspect.currentframe().f_lineno)
return True
[docs]
class FastGeluPattern(PatternOptimization):
"""
Replaces Gelu by FastGelu.
"""
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Gelu" or node.domain not in ("", "com.microsoft"):
return self.none()
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
gelu_node: NodeProto,
) -> List[NodeProto]:
return [
g.make_node(
"FastGelu",
gelu_node.input,
gelu_node.output,
domain="com.microsoft",
doc_string=gelu_node.doc_string,
name=f"{self.__class__.__name__}--{gelu_node.name}",
)
]