import inspect
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
from onnx import NodeProto
from ...xbuilder._onnx_helper import (
element_wise_binary_op_types,
element_wise_op_cmp_types,
unary_like_op_types,
)
from ...xbuilder._shape_helper import all_int, DYNAMIC_SHAPE
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class ExpandPattern(PatternOptimization):
"""Checks that a Expand is really needed."""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape = g.get_shape(node.input[0])
if not all_int(shape):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
# It may be a symbolic shape.
return self.none(node, inspect.currentframe().f_lineno)
value = g.get_computed_constant(node.input[1])
if value is None:
return self.none(node, inspect.currentframe().f_lineno)
with g.builder.maybe_disable_fake_tensor_mode():
new_shape = tuple(int(i) for i in value)
if shape != new_shape:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
) -> List[NodeProto]:
new_node = g.make_node(
"Identity",
node.input,
node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node]
[docs]
class ExpandBroadcastPattern(PatternOptimization):
"""
Checks that a Expand is really needed before an element wise operator.
The objective is to save one allocation and let the next operator
do the expansion by broadcasting one input.
"""
_op_types = element_wise_binary_op_types() | element_wise_op_cmp_types()
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape = g.get_shape(node.input[0])
if not all_int(shape):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
# It may be a symbolic shape.
return self.none(node, inspect.currentframe().f_lineno)
value = g.get_computed_constant(node.input[1])
if value is None:
return self.none(node, inspect.currentframe().f_lineno)
with g.builder.maybe_disable_fake_tensor_mode():
new_shape = tuple(int(i) for i in value)
if g.is_used_more_than_once(node.output[0]):
# More than one output, not handled right now.
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(node.output[0])
assert len(next_nodes) == 1, "The previous test should have cleared out this case."
next_node = next_nodes[0]
if next_node.op_type not in self._op_types or next_node.domain != "":
# Not an element wise operator.
return self.none(node, inspect.currentframe().f_lineno)
other = next_node.input[1 if next_node.input[0] == node.output[0] else 0]
if not g.has_shape(other):
return self.none(node, inspect.currentframe().f_lineno)
other_shape = g.get_shape(other)
if new_shape != other_shape:
# Expand does not expand to the shape of the other element.
return self.none(node, inspect.currentframe().f_lineno)
if len(shape) != len(other_shape):
# Different ranks.
return self.none(node, inspect.currentframe().f_lineno)
for a, b in zip(shape, other_shape):
if not (a == b or a == 1 or b == 1):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_node], self.apply, insert_at=next_node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
next_node: NodeProto,
) -> List[NodeProto]:
if next_node.input[0] == node.output[0]:
inputs = [node.input[0], next_node.input[1]]
else:
inputs = [next_node.input[0], node.input[0]]
return [
g.make_node(
next_node.op_type,
inputs,
next_node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=next_node.doc_string,
)
]
[docs]
class ShapeBasedExpandBroadcastPattern(PatternOptimization):
"""
Similar to
:class:`experimental_experiment.xoptim.patterns.onnx_expand.ExpandBroadcastPattern`,
but it allows dynamic shapes as well. It does not look into the second
argument of Expand, it just infers than an expand is not needed for
a binary operator following just after.
"""
_op_types = element_wise_binary_op_types() | element_wise_op_cmp_types()
@classmethod
def _is_compatible_shapes_for_expand(
cls,
shape_left: DYNAMIC_SHAPE,
shape_right: DYNAMIC_SHAPE,
output_shape: Optional[DYNAMIC_SHAPE],
) -> bool:
"""
Checks that the binary operations of the two input shapes returns the output_shape.
Then no Expand node is needed.
"""
if output_shape is None:
return False
if max(len(shape_left), len(shape_right) if shape_right else 0) < len(output_shape):
return False
# Align shapes
if len(shape_left) < len(shape_right):
shape_left = (1,) * (len(shape_right) - len(shape_left)) + shape_left
elif len(shape_left) > len(shape_right):
shape_right = (1,) * (len(shape_left) - len(shape_right)) + shape_right
for left, right, out in zip(shape_left, shape_right, output_shape):
if isinstance(left, int):
if isinstance(right, int):
# static right
if left == 1:
if right != out:
return False
elif right == 1:
if left != out:
return False
else:
if left != right or left != out or right != out:
return False
else:
# dynamic right
if left == 1:
if right != out:
return False
else:
if left != right or left != out or right != out:
return False
else:
# dynamic left
if isinstance(right, int):
# static right
if right == 1:
if left != out:
return False
else:
if left != right or left != out or right != out:
return False
else:
# dynamic right
if left != right or left != out or right != out:
return False
return True
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._op_types or node.domain != "":
return self.none()
if (
not g.has_shape(node.output[0])
or not g.has_shape(node.input[0])
or not g.has_shape(node.input[1])
):
return self.none(node, inspect.currentframe().f_lineno)
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
before = [
None if n is None or n.op_type != "Expand" else n for n in [node_left, node_right]
]
if before == [None, None]:
return self.none(node, inspect.currentframe().f_lineno)
# At least one expand.
node_left, node_right = before
shape_left = g.get_shape_renamed(
node.input[0] if node_left is None else node_left.input[0]
)
shape_right = g.get_shape_renamed(
node.input[1] if node_right is None else node_right.input[0]
)
if self._is_compatible_shapes_for_expand(
shape_left, shape_right, g.get_shape_renamed(node.output[0])
):
if self.verbose:
print(
f"[{self.__class__.__name__}.match] {shape_left} "
f"{node.op_type} {shape_right} -> {g.get_shape_renamed(node.output[0])}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
# We could end up with the following case.
# shape_left = (1, 1, 'seq_length', 'cache_length + seq_length')
# shape_right = (1, 1, 'seq_length', 'cache_length + seq_length')
# output_shape = ('batch', 1, 'seq_length', 'cache_length + seq_length')
# When this happes, it could also be caught by another pattern.
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_left: NodeProto,
expand_right: NodeProto,
binary_node: NodeProto,
) -> List[NodeProto]:
nodes = []
if expand_left is not None and g.is_used_more_than_once(expand_left.output[0]):
nodes.append(expand_left)
if expand_right is not None and g.is_used_more_than_once(expand_right.output[0]):
nodes.append(expand_right)
assert (
not binary_node.attribute
), f"Binary operator should not have any attribute, binary_node={binary_node}"
return [
*nodes,
g.make_node(
binary_node.op_type,
[
binary_node.input[0] if expand_left is None else expand_left.input[0],
binary_node.input[1] if expand_right is None else expand_right.input[0],
],
binary_node.output,
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
]
[docs]
class ExpandSwapPattern(PatternOptimization):
"""
Tries to move a node Expand forward in the graph.
Expand + Exp can be changed into Exp + Expand.
Then Exp applies on a tensor of a smaller or equal size.
"""
_op_types = unary_like_op_types()
_other_types = {"NegXplus1", "ReplaceZero", "Pow"}
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
assert g.is_used(node.output[0]), (
f"The match should not even begin, {node.output[0]!r} "
f"is not used among {node.output} and type={node.op_type!r}"
)
if g.is_used_more_than_once(node.output[0]):
# More than one output so it probably must be done.
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(node.output[0])
assert len(next_nodes) == 1, "The previous test should have cleared out this case."
next_node = next_nodes[0]
if next_node.op_type not in self._other_types and (
next_node.op_type not in self._op_types or next_node.domain != ""
):
# Not an unary wise operator.
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
next_node: NodeProto,
) -> List[NodeProto]:
# We need to create a new name for the intermediate results.
# The optimizer cannot reuse an existing name if the new result
# has a different shape.
new_name = g.unique_name(f"{self.__class__.__name__}_{node.input[0]}")
unary = g.make_node(
next_node.op_type,
[node.input[0], *next_node.input[1:]],
[new_name],
name=f"{self.__class__.__name__}--{node.name}",
domain=next_node.domain,
doc_string=next_node.doc_string,
)
unary.attribute.extend(next_node.attribute)
expand = g.make_node(
node.op_type, # Expand
[new_name, node.input[1]],
[next_node.output[0]],
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [unary, expand]
[docs]
class ShapeBasedStaticExpandPattern(PatternOptimization):
"""
Compares input and output shapes to tell if the expand
can uses a constant as a second input.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
@classmethod
def _find_expand_shape(
cls, sh1: Tuple[Union[str, int], ...], sh2: Tuple[Union[str, int], ...]
) -> Tuple[int, ...]:
expand_shape = []
for s1, s2 in zip(sh1, sh2):
if s1 == s2:
expand_shape.append(1)
continue
if not isinstance(s1, int) or not isinstance(s2, int):
return None
expand_shape.append(s2)
return tuple(expand_shape)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if g.is_constant(node.input[1]):
# already done
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
sh1 = g.get_shape_renamed(node.input[0])
sh2 = g.get_shape_renamed(node.output[0])
if len(sh1) != len(sh2):
# We ignore that case for the time being.
return self.none(node, inspect.currentframe().f_lineno)
expand_shape = self._find_expand_shape(sh1, sh2)
if expand_shape is None:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
reshape: NodeProto,
) -> List[NodeProto]:
expand_shape = self._find_expand_shape(
g.get_shape_renamed(reshape.input[0]), g.get_shape_renamed(reshape.output[0])
)
new_shape = g.make_initializer(
"",
np.array(expand_shape, dtype=np.int64),
source=f"{self.__class__.__name__}.m1",
)
return [
g.make_node(
"Expand",
[reshape.input[0], new_shape],
reshape.output,
name=f"{self.__class__.__name__}--{reshape.name}",
doc_string=reshape.doc_string,
)
]
[docs]
class ShapeBasedExpandSwapPattern(PatternOptimization):
"""
Tries to move a node Expand forward in the graph
for a binary operator. The code is similar to
:class:`experimental_experiment.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern`
"""
_op_types = element_wise_binary_op_types() | element_wise_op_cmp_types()
@classmethod
def _broadcast_shape(
cls,
before_expand_shape: DYNAMIC_SHAPE,
other_term_shape: DYNAMIC_SHAPE,
exc: bool = False,
) -> Optional[DYNAMIC_SHAPE]:
if len(before_expand_shape) != len(other_term_shape):
d = abs(len(before_expand_shape) - len(other_term_shape))
if len(before_expand_shape) < len(other_term_shape):
before_expand_shape = (1,) * d + before_expand_shape
else:
other_term_shape = (1,) * d + other_term_shape
if len(before_expand_shape) != len(other_term_shape):
assert not exc, (
f"Unable to produce a broadcasted shape from "
f"{before_expand_shape} and {other_term_shape}"
)
return None
res = []
for a, b in zip(before_expand_shape, other_term_shape):
if a == b:
res.append(a)
elif a == 1:
res.append(b)
elif b == 1:
res.append(a)
else:
assert not exc, (
f"Unable to produce a broadcasted shape from "
f"{before_expand_shape} and {other_term_shape}"
)
return None
return tuple(res)
@classmethod
def _get_compatible_expand_shape_for_expand_swap(
cls,
before_expand_shape: DYNAMIC_SHAPE,
expanded_shape: DYNAMIC_SHAPE,
other_term_shape: DYNAMIC_SHAPE,
other_expanded_shape: Optional[DYNAMIC_SHAPE],
output_shape: DYNAMIC_SHAPE,
) -> Optional[DYNAMIC_SHAPE]:
"""
Something like that should work.
The function returns a shape or None is not possible.
.. code-block:: python
_get_compatible_expand_shape_for_expand_swap(
("batch", 1, 1, 1),
("batch", 1, "seq_length", "cache_length+seq_length"),
(1,),
None,
("batch", 1, "seq_length", "cache_length+seq_length"),
)
>>> ("batch", 1, "seq_length", "cache_length+seq_length")
)
"""
if other_expanded_shape is not None and (
other_expanded_shape != expanded_shape
or expanded_shape != output_shape
or len(before_expand_shape) != len(other_term_shape)
):
return None
if before_expand_shape == expanded_shape or expanded_shape == other_term_shape:
# This pattern is not meant for that.
return None
if output_shape != expanded_shape:
return None
if (
other_expanded_shape is None
and not ShapeBasedExpandBroadcastPattern._is_compatible_shapes_for_expand(
before_expand_shape,
other_term_shape,
cls._broadcast_shape(before_expand_shape, other_term_shape, exc=False),
)
):
return None
if (
other_expanded_shape is not None
and not ShapeBasedExpandBroadcastPattern._is_compatible_shapes_for_expand(
before_expand_shape,
other_term_shape,
cls._broadcast_shape(before_expand_shape, other_term_shape, exc=False),
)
):
return None
if other_expanded_shape is None:
return "expand_arg"
max_dim = cls._broadcast_shape(before_expand_shape, other_term_shape)
if max_dim == output_shape:
# Expand is not necessary at all.
return None
return tuple(1 if a == b else 0 for a, b in zip(max_dim, output_shape))
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._op_types or node.domain != "":
return self.none()
if (
not g.has_shape(node.output[0])
or not g.has_shape(node.input[0])
or not g.has_shape(node.input[1])
):
return self.none(node, inspect.currentframe().f_lineno)
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
before = [
None if n is None or n.op_type != "Expand" else n for n in [node_left, node_right]
]
if before == [None, None]:
return self.none(node, inspect.currentframe().f_lineno)
if None in before:
# Only one expand
node_left, node_right = before
shape_left = g.get_shape_renamed(
node.input[0] if node_left is None else node_left.input[0]
)
shape_right = g.get_shape_renamed(
node.input[1] if node_right is None else node_right.input[0]
)
before_expand_shape = shape_right if node_left is None else shape_left
expanded_shape = (
g.get_shape_renamed(node_right.output[0])
if node_left is None
else g.get_shape_renamed(node_left.output[0])
)
other_term_shape = shape_left if node_left is None else shape_right
output_shape = g.get_shape_renamed(node.output[0])
if self._get_compatible_expand_shape_for_expand_swap(
before_expand_shape, expanded_shape, other_term_shape, None, output_shape
):
if self.verbose:
print(
f"[{self.__class__.__name__}.match.1] {shape_left} "
f"{node.op_type} {shape_right} -> {output_shape}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
return self.none(node, inspect.currentframe().f_lineno)
# Both expand.
node_left, node_right = before
if node_left.input[1] != node_right.input[1]:
# It could work in that case if both expand have different
# shape argument but the code to make sure it is is not implemented.
return self.none(node, inspect.currentframe().f_lineno)
shape_left = g.get_shape_renamed(node_left.input[0])
shape_right = g.get_shape_renamed(node_right.input[0])
output_shape = g.get_shape_renamed(node.output[0])
expand_arg = self._get_compatible_expand_shape_for_expand_swap(
shape_left,
g.get_shape_renamed(node.input[0]),
shape_right,
g.get_shape_renamed(node.input[1]),
output_shape,
)
if expand_arg:
if self.verbose:
print(
f"[{self.__class__.__name__}.match.2] {shape_left} "
f"{node.op_type} {shape_right} -> {output_shape} with "
f"expand_arg={expand_arg}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_left: NodeProto,
expand_right: NodeProto,
binary_node: NodeProto,
) -> List[NodeProto]:
nodes = []
if expand_left is not None and g.is_used_more_than_once(expand_left.output[0]):
nodes.append(expand_left)
if expand_right is not None and g.is_used_more_than_once(expand_right.output[0]):
nodes.append(expand_right)
assert (
not binary_node.attribute
), f"Binary operator should not have any attribute, binary_node={binary_node}"
new_name = g.unique_name(f"{self.__class__.__name__}_{binary_node.output[0]}")
nodes.append(
g.make_node(
binary_node.op_type,
[
binary_node.input[0] if expand_left is None else expand_left.input[0],
binary_node.input[1] if expand_right is None else expand_right.input[0],
],
[new_name],
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
)
# One or two expand, same rewriting as the expand argument is the same.
return [
*nodes,
g.make_node(
"Expand",
[
new_name,
expand_left.input[1] if expand_right is None else expand_right.input[1],
],
binary_node.output,
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
]
[docs]
class ShapeBasedExpandBroadcastMatMulPattern(PatternOptimization):
"""
Similar to
:class:`experimental_experiment.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern`,
but works only with MatMul.
"""
@classmethod
def _is_compatible_shapes_for_expand(
cls,
shape_left: DYNAMIC_SHAPE,
shape_right: DYNAMIC_SHAPE,
output_shape: Optional[DYNAMIC_SHAPE],
) -> bool:
"""
Checks that the binary operations of the two input shapes returns the output_shape.
Then no Expand node is needed.
"""
if output_shape is None:
return False
if len(shape_left) < 2 or len(shape_right) < 2 or len(output_shape) < 2:
return False
return ShapeBasedExpandBroadcastPattern._is_compatible_shapes_for_expand(
shape_left[:-2], shape_right[:-2], output_shape[:-2]
)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "MatMul" or node.domain != "":
return self.none()
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
before = [
None if n is None or n.op_type != "Expand" else n for n in [node_left, node_right]
]
if before == [None, None]:
return self.none(node, inspect.currentframe().f_lineno)
# At least one expand.
node_left, node_right = before
shape_left = g.get_shape_renamed(
node.input[0] if node_left is None else node_left.input[0]
)
shape_right = g.get_shape_renamed(
node.input[1] if node_right is None else node_right.input[0]
)
if self._is_compatible_shapes_for_expand(
shape_left, shape_right, g.get_shape_renamed(node.output[0])
):
if self.verbose:
print(
f"[{self.__class__.__name__}.match] {shape_left} "
f"{node.op_type} {shape_right} -> {g.get_shape_renamed(node.output[0])}"
)
return MatchResult(self, [node_left, node_right, node], self.apply)
# We could end up with the following case.
# shape_left = (1, 1, 'seq_length', 'cache_length + seq_length')
# shape_right = (1, 1, 'seq_length', 'cache_length + seq_length')
# output_shape = ('batch', 1, 'seq_length', 'cache_length + seq_length')
# When this happes, it could also be caught by another pattern.
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_left: NodeProto,
expand_right: NodeProto,
binary_node: NodeProto,
) -> List[NodeProto]:
nodes = []
if expand_left is not None and g.is_used_more_than_once(expand_left.output[0]):
nodes.append(expand_left)
if expand_right is not None and g.is_used_more_than_once(expand_right.output[0]):
nodes.append(expand_right)
assert (
not binary_node.attribute
), f"Binary operator should not have any attribute, binary_node={binary_node}"
return [
*nodes,
g.make_node(
binary_node.op_type,
[
binary_node.input[0] if expand_left is None else expand_left.input[0],
binary_node.input[1] if expand_right is None else expand_right.input[0],
],
binary_node.output,
name=f"{self.__class__.__name__}--{binary_node.name}",
doc_string=binary_node.doc_string,
),
]
[docs]
class ShapeBasedExpandCastWhereSwapPattern(PatternOptimization):
"""Rewrites Where(Cast(X), X, cond)."""
@classmethod
def _compatible_shapes(
cls,
cond: DYNAMIC_SHAPE,
cst: DYNAMIC_SHAPE,
output: DYNAMIC_SHAPE,
before: DYNAMIC_SHAPE,
):
if cond != output:
return False
if len(before) < len(output):
before = (1,) * (len(output) - len(before)) + before
if len(cst) < len(output):
cst = (1,) * (len(output) - len(cst)) + cst
out = ShapeBasedExpandSwapPattern._broadcast_shape(before, cst)
if len(out) != len(output) or len(out) != len(before):
return False
return all(not (o != e and o != b) for b, o, e in zip(before, out, output))
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Where" or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.input[0]):
return self.none()
cast_node = g.node_before(node.input[0])
if cast_node is None or cast_node.op_type != "Cast" or cast_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if cast_node.input[0] not in node.input[1:]:
return self.none(node, inspect.currentframe().f_lineno)
expand_node = g.node_before(cast_node.input[0])
if expand_node is None or expand_node.op_type != "Expand" or expand_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
nodes = g.next_nodes(cast_node.input[0])
if len(nodes) != 2:
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[2]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(expand_node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
same_index = list(node.input).index(cast_node.input[0])
if self._compatible_shapes(
g.get_shape_renamed(node.input[0]),
g.get_shape_renamed(node.input[3 - same_index]),
g.get_shape_renamed(node.output[0]),
g.get_shape_renamed(expand_node.input[0]),
):
return MatchResult(self, [expand_node, cast_node, node], self.apply)
return self.none(node, inspect.currentframe().f_lineno)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
expand_node: NodeProto,
cast_node: NodeProto,
where_node: NodeProto,
) -> List[NodeProto]:
to = g.get_attribute(cast_node, "to").i
pos_index = list(where_node.input).index(expand_node.output[0])
cast_output = g.unique_name(f"{self.__class__.__name__}_{cast_node.output[0]}")
where_output = g.unique_name(f"{self.__class__.__name__}_{where_node.output[0]}")
return [
g.make_node(
cast_node.op_type,
[expand_node.input[0]],
[cast_output],
to=to,
name=f"{self.__class__.__name__}--{cast_node.name}",
doc_string=cast_node.doc_string,
),
g.make_node(
where_node.op_type,
(
[cast_output, expand_node.input[0], where_node.input[2]]
if pos_index == 1
else [cast_output, where_node.input[1], expand_node.input[0]]
),
[where_output],
name=f"{self.__class__.__name__}--{where_node.name}",
doc_string=where_node.doc_string,
),
g.make_node(
expand_node.op_type,
[where_output, expand_node.input[1]],
[where_node.output[0]],
name=f"{self.__class__.__name__}--{expand_node.name}",
doc_string=expand_node.doc_string,
),
]
[docs]
class ShapeBasedConcatExpandPattern(PatternOptimization):
"""Rewrites Expand(X, concat(...)) if possible."""
@classmethod
def _compatible_shapes(
cls,
g: "GraphBuilderPatternOptimization", # noqa: F821
shape: DYNAMIC_SHAPE,
expanded_shape: DYNAMIC_SHAPE,
concat_input: Sequence[str],
) -> Optional[int]:
if len(shape) != len(expanded_shape) or len(expanded_shape) != len(concat_input):
return None
position = []
for i, (a, b) in enumerate(zip(shape, expanded_shape)):
if a == b:
continue
position.append(i)
if len(position) != 1:
# It might be Identity but this should be caught by another pattern.
return None
return position[0]
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Expand" or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if g.is_constant(node.input[1]):
# no need
return self.none(node, inspect.currentframe().f_lineno)
concat_node = g.node_before(node.input[1])
if concat_node is None or concat_node.op_type != "Concat" or concat_node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.input[0]) or not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape1 = g.get_shape_renamed(node.input[0])
shape2 = g.get_shape_renamed(node.output[0])
index = self._compatible_shapes(g, shape1, shape2, concat_node.input)
if index is None:
return self.none(node, inspect.currentframe().f_lineno)
# checking the other values are not 1
if all(
(i == index or (g.is_constant(name) and g.get_constant_scalar(name) == 1))
for i, name in enumerate(concat_node.input)
):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [concat_node, node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
concat_node: NodeProto,
expand_node: NodeProto,
) -> List[NodeProto]:
shape1 = g.get_shape_renamed(expand_node.input[0])
shape2 = g.get_shape_renamed(expand_node.output[0])
index = self._compatible_shapes(g, shape1, shape2, concat_node.input)
init1 = g.make_initializer(
g.unique_name("init7_1"), g.ONE, source="ShapeBasedConcatExpandPattern.1"
)
new_input = [
(iname if i == index else init1) for i, iname in enumerate(concat_node.input)
]
new_name = g.unique_name(concat_node.output[0])
return [
g.make_node(
"Concat",
new_input,
[new_name],
axis=0,
name=f"{self.__class__.__name__}--{concat_node.name}",
doc_string=concat_node.doc_string,
),
g.make_node(
"Expand",
[expand_node.input[0], new_name],
expand_node.output,
name=f"{self.__class__.__name__}--{expand_node.name}",
doc_string=expand_node.doc_string,
),
]