import inspect
from typing import List, Optional, Tuple, Union
import numpy as np
from onnx import NodeProto
from ...xbuilder._onnx_helper import element_wise_binary_op_types
from ...xbuilder._shape_helper import all_int, DYNAMIC_SHAPE, STATIC_SHAPE
from ..patterns_api import MatchResult, PatternOptimization
[docs]
class ReshapePattern(PatternOptimization):
"""Checks that a Reshape is really needed."""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Reshape" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
shape = g.get_shape(node.input[0])
if not all_int(shape):
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(node.input[1]):
# It may be a symbolic shape.
return self.none(node, inspect.currentframe().f_lineno)
value = g.get_computed_constant(node.input[1])
if value is None:
return self.none(node, inspect.currentframe().f_lineno)
new_shape = tuple(int(i) for i in value)
if shape != new_shape:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
) -> List[NodeProto]:
new_node = g.make_node(
"Identity",
node.input,
node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node]
[docs]
class ReduceReshapePattern(PatternOptimization):
"""
Replaces the sequence Reduce* Reshape if reshape is only
introduces to deal with a dimension kept because keepdims=1.
"""
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if not node.op_type.startswith("Reduce") or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
att = g.get_attribute(node, "keepdims", exc=False)
keepdims = 1 if att is None else att.i
if keepdims == 0:
# not keeping the dimension so Reshape means to restore them.
return self.none(node, inspect.currentframe().f_lineno)
if len(node.input) == 2:
if not g.is_constant(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
axes = tuple(g.get_computed_constant(node.input[1]))
else:
if not g.has_rank(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
att = g.get_attribute(node, "axes", exc=False)
axes = tuple(range(g.get_rank(node.input[0]))) if att is None else tuple(att.ints)
next_nodes = g.next_nodes(node.output[0])
if len(next_nodes) != 1:
return self.none(node, inspect.currentframe().f_lineno)
next_node = next_nodes[0]
if next_node.op_type != "Reshape" or node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if next_node.input[0] != node.output[0]:
return self.none(node, inspect.currentframe().f_lineno)
if g.get_rank(node.input[0]) != g.get_rank(next_node.output[0]) + len(axes):
return self.none(node, inspect.currentframe().f_lineno)
if g.get_rank(next_node.output[0]) > 1:
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
set_axes = set(axes)
shape = g.get_shape(node.input[0])
reduced_shape = [s for i, s in enumerate(shape) if i not in set_axes]
reshaped_shape = g.get_shape(next_node.output[0])
if reduced_shape != reshaped_shape:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
next_node: NodeProto,
) -> List[NodeProto]:
axes = g.get_attribute(node, "axes", exc=False)
if axes is None:
new_node = g.make_node(
node.op_type,
node.input,
next_node.output,
keepdims=0,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node]
# older opset
new_node = g.make_node(
node.op_type,
node.input,
next_node.output,
keepdims=0,
axes=list(axes.ints),
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node]
[docs]
class ReshapeReshapePattern(PatternOptimization):
"""Replaces the sequence Reshape, Reshape by Reshape."""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Reshape" or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(node.output[0])
if len(next_nodes) != 1:
return self.none(node, inspect.currentframe().f_lineno)
next_node = next_nodes[0]
if next_node.op_type != "Reshape" or node.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if next_node.input[0] != node.output[0]:
return self.none(node, inspect.currentframe().f_lineno)
if g.is_constant(node.input[1]):
cst = g.get_computed_constant(node.input[1])
if -1 in cst.tolist():
# Then we only allow it the shape is static.
if not g.is_constant(next_node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
cst = g.get_computed_constant(next_node.input[1])
if cst.min() <= 0:
return self.none(node, inspect.currentframe().f_lineno)
if (
not g.has_rank(node.input[0])
or not g.has_rank(next_node.output[0])
or not g.has_rank(node.output[0])
):
return self.none(node, inspect.currentframe().f_lineno)
sh1 = g.builder.value_as_shape(node.input[1])
sh2 = g.builder.value_as_shape(next_node.input[1])
if (sh2 is None or (-1 in sh2 and 0 not in sh2)) and (sh1 is None or -1 in sh1):
return self.none(node, inspect.currentframe().f_lineno)
# If g.get_rank(node.input[0]) != g.get_rank(next_node.output[0]),
# the bet is, when the shape is not a constant, then using 0 is not really
# useful. Since 0 is only valid for ONNX, 0 should not be found
# in a non constant shape used to reshape.
# If it is a constant that should be ok too.
if not g.has_shape(node.input[0]) or not g.has_shape(next_node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
if (
g.is_constant(next_node.input[1])
and not self._applicable_reshape(
g.get_shape(node.input[0]),
g.get_shape(node.output[0]),
g.get_computed_constant(next_node.input[1]),
)
and (
g.is_constant(node.input[1])
or g.get_rank(node.output[0]) != g.get_rank(next_node.output[0])
)
):
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node, next_node], self.apply, insert_at=next_node)
@classmethod
def _applicable_reshape(
cls, shape1: DYNAMIC_SHAPE, shape2: DYNAMIC_SHAPE, att: STATIC_SHAPE
) -> Optional[STATIC_SHAPE]:
new_shape = []
m1 = False
for i, s in enumerate(att):
if s == 0:
if m1:
return None
if i >= len(shape1):
return None
new_shape.append(shape1[i])
elif s > 0:
new_shape.append(s)
elif m1:
return None
else:
# -1
m1 = True
new_shape.append(None)
if tuple(new_shape) == shape1:
return tuple(att)
# something needs to change
list_att = list(map(int, att))
if list_att.count(0) > 1 or (-1 in list_att and 0 in list_att):
return None
return tuple((-1 if s == 0 else s) for s in list_att)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node: NodeProto,
next_node: NodeProto,
) -> List[NodeProto]:
same_rank = g.get_rank(node.input[0]) != g.get_rank(next_node.output[0])
second_input = next_node.input[1]
pre_nodes = []
if (
same_rank
and g.get_rank(node.output[0]) == g.get_rank(next_node.output[0])
and g.is_constant(next_node.input[1])
):
cst = tuple(g.get_computed_constant(next_node.input[1]))
if 0 in cst:
if g.is_constant(node.input[1]):
shape0 = tuple(g.get_computed_constant(node.input[1]))
assert len(shape0) == len(cst), (
f"This should be true due to the first test but cst={cst}, "
f"shape0={shape0}"
)
new_shape = [(s if s != 0 else s0) for s, s0 in zip(cst, shape0)]
assert (
len(new_shape) >= len([s for s in new_shape if s != 0]) - 1
), f"new_shape={new_shape} has two -1. This is not possible."
second_input = g.make_initializer(
"",
np.array(new_shape, dtype=np.int64),
source="ReshapeReshapePattern.new_shape.1",
)
else:
# This code has one loop hole. It could produce shapes with two -1.
# Let's extract the missing information.
names = []
for axis, dim in enumerate(cst):
if dim == 0:
d_name = g.unique_name(f"{next_node.input[0]}--dim{axis}")
d_init = g.make_initializer(
"",
np.array([axis], dtype=np.int64),
source=f"ReshapeReshapePattern.axis.{axis}.1",
)
pre_nodes.append(
g.make_node(
"Gather",
[node.input[1], d_init],
[d_name],
axis=0,
name=f"{next_node.name}--axis{axis}",
)
)
names.append(d_name)
else:
d_init = g.make_initializer(
"",
np.array([dim], dtype=np.int64),
source=f"ReshapeReshapePattern.axis.{axis}.2",
)
names.append(d_init)
second_input = g.unique_name(f"{next_node.input[0]}--concat")
pre_nodes.append(
g.make_node(
"Concat",
names,
[second_input],
axis=0,
name=f"{next_node.name}--concat",
)
)
elif g.is_constant(next_node.input[1]):
cst = tuple(map(int, g.get_computed_constant(next_node.input[1])))
cst2 = self._applicable_reshape(
g.get_shape(node.input[0]), g.get_shape(node.output[0]), cst
)
if cst2 != cst:
second_input = g.make_initializer(
"",
np.array(cst2, dtype=np.int64),
source="ReshapeReshapePattern.new_shape.3",
)
new_node = g.make_node(
"Reshape",
[node.input[0], second_input],
next_node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=next_node.doc_string,
)
return [*pre_nodes, new_node]
[docs]
class Reshape2Of3Pattern(PatternOptimization):
"""
Replaces the reshapes around element-wise operators.
It can be 3 or 2 out of 3.
"""
_op_types = element_wise_binary_op_types()
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._op_types or node.domain != "":
return self.none()
if (
not g.has_shape(node.output[0])
or not g.has_shape(node.input[0])
or not g.has_shape(node.input[1])
):
# Shapes are missing. They should be populated as much as possible.
return self.none(node, inspect.currentframe().f_lineno)
shape_out = g.get_shape(node.output[0])
shape_in = g.get_shape(node.input[0]), g.get_shape(node.input[1])
if not (shape_out == shape_in[0] == shape_in[1]):
# Broadcasting is involved.
return self.none(node, inspect.currentframe().f_lineno)
next_nodes = g.next_nodes(node.output[0])
if len(next_nodes) > 1 or (len(next_nodes) == 0 and not g.is_output(node.output[0])):
return self.none(node, inspect.currentframe().f_lineno)
next_node = None if len(next_nodes) == 0 else next_nodes[0]
type_out = None if next_node is None else next_node.op_type
node_left = g.node_before(node.input[0])
node_right = g.node_before(node.input[1])
type_left = None if node_left is None else node_left.op_type
type_right = None if node_right is None else node_right.op_type
types = [type_left, type_right, type_out, node.op_type]
n_reshape = len([_ for _ in types if _ == "Reshape"])
if n_reshape < 2:
return self.none(node, inspect.currentframe().f_lineno)
if node_left is not None and node_left.op_type != "Reshape":
node_left = None
if node_right is not None and node_right.op_type != "Reshape":
node_right = None
if next_node is not None and next_node.op_type != "Reshape":
next_node = None
shapes = [
(
None
if (node_left is None or not g.has_shape(node_left.input[0]))
else g.get_shape(node_left.input[0])
),
(
None
if (node_right is None or not g.has_shape(node_right.input[0]))
else g.get_shape(node_right.input[0])
),
(
None
if (next_node is None or not g.has_shape(next_node.output[0]))
else g.get_shape(next_node.output[0])
),
]
ranks = [
(
None
if (node_left is None or not g.has_rank(node_left.input[0]))
else g.get_rank(node_left.input[0])
),
(
None
if (node_right is None or not g.has_rank(node_right.input[0]))
else g.get_rank(node_right.input[0])
),
(
None
if (next_node is None or not g.has_rank(next_node.output[0]))
else g.get_rank(next_node.output[0])
),
]
all_shapes = [_ for _ in shapes if _ is not None]
all_ranks = [_ for _ in ranks if _ is not None]
if len(set(all_shapes)) != 1 or len(set(all_ranks)) != 1 or len(all_shapes) < 2:
# Not the same shapes.
return self.none(node, inspect.currentframe().f_lineno)
nodes = [node_left, node_right, next_node, node]
return MatchResult(self, nodes, self.apply)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
node_left: NodeProto,
node_right: NodeProto,
next_node: NodeProto,
node: NodeProto,
) -> List[NodeProto]:
compute_shape_name = node_left.input[1] if node_right is None else node_right.input[1]
final_shape_name = compute_shape_name if next_node is None else next_node.input[1]
res = []
# node left
if node_left is None:
left_name = g.unique_name(f"{self.__class__.__name__}L_{node.input[0]}")
res.append(
g.make_node(
"Reshape",
[node.input[0], final_shape_name],
[left_name],
name=f"{self.__class__.__name__}--{node.name}",
)
)
elif g.is_used_more_than_once(node_left.output[0]):
res.append(node_left)
left_name = node_left.input[0]
else:
left_name = node_left.input[0]
# node right
if node_right is None:
right_name = g.unique_name(f"{self.__class__.__name__}R_{node.input[1]}")
res.append(
g.make_node(
"Reshape",
[node.input[1], final_shape_name],
[right_name],
name=f"{self.__class__.__name__}--{node.name}",
)
)
elif g.is_used_more_than_once(node_right.output[0]):
res.append(node_right)
right_name = node_right.input[0]
else:
right_name = node_right.input[0]
# node and next node
if next_node is None:
# Reshape is needed.
new_name = g.unique_name(f"{self.__class__.__name__}L_{node.output[0]}")
res.extend(
[
g.make_node(
node.op_type,
[left_name, right_name],
[new_name],
name=f"{self.__class__.__name__}--{node.name}",
),
g.make_node(
"Reshape",
[new_name, final_shape_name],
[node.output[0]],
name=f"{self.__class__.__name__}--{node.name}",
),
]
)
else:
main_node = g.make_node(
node.op_type,
[left_name, right_name],
[next_node.output[0]],
name=f"{self.__class__.__name__}--{node.name}",
)
res.append(main_node)
if g.is_used_more_than_once(node.output[0]):
res.append(
g.make_node(
"Reshape",
[main_node.output[0], compute_shape_name],
[node.output[0]],
name=f"{self.__class__.__name__}--{node.name}",
)
)
return res
[docs]
class ReshapeReshapeBinaryPattern(PatternOptimization):
"""
Moves two reshape operators beyond a binary operator
if it is possible.
"""
_op_types = element_wise_binary_op_types()
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type not in self._op_types or node.domain != "":
return self.none()
if g.is_used_more_than_once(node.input[0]) or g.is_used_more_than_once(node.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
left, right = g.node_before(node.input[0]), g.node_before(node.input[1])
if left is None or left.op_type != "Reshape" or left.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if right is None or right.op_type != "Reshape" or right.domain != "":
return self.none(node, inspect.currentframe().f_lineno)
if not g.is_constant(left.input[1]) or not g.is_constant(right.input[1]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
cst_left = g.get_computed_constant(left.input[1]).tolist()
cst_right = g.get_computed_constant(right.input[1]).tolist()
if cst_left != cst_right:
return self.none(node, inspect.currentframe().f_lineno)
shape1 = g.get_shape(left.input[0]) if g.has_shape(left.input[0]) else None
shape2 = g.get_shape(right.input[0]) if g.has_shape(right.input[0]) else None
if shape1 is None or shape2 is None or shape1 != shape2:
return self.none(node, inspect.currentframe().f_lineno)
# If there is not broadcast involved then it is ok.
# At this stage, we know shapes are equal before the reshaped operators
# and the same reshape is applied. So checking the output shape
# is not necesssary.
return MatchResult(self, [left, right, node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
left: NodeProto,
right: NodeProto,
node: NodeProto,
) -> List[NodeProto]:
new_node = g.make_node(
node.op_type,
[left.input[0], right.input[0]],
name=f"{self.__class__.__name__}--{node.name}",
)
reshape_node = g.make_node(
"Reshape",
[new_node.output[0], left.input[1]],
node.output,
name=f"{self.__class__.__name__}--{node.name}",
doc_string=node.doc_string,
)
return [new_node, reshape_node]
[docs]
class ConcatReshapePattern(PatternOptimization):
"""
Tries to reduce the number of nodes in the sequence Concat + Reshape
by replacing one of the dimension by -1.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Reshape" or node.domain != "":
return self.none()
gen = g.node_before(node.input[1])
if gen is None or gen.op_type != "Concat":
return self.none(node, inspect.currentframe().f_lineno)
op_types = {}
for i in gen.input:
if g.is_constant(i):
cst = g.get_computed_constant(i)
if cst is None:
return self.none(node, inspect.currentframe().f_lineno)
li = cst.tolist()
if -1 in li:
return self.none(node, inspect.currentframe().f_lineno)
else:
p = g.node_before(i)
if p is None:
return self.none(node, inspect.currentframe().f_lineno)
op_types[p.op_type] = op_types.get(p.op_type, 0) + 1
if len(op_types) == 1:
# only ony operator
op_type = list(op_types)[0] # noqa: RUF015
if op_type != "Shape":
return self.none(node, inspect.currentframe().f_lineno)
# Then we can replace any of the node by -1.
elif len(op_types) == 2:
if "Shape" not in set(op_types):
return self.none(node, inspect.currentframe().f_lineno)
total = sum(op_types.values())
if op_types["Shape"] != total - 1:
return self.none(node, inspect.currentframe().f_lineno)
if g.is_used_more_than_once(node.input[1]):
# Not really safe to do the replacement.
return MatchResult(self, [gen, node], self.apply)
return MatchResult(self, [gen, node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
concat: NodeProto,
reshape: NodeProto,
) -> List[NodeProto]:
m1 = g.make_initializer(
"",
np.array([-1], dtype=np.int64),
source="ConcatReshapePattern.m1",
)
inputs = []
done = False
last_shape = -1
for i in concat.input:
if g.is_constant(i):
inputs.append(i)
continue
p = g.node_before(i)
if p is None:
inputs.append(i)
continue
if p.op_type != "Shape":
inputs.append(m1)
done = True
continue
last_shape = len(inputs)
inputs.append(i)
if not done:
# only shape
assert last_shape != -1, f"last_shape={last_shape} but done={done}, unexpected"
inputs[last_shape] = m1
keep_concat = g.is_used_more_than_once(concat.output[0])
new_output = g.unique_name(f"{concat.output[0]}--concat")
res = [
g.make_node(
"Concat",
inputs,
[new_output],
name=f"{self.__class__.__name__}--{concat.name}",
doc_string=concat.doc_string,
axis=0,
),
g.make_node(
"Reshape",
[reshape.input[0], new_output],
reshape.output,
name=f"{self.__class__.__name__}--{reshape.name}",
doc_string=concat.doc_string,
),
]
if keep_concat:
return [concat, *res]
return res
[docs]
class StaticConcatReshapePattern(PatternOptimization):
"""
Tries to reduce the number of nodes in the sequence Concat + Reshape
by replacing one of the dimension by -1.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Reshape" or node.domain != "":
return self.none()
gen = g.node_before(node.input[1])
if gen is None or gen.op_type != "Concat":
return self.none(node, inspect.currentframe().f_lineno)
not_cst = []
for i in gen.input:
if g.is_constant(i):
cst = g.get_computed_constant(i)
if cst is None:
return self.none(node, inspect.currentframe().f_lineno)
li = cst.tolist()
if -1 in li:
return self.none(node, inspect.currentframe().f_lineno)
elif g.has_shape(i) and g.get_shape(i) == (1,):
not_cst.append(i)
else:
return self.none(node, inspect.currentframe().f_lineno)
if len(not_cst) != 1:
return self.none(node, inspect.currentframe().f_lineno)
if g.is_used_more_than_once(node.input[1]):
# Not really safe to do the replacement.
return MatchResult(self, [gen, node], self.apply)
return MatchResult(self, [gen, node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
concat: NodeProto,
reshape: NodeProto,
) -> List[NodeProto]:
m1 = g.make_initializer(
"",
np.array([-1], dtype=np.int64),
source="ConcatReshapePattern.m1",
)
inputs = []
done = False
for i in concat.input:
if g.is_constant(i):
inputs.append(i)
continue
if g.has_shape(i) and g.get_shape(i) == (1,):
assert not done, f"-1 was already added, input {i!r} cannot be replaced by -1."
inputs.append(m1)
done = True
continue
raise RuntimeError(
f"The pattern was allowed but input {i!r} "
f"is not a constant and its shape is not (1,)."
)
assert (
done
), f"-1 was not inserted, pattern {self.__class__.__name__} should not have matched."
keep_concat = g.is_used_more_than_once(concat.output[0])
new_output = g.unique_name(f"{concat.output[0]}--concat")
res = [
g.make_node(
"Concat",
inputs,
[new_output],
name=f"{self.__class__.__name__}--{concat.name}",
doc_string=concat.doc_string,
axis=0,
),
g.make_node(
"Reshape",
[reshape.input[0], new_output],
reshape.output,
name=f"{self.__class__.__name__}--{reshape.name}",
doc_string=concat.doc_string,
),
]
if keep_concat:
return [concat, *res]
return res
[docs]
class ShapeBasedEditDistanceReshapePattern(PatternOptimization):
"""
Tries to reduce the number of nodes in the sequence Concat + Reshape
by replacing one of the dimension by -1 or 0.
The pattern tries to align shape information to infer a static shape.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
@classmethod
def _prod(cls, sequence):
p = 1
for s in sequence:
if not isinstance(s, int):
return None
p *= s
return p
@classmethod
def _align_shapes(
cls, s1: DYNAMIC_SHAPE, s2: Tuple[Union[str, int]]
) -> Optional[Tuple[int, ...]]:
"""
Compute the edit distance (Levenshtein distance) between two shapes and
tries to align them in order to return a reshape argument with only integers.
"""
assert all(
isinstance(s, (int, str)) and s != -1 for s in s1
), f"Unsupported shape s1={s1}"
assert all(
isinstance(s, (int, str)) and s != -1 for s in s2
), f"Unsupported shape s2={s2}"
eps = 0.5
mat = np.full((len(s1) + 1, len(s2) + 1), max(len(s1), len(s2)) + 10, dtype=np.float32)
mat[0, 0] = 0
predecessor = {}
for i in range(1, len(s1) + 1):
for j in range(1, len(s2) + 1):
c_cmp = mat[i - 1, j - 1] + (
0
if s1[i - 1] == s2[j - 1]
else (1 if isinstance(s1[i - 1], int) and isinstance(s2[j - 1], int) else eps)
)
options = [(c_cmp, (1, 1, i - 1, j - 1))]
for ki in range(1, 5):
if i < ki:
break
ss1 = s1[i - ki : i]
vi = cls._prod(ss1)
for kj in range(1, 5):
if kj == 1 and ki == 1:
continue
if i - ki == 0 and j - kj != 0:
continue
if i - ki != 0 and j - kj == 0:
continue
if j < kj:
break
ss2 = s2[j - kj : j]
vj = cls._prod(ss2)
if vi is None or vj is None:
c1 = sum(isinstance(_, str) for _ in ss1)
c2 = sum(isinstance(_, str) for _ in ss2)
if c1 <= 1 and c2 <= 1:
options.append(
(mat[i - ki, j - kj] + eps, (ki, kj, i - ki, j - kj))
)
elif vi == vj:
options.append((mat[i - ki, j - kj], (ki, kj, i - ki, j - kj)))
mini = min(options)
mat[i, j], predecessor[i, j] = mini
# computed
if mat[len(s1), len(s2)] >= 1:
# No possible equivalence.
return None
last = predecessor[len(s1), len(s2)]
path = []
while last[2:] in predecessor:
path.append(last)
last = predecessor[last[2:]]
path.append(last)
new_shape = []
mone = 0
for di, dj, pi, pj in reversed(path):
sh1, sh2 = s1[pi : pi + di], s2[pj : pj + dj]
if all(isinstance(_, int) for _ in sh2):
new_shape.extend(sh2)
elif all(isinstance(_, int) for _ in sh1):
if len(sh2) == 1:
new_shape.append(cls._prod(sh1))
else:
return None
elif len(sh1) == len(sh2) == 1:
# They are equal and both strings
if pi == pj and s1[pi] == s2[pj]:
new_shape.append(0)
else:
new_shape.append(-1)
mone += 1
else:
for i in sh2:
if isinstance(i, str):
new_shape.append(-1)
mone += 1
else:
new_shape.append(i)
if mone > 1:
return None
assert (
None not in new_shape
), f"Unexpected inputs: new_shape={new_shape}, shape1={s1}, shape2={s2}"
return tuple(new_shape)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if node.op_type != "Reshape" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
sh1 = g.get_shape_renamed(node.input[0])
sh2 = g.get_shape_renamed(node.output[0])
aligned_reshape = self._align_shapes(sh1, sh2)
if aligned_reshape is None:
return self.none(node, inspect.currentframe().f_lineno)
assert len(aligned_reshape) == g.get_rank(node.output[0]), (
f"Issue with input shape {sh1}, output shape={sh2}, "
f"proposed new_shape {aligned_reshape}"
)
gen = g.node_before(node.input[1])
if gen is None or gen.op_type != "Concat":
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
reshape: NodeProto,
) -> List[NodeProto]:
aligned_reshape = self._align_shapes(
g.get_shape_renamed(reshape.input[0]), g.get_shape_renamed(reshape.output[0])
)
new_shape = g.make_initializer(
"",
np.array(aligned_reshape, dtype=np.int64),
source="EditDistanceReshapePattern.m1",
)
return [
g.make_node(
"Reshape",
[reshape.input[0], new_shape],
[reshape.output[0]],
name=f"{self.__class__.__name__}--{reshape.name}",
doc_string=reshape.doc_string,
)
]
[docs]
class ShapeBasedReshapeIsSqueezePattern(PatternOptimization):
"""
Replaces a replaces by a squeeze or unsqueeze pattern if possible.
It is only available for opset < 18.
"""
def __init__(self, verbose: int = 0, priority: int = 0):
super().__init__(verbose, priority)
@classmethod
def _squeeze_axes(
cls, s1: DYNAMIC_SHAPE, s2: Tuple[Union[str, int]]
) -> Optional[Tuple[int, ...]]:
if s1 == s2:
return None, None
sh1 = tuple(s for s in s1 if s != 1)
sh2 = tuple(s for s in s2 if s != 1)
if sh1 != sh2:
return None, None
if len(s1) < len(s2):
op_type = "Unsqueeze"
axes = cls._find_unsqueeze_axes(s1, s2)
else:
op_type = "Squeeze"
axes = cls._find_squeeze_axes(s1, s2)
if axes is None:
return None, None
return op_type, axes
@classmethod
def _find_squeeze_axes(cls, s1: DYNAMIC_SHAPE, s2: DYNAMIC_SHAPE) -> Tuple[int, ...]:
sh1 = tuple(s for s in s1 if s != 1)
if sh1 != s2:
return None
return tuple(i for i, s in enumerate(s1) if s == 1)
@classmethod
def _find_unsqueeze_axes(cls, s1: DYNAMIC_SHAPE, s2: DYNAMIC_SHAPE) -> Tuple[int, ...]:
return cls._find_squeeze_axes(s2, s1)
[docs]
def match(
self,
g: "GraphBuilderPatternOptimization", # noqa: F821
node: NodeProto,
matched: List[MatchResult],
) -> Optional[MatchResult]:
if g.main_opset < 18:
return self.none()
if node.op_type != "Reshape" or node.domain != "":
return self.none()
if not g.has_shape(node.input[0]):
return self.none(node, inspect.currentframe().f_lineno)
if not g.has_shape(node.output[0]):
return self.none(node, inspect.currentframe().f_lineno)
sh1 = g.get_shape_renamed(node.input[0])
sh2 = g.get_shape_renamed(node.output[0])
op_type, _axes = self._squeeze_axes(sh1, sh2)
if op_type is None:
return self.none(node, inspect.currentframe().f_lineno)
return MatchResult(self, [node], self.apply, insert_at=node)
[docs]
def apply(
self,
g: "GraphBuilder", # noqa: F821
reshape: NodeProto,
) -> List[NodeProto]:
op_type, axes = self._squeeze_axes(
g.get_shape_renamed(reshape.input[0]), g.get_shape_renamed(reshape.output[0])
)
new_axes = g.make_initializer(
"",
np.array(axes, dtype=np.int64),
source="ReshapeIsSqueezePattern.m1",
)
return [
g.make_node(
op_type,
[reshape.input[0], new_axes],
reshape.output,
name=f"{self.__class__.__name__}--{reshape.name}",
doc_string=reshape.doc_string,
)
]