import inspect
from typing import List, Optional
import numpy as np
from onnx import NodeProto
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.numpy_helper import from_array
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class LayerNormalizationPattern(PatternOptimization):
"""
Fuses node of a LayerNormalization.
"""
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "ReduceMean" or node.domain != "":
return self.none()
if len(node.input) != 2:
# Not defined for older opset than 18.
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
axis = g.get_computed_constant(node.input[1])
if axis.tolist() != [-1]:
if not g.has_rank(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
rk = g.get_rank(node.input[0])
al = axis.tolist()
if al != list(range(rk - len(al), rk)):
return self.none(node, inspect.currentframe().f_lineno)
# before
pow = g.node_before(node.input[0])
if pow is None:
return self.none(node, inspect.currentframe().f_lineno)
if pow.op_type != "Pow" or len(g.next_nodes(pow.output[0])) != 1:
return self.none(node, inspect.currentframe().f_lineno)
if (
not g.is_constant_scalar(pow.input[1], broadcast=True)
or g.get_constant_scalar(pow.input[1], broadcast=True) != 2
):
return self.none(node, inspect.currentframe().f_lineno)
sub = g.node_before(pow.input[0])
if sub is None:
return self.none(node, inspect.currentframe().f_lineno)
if sub.op_type != "Sub" or len(g.next_nodes(sub.output[0])) != 2:
return self.none(node, inspect.currentframe().f_lineno)
red = g.node_before(sub.input[1])
if red is None:
return self.none(node, inspect.currentframe().f_lineno)
if red.op_type != "ReduceMean" or len(g.next_nodes(red.output[0])) != 1:
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(red.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
axis2 = g.get_computed_constant(red.input[1])
if axis.tolist() != axis2.tolist():
return self.none(node, inspect.currentframe().f_lineno)
if sub.input[0] != red.input[0]:
return self.none(node, inspect.currentframe().f_lineno)
kp = g.get_attribute(red, "keepdims", exc=False)
if kp is None or kp.i != 1:
return self.none(node, inspect.currentframe().f_lineno)
# after
add = g.next_nodes(node.output[0])
if len(add) != 1:
return self.none(node, inspect.currentframe().f_lineno)
if add[0].op_type == "Add":
add = add[0]
if not g.is_constant_scalar(add.input[1], broadcast=True):
return self.none(node, inspect.currentframe().f_lineno)
sqrt = g.next_nodes(add.output[0])
else:
add = None
if add is None:
sqrt = g.next_nodes(node.output[0])
if len(sqrt) != 1 or sqrt[0].op_type != "Sqrt":
return self.none(node, inspect.currentframe().f_lineno)
sqrt = sqrt[0]
div = g.next_nodes(sqrt.output[0])
if len(div) != 1:
return self.none(node, inspect.currentframe().f_lineno)
div = div[0]
if div.op_type == "Div":
if len(g.next_nodes(div.input[1])) != 1:
return self.none(node, inspect.currentframe().f_lineno)
if div.input[0] != sub.output[0]:
return self.none(node, inspect.currentframe().f_lineno)
elif div.op_type == "Reciprocal":
if div.input[0] != sub.output[0]:
return self.none(node, inspect.currentframe().f_lineno)
else:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(
self, [red, sub, pow, node, add, sqrt, div], self.apply, insert_at=node
)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
red: NodeProto,
sub: NodeProto,
pow: NodeProto,
node: NodeProto,
add: Optional[NodeProto],
sqrt: NodeProto,
div: NodeProto,
) -> List[NodeProto]:
itype = g.get_type(red.input[0])
dtype = tensor_dtype_to_np_dtype(itype)
axis = g.get_computed_constant(red.input[1]).tolist()
scale = None
dtype = tensor_dtype_to_np_dtype(g.get_type(red.input[0]))
new_nodes = []
if axis == [-1]:
ly_axis = -1
if g.has_shape(red.input[0]):
shape = g.get_shape(red.input[0])
if isinstance(shape[-1], int):
scale = g.make_initializer("", np.ones((shape[-1],), dtype=dtype))
bias = g.make_initializer("", np.zeros((shape[-1],), dtype=dtype))
else:
ly_axis = min(axis)
if scale is None:
shape = g.unique_name(f"{self.__class__.__name__}_Sh_{red.input[0]}")
new_nodes.append(
g.make_node(
"Shape",
[red.input[0]],
[shape],
start=ly_axis,
name=f"{self.__class__.__name__}--{red.name}",
)
)
scale = g.unique_name(f"{self.__class__.__name__}_Sc_{red.input[0]}")
new_nodes.append(
g.make_node(
"ConstantOfShape",
[shape],
[scale],
name=f"{self.__class__.__name__}--{red.name}",
value=from_array(np.array([1], dtype=dtype)),
)
)
bias = g.unique_name(f"{self.__class__.__name__}_Bi_{red.input[0]}")
new_nodes.append(
g.make_node(
"ConstantOfShape",
[shape],
[bias],
name=f"{self.__class__.__name__}--{red.name}",
value=from_array(np.array([0], dtype=dtype)),
)
)
eps = (
g.get_constant_scalar(add.input[1], broadcast=True)
if add
else 9.999999960041972e-13
)
new_nodes.append(
g.make_node(
"LayerNormalization",
[red.input[0], scale, bias],
[div.output[0]],
epsilon=float(eps),
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
stash_type=1, # itype,
axis=ly_axis,
)
)
return new_nodes
[docs]
class LayerNormalizationScalePattern(PatternOptimization):
"""
Fused LayerNormalization, scale, bias just after.
"""
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "LayerNormalization" or node.domain != "":
return self.none()
if len(node.output) != 1:
# No need for the scale.
return self.none(node, inspect.currentframe().f_lineno)
nodes = g.next_nodes(node.output[0])
if len(nodes) != 1 or nodes[0].op_type != "Mul":
return self.none(node, inspect.currentframe().f_lineno)
mul_node = nodes[0]
nodes = g.next_nodes(mul_node.output[0])
if len(nodes) == 0:
return MatchResult(self, [node, mul_node, None], self.apply, insert_at=mul_node)
index = 1 if mul_node.input[0] == node.output[0] else 0
if not g.has_shape(mul_node.input[index]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if g.get_shape(mul_node.input[index]) != g.get_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
nodes = g.next_nodes(mul_node.output[0])
if len(nodes) != 1 or nodes[0].op_type != "Add":
return MatchResult(self, [node, mul_node, None], self.apply, insert_at=nodes[0])
add_node = nodes[0]
index = 1 if add_node.input[0] == mul_node.output[0] else 0
if not g.has_shape(add_node.input[index]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if g.get_shape(add_node.input[index]) != g.get_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, mul_node, add_node], self.apply, insert_at=nodes[0])
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
ln_node: NodeProto,
mul_node: NodeProto,
add_node: Optional[NodeProto],
) -> List[NodeProto]:
# scale
scale = (
mul_node.input[1] if mul_node.input[0] == ln_node.output[0] else mul_node.input[0]
)
new_scale = None
if g.is_constant_scalar(ln_node.input[1], broadcast=True):
fscale = g.get_constant_scalar(ln_node.input[1], broadcast=True)
if fscale == 1:
new_scale = scale
new_nodes = []
if new_scale is None:
new_scale = g.unique_name(f"{self.__class__.__name__}_{ln_node.input[1]}")
node = g.make_node(
"Mul",
[ln_node.input[1], scale],
[new_scale],
name=f"{self.__class__.__name__}--{ln_node.name}",
)
new_nodes.append(node)
if add_node:
if len(ln_node.input) == 2:
new_bias = (
add_node.input[1]
if add_node.input[0] == mul_node.output[0]
else add_node.input[0]
)
else:
# there is an existing bias
existing_bias = ln_node.input[2]
mul_cst = (
mul_node.input[0]
if mul_node.input[1] == ln_node.output[0]
else mul_node.input[1]
)
add_cst = (
add_node.input[0]
if add_node.input[1] == mul_node.output[0]
else add_node.input[1]
)
# new_bias = existing_bias * mul_cst + add_cst
temp = g.unique_name(f"{self.__class__.__name__}_{ln_node.input[1]}")
new_bias = g.unique_name(f"{self.__class__.__name__}_{ln_node.input[1]}")
new_nodes.extend(
[
g.make_node(
"Mul",
[mul_cst, existing_bias],
[temp],
name=f"{self.__class__.__name__}--{ln_node.name}",
),
g.make_node(
"Add",
[temp, add_cst],
[new_bias],
name=f"{self.__class__.__name__}--{ln_node.name}",
),
]
)
else:
new_bias = ln_node.input[2] if len(ln_node.input) > 2 else None
kwargs = {}
axis = g.get_attribute(ln_node, "axis", exc=None)
if axis:
kwargs["axis"] = axis.i
epsilon = g.get_attribute(ln_node, "epsilon", exc=None)
if epsilon:
kwargs["epsilon"] = epsilon.f
stash_type = g.get_attribute(ln_node, "stash_type", exc=None)
if stash_type:
kwargs["stash_type"] = stash_type.i
new_node = g.make_node(
"LayerNormalization",
(
[ln_node.input[0], new_scale, new_bias]
if new_bias
else [ln_node.input[0], new_scale]
),
[(add_node or mul_node).output[0]],
name=f"{self.__class__.__name__}--{ln_node.name}",
doc_string=ln_node.doc_string,
**kwargs,
)
return [*new_nodes, new_node]
[docs]
class CastLayerNormalizationCastPattern(PatternOptimization):
"""
Checks that a Cast is really needed around LayerNormalization
"""
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in (
"LayerNormalization",
"SimplifiedLayerNormalization",
) or node.domain not in ("", "com.microsoft"):
return self.none()
if len(node.output) > 1 and g.is_used(node.output[1]):
# No need for the scale.
return self.none(node, inspect.currentframe().f_lineno)
stash_type = g.get_attribute(node, "stash_type", exc=False)
stash_itype = 1 if stash_type is None else stash_type.i
cast_before = g.node_before(node.input[0])
if cast_before is None or cast_before.op_type != "Cast" or cast_before.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
to = g.get_attribute(cast_before, "to")
if to.i != stash_itype:
return self.none(node, inspect.currentframe().f_lineno)
if g.is_used_more_than_once(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
cast_afters = g.next_nodes(node.output[0])
if len(cast_afters) != 1:
return self.none(node, inspect.currentframe().f_lineno)
cast_after = cast_afters[0]
if cast_after.op_type != "Cast" or cast_after.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
to = g.get_attribute(cast_after, "to")
itype = g.get_type(cast_before.input[0])
if to.i != itype:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [cast_before, node, cast_after], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
cast_before: NodeProto,
node: NodeProto,
cast_after: NodeProto,
) -> List[NodeProto]:
itype = g.get_type(cast_before.input[0])
other = []
nodes = []
for i in node.input[1:]:
name = g.unique_name(f"{self.__class__.__name__}_{i}")
other.append(name)
nodes.append(
g.make_node(
"Cast",
[i],
[name],
to=itype,
name=f"{self.__class__.__name__}--cast--{node.name}",
)
)
new_node = g.make_node(
node.op_type,
[cast_before.input[0], *other],
[cast_after.output[0], *node.output[1:]],
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
domain=node.domain,
)
new_node.attribute.extend(node.attribute)
return [*nodes, new_node]
[docs]
class BatchNormalizationPattern(PatternOptimization):
"""
Checks that a BatchNormalization is really needed.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "BatchNormalization" or node.domain != "":
return self.none()
if len(node.output) > 1 and g.next_nodes(node.output[1]):
return self.none(node, inspect.currentframe().f_lineno)
if len(node.output) > 2 and g.next_nodes(node.output[2]):
return self.none(node, inspect.currentframe().f_lineno)
momentum = 0.9
epsilon = 1e-5
training_mode = 0
for att in node.attribute:
if att.name == "momentum":
momentum = att.f
elif att.name == "epsilon":
epsilon = att.f
elif att.name == "training_mode":
training_mode = att.i
if training_mode and momentum != 0:
return self.none(node, inspect.currentframe().f_lineno)
if epsilon != 0:
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[2]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[3]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[3]):
return self.none(node, inspect.currentframe().f_lineno)
# biases
for z in node.input[2:4]:
cst = g.get_computed_constant(z)
if cst is None:
return self.none(node, inspect.currentframe().f_lineno)
if cst.min() == cst.max() == 0:
continue
return self.none(node, inspect.currentframe().f_lineno)
# scales
for z in [node.input[1], node.input[4]]:
cst = g.get_computed_constant(z)
if cst is None:
return self.none(node, inspect.currentframe().f_lineno)
if cst.min() == cst.max() == 1:
continue
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
) -> List[NodeProto]:
new_node = g.make_node(
"Identity",
node.input[:1],
node.output[:1],
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node]
[docs]
class BatchNormalizationTrainingPattern(PatternOptimization):
"""
Checks that a BatchNormalization in training mode can be avoided.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "BatchNormalization" or node.domain != "":
return self.none()
if g.main_opset < 18:
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_rank(node.input[0]) and g.get_rank(node.input[0]) < 2:
return self.none(node, inspect.currentframe().f_lineno)
if len(node.output) > 1 and (
not g.has_rank(node.input[1]) or g.next_nodes(node.output[1])
):
return self.none(node, inspect.currentframe().f_lineno)
if len(node.output) > 2 and (
not g.has_rank(node.input[2]) or g.next_nodes(node.output[2])
):
return self.none(node, inspect.currentframe().f_lineno)
momentum = 0.9
training_mode = 0
for att in node.attribute:
if att.name == "momentum":
momentum = att.f
elif att.name == "training_mode":
training_mode = att.i
if not training_mode and momentum != 1:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
) -> List[NodeProto]:
nname = f"{self.__class__.__name__}--{node.name}"
rk = g.get_rank(node.input[0])
axes = tuple(np.delete(np.arange(rk), 1))
init_axes = g.make_initializer("", np.array(list(axes), dtype=np.int64))
mean_name = g.unique_name(f"{self.__class__.__name__}_mean_{node.input[0]}")
mean = g.make_node(
"ReduceMean", [node.input[0], init_axes], [mean_name], keepdims=1, name=nname
)
centered_name = g.unique_name(f"{self.__class__.__name__}_center_{node.input[0]}")
sub = g.make_node("Sub", [node.input[0], mean_name], [centered_name], name=nname)
x2 = g.unique_name(f"{self.__class__.__name__}_x2_{node.input[0]}")
mul2 = g.make_node("Mul", [centered_name, centered_name], [x2], name=nname)
var_name = g.unique_name(f"{self.__class__.__name__}_var_{node.input[0]}")
var = g.make_node("ReduceMean", [x2, init_axes], [var_name], keepdims=1, name=nname)
dtype = tensor_dtype_to_np_dtype(g.get_type(node.input[0]))
epsilon = g.get_attributes_with_default(node, epsilon=1e-5)["epsilon"]
init_epsilon = g.make_initializer("", np.array([epsilon], dtype=dtype))
vare_name = g.unique_name(f"{self.__class__.__name__}_vareps_{node.input[0]}")
add = g.make_node("Add", [var_name, init_epsilon], [vare_name], name=nname)
std_name = g.unique_name(f"{self.__class__.__name__}_vareps_{node.input[0]}")
sqrt = g.make_node("Sqrt", [vare_name], [std_name], name=nname)
new_shape = [1 for _ in range(rk)]
new_shape[1] = -1
new_shape = g.make_initializer("", np.array(new_shape, dtype=np.int64))
if g.get_rank(node.input[1]) == 1:
scale_name = g.unique_name(f"{self.__class__.__name__}_scale_{node.input[1]}")
scale = g.make_node(
"Reshape", [node.input[1], new_shape], [scale_name], name=nname
)
else:
scale_name = node.input[1]
scale = None
if g.get_rank(node.input[2]) == 1:
bias_name = g.unique_name(f"{self.__class__.__name__}_bias_{node.input[2]}")
bias = g.make_node("Reshape", [node.input[2], new_shape], [bias_name], name=nname)
else:
bias_name = node.input[2]
bias = None
scaled_name = g.unique_name(f"{self.__class__.__name__}_scaled_{node.input[1]}")
scaled = g.make_node("Div", [centered_name, std_name], [scaled_name], name=nname)
scaled2_name = g.unique_name(f"{self.__class__.__name__}_scaled2_{node.input[2]}")
scaled2 = g.make_node("Mul", [scaled_name, scale_name], [scaled2_name], name=nname)
final = g.make_node("Add", [scaled2_name, bias_name], [node.output[0]], name=nname)
return [
_
for _ in [mean, sub, mul2, var, add, sqrt, scale, bias, scaled, scaled2, final]
if _ is not None
]