Source code for experimental_experiment.xoptim.patterns.onnx_reshape

import inspect
from typing import List, Optional, Tuple, Union
import numpy as np
from onnx import NodeProto
from ...xbuilder._onnx_helper import element_wise_binary_op_types
from ...xbuilder._shape_helper import all_int, DYNAMIC_SHAPE, STATIC_SHAPE
from ..patterns_api import MatchResult, PatternOptimization


[docs] class ReshapePattern(PatternOptimization): """Checks that a Reshape is really needed.""" def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) shape = g.get_shape(node.input[0]) if not all_int(shape): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[1]): # It may be a symbolic shape. return self.none(node, inspect.currentframe().f_lineno) value = g.get_computed_constant(node.input[1]) if value is None: return self.none(node, inspect.currentframe().f_lineno) new_shape = tuple(int(i) for i in value) if shape != new_shape: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( "Identity", node.input, node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node]
[docs] class ReduceReshapePattern(PatternOptimization): """ Replaces the sequence Reduce* Reshape if reshape is only introduces to deal with a dimension kept because keepdims=1. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if not node.op_type.startswith("Reduce") or node.domain != "": return self.none() if g.is_used_more_than_once(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) att = g.get_attribute(node, "keepdims", exc=False) keepdims = 1 if att is None else att.i if keepdims == 0: # not keeping the dimension so Reshape means to restore them. return self.none(node, inspect.currentframe().f_lineno) if len(node.input) == 2: if not g.is_constant(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) axes = tuple(g.get_computed_constant(node.input[1])) else: if not g.has_rank(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) att = g.get_attribute(node, "axes", exc=False) axes = tuple(range(g.get_rank(node.input[0]))) if att is None else tuple(att.ints) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) next_node = next_nodes[0] if next_node.op_type != "Reshape" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if next_node.input[0] != node.output[0]: return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(node.input[0]) != g.get_rank(next_node.output[0]) + len(axes): return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(next_node.output[0]) > 1: if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) set_axes = set(axes) shape = g.get_shape(node.input[0]) reduced_shape = [s for i, s in enumerate(shape) if i not in set_axes] reshaped_shape = g.get_shape(next_node.output[0]) if reduced_shape != reshaped_shape: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, next_node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, next_node: NodeProto, ) -> List[NodeProto]: axes = g.get_attribute(node, "axes", exc=False) if axes is None: new_node = g.make_node( node.op_type, node.input, next_node.output, keepdims=0, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node] # older opset new_node = g.make_node( node.op_type, node.input, next_node.output, keepdims=0, axes=list(axes.ints), name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node]
[docs] class ReshapeReshapePattern(PatternOptimization): """Replaces the sequence Reshape, Reshape by Reshape.""" def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() if g.is_used_more_than_once(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) next_node = next_nodes[0] if next_node.op_type != "Reshape" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if next_node.input[0] != node.output[0]: return self.none(node, inspect.currentframe().f_lineno) if g.is_constant(node.input[1]): cst = g.get_computed_constant(node.input[1]) if -1 in cst.tolist(): # Then we only allow it the shape is static. if not g.is_constant(next_node.input[1]): return self.none(node, inspect.currentframe().f_lineno) cst = g.get_computed_constant(next_node.input[1]) if cst.min() <= 0: return self.none(node, inspect.currentframe().f_lineno) if ( not g.has_rank(node.input[0]) or not g.has_rank(next_node.output[0]) or not g.has_rank(node.output[0]) ): return self.none(node, inspect.currentframe().f_lineno) sh1 = g.builder.value_as_shape(node.input[1]) sh2 = g.builder.value_as_shape(next_node.input[1]) if (sh2 is None or (-1 in sh2 and 0 not in sh2)) and (sh1 is None or -1 in sh1): return self.none(node, inspect.currentframe().f_lineno) # If g.get_rank(node.input[0]) != g.get_rank(next_node.output[0]), # the bet is, when the shape is not a constant, then using 0 is not really # useful. Since 0 is only valid for ONNX, 0 should not be found # in a non constant shape used to reshape. # If it is a constant that should be ok too. if not g.has_shape(node.input[0]) or not g.has_shape(next_node.output[0]): return self.none(node, inspect.currentframe().f_lineno) if ( g.is_constant(next_node.input[1]) and not self._applicable_reshape( g.get_shape(node.input[0]), g.get_shape(node.output[0]), g.get_computed_constant(next_node.input[1]), ) and ( g.is_constant(node.input[1]) or g.get_rank(node.output[0]) != g.get_rank(next_node.output[0]) ) ): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, next_node], self.apply, insert_at=next_node)
@classmethod def _applicable_reshape( cls, shape1: DYNAMIC_SHAPE, shape2: DYNAMIC_SHAPE, att: STATIC_SHAPE ) -> Optional[STATIC_SHAPE]: new_shape = [] m1 = False for i, s in enumerate(att): if s == 0: if m1: return None if i >= len(shape1): return None new_shape.append(shape1[i]) elif s > 0: new_shape.append(s) elif m1: return None else: # -1 m1 = True new_shape.append(None) if tuple(new_shape) == shape1: return tuple(att) # something needs to change list_att = list(map(int, att)) if list_att.count(0) > 1 or (-1 in list_att and 0 in list_att): return None return tuple((-1 if s == 0 else s) for s in list_att)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, next_node: NodeProto, ) -> List[NodeProto]: same_rank = g.get_rank(node.input[0]) != g.get_rank(next_node.output[0]) second_input = next_node.input[1] pre_nodes = [] if ( same_rank and g.get_rank(node.output[0]) == g.get_rank(next_node.output[0]) and g.is_constant(next_node.input[1]) ): cst = tuple(g.get_computed_constant(next_node.input[1])) if 0 in cst: if g.is_constant(node.input[1]): shape0 = tuple(g.get_computed_constant(node.input[1])) assert len(shape0) == len(cst), ( f"This should be true due to the first test but cst={cst}, " f"shape0={shape0}" ) new_shape = [(s if s != 0 else s0) for s, s0 in zip(cst, shape0)] assert ( len(new_shape) >= len([s for s in new_shape if s != 0]) - 1 ), f"new_shape={new_shape} has two -1. This is not possible." second_input = g.make_initializer( "", np.array(new_shape, dtype=np.int64), source="ReshapeReshapePattern.new_shape.1", ) else: # This code has one loop hole. It could produce shapes with two -1. # Let's extract the missing information. names = [] for axis, dim in enumerate(cst): if dim == 0: d_name = g.unique_name(f"{next_node.input[0]}--dim{axis}") d_init = g.make_initializer( "", np.array([axis], dtype=np.int64), source=f"ReshapeReshapePattern.axis.{axis}.1", ) pre_nodes.append( g.make_node( "Gather", [node.input[1], d_init], [d_name], axis=0, name=f"{next_node.name}--axis{axis}", ) ) names.append(d_name) else: d_init = g.make_initializer( "", np.array([dim], dtype=np.int64), source=f"ReshapeReshapePattern.axis.{axis}.2", ) names.append(d_init) second_input = g.unique_name(f"{next_node.input[0]}--concat") pre_nodes.append( g.make_node( "Concat", names, [second_input], axis=0, name=f"{next_node.name}--concat", ) ) elif g.is_constant(next_node.input[1]): cst = tuple(map(int, g.get_computed_constant(next_node.input[1]))) cst2 = self._applicable_reshape( g.get_shape(node.input[0]), g.get_shape(node.output[0]), cst ) if cst2 != cst: second_input = g.make_initializer( "", np.array(cst2, dtype=np.int64), source="ReshapeReshapePattern.new_shape.3", ) new_node = g.make_node( "Reshape", [node.input[0], second_input], next_node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=next_node.doc_string, ) return [*pre_nodes, new_node]
[docs] class Reshape2Of3Pattern(PatternOptimization): """ Replaces the reshapes around element-wise operators. It can be 3 or 2 out of 3. """ _op_types = element_wise_binary_op_types()
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in self._op_types or node.domain != "": return self.none() if ( not g.has_shape(node.output[0]) or not g.has_shape(node.input[0]) or not g.has_shape(node.input[1]) ): # Shapes are missing. They should be populated as much as possible. return self.none(node, inspect.currentframe().f_lineno) shape_out = g.get_shape(node.output[0]) shape_in = g.get_shape(node.input[0]), g.get_shape(node.input[1]) if not (shape_out == shape_in[0] == shape_in[1]): # Broadcasting is involved. return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) > 1 or (len(next_nodes) == 0 and not g.is_output(node.output[0])): return self.none(node, inspect.currentframe().f_lineno) next_node = None if len(next_nodes) == 0 else next_nodes[0] type_out = None if next_node is None else next_node.op_type node_left = g.node_before(node.input[0]) node_right = g.node_before(node.input[1]) type_left = None if node_left is None else node_left.op_type type_right = None if node_right is None else node_right.op_type types = [type_left, type_right, type_out, node.op_type] n_reshape = len([_ for _ in types if _ == "Reshape"]) if n_reshape < 2: return self.none(node, inspect.currentframe().f_lineno) if node_left is not None and node_left.op_type != "Reshape": node_left = None if node_right is not None and node_right.op_type != "Reshape": node_right = None if next_node is not None and next_node.op_type != "Reshape": next_node = None shapes = [ ( None if (node_left is None or not g.has_shape(node_left.input[0])) else g.get_shape(node_left.input[0]) ), ( None if (node_right is None or not g.has_shape(node_right.input[0])) else g.get_shape(node_right.input[0]) ), ( None if (next_node is None or not g.has_shape(next_node.output[0])) else g.get_shape(next_node.output[0]) ), ] ranks = [ ( None if (node_left is None or not g.has_rank(node_left.input[0])) else g.get_rank(node_left.input[0]) ), ( None if (node_right is None or not g.has_rank(node_right.input[0])) else g.get_rank(node_right.input[0]) ), ( None if (next_node is None or not g.has_rank(next_node.output[0])) else g.get_rank(next_node.output[0]) ), ] all_shapes = [_ for _ in shapes if _ is not None] all_ranks = [_ for _ in ranks if _ is not None] if len(set(all_shapes)) != 1 or len(set(all_ranks)) != 1 or len(all_shapes) < 2: # Not the same shapes. return self.none(node, inspect.currentframe().f_lineno) nodes = [node_left, node_right, next_node, node] return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node_left: NodeProto, node_right: NodeProto, next_node: NodeProto, node: NodeProto, ) -> List[NodeProto]: compute_shape_name = node_left.input[1] if node_right is None else node_right.input[1] final_shape_name = compute_shape_name if next_node is None else next_node.input[1] res = [] # node left if node_left is None: left_name = g.unique_name(f"{self.__class__.__name__}L_{node.input[0]}") res.append( g.make_node( "Reshape", [node.input[0], final_shape_name], [left_name], name=f"{self.__class__.__name__}--{node.name}", ) ) elif g.is_used_more_than_once(node_left.output[0]): res.append(node_left) left_name = node_left.input[0] else: left_name = node_left.input[0] # node right if node_right is None: right_name = g.unique_name(f"{self.__class__.__name__}R_{node.input[1]}") res.append( g.make_node( "Reshape", [node.input[1], final_shape_name], [right_name], name=f"{self.__class__.__name__}--{node.name}", ) ) elif g.is_used_more_than_once(node_right.output[0]): res.append(node_right) right_name = node_right.input[0] else: right_name = node_right.input[0] # node and next node if next_node is None: # Reshape is needed. new_name = g.unique_name(f"{self.__class__.__name__}L_{node.output[0]}") res.extend( [ g.make_node( node.op_type, [left_name, right_name], [new_name], name=f"{self.__class__.__name__}--{node.name}", ), g.make_node( "Reshape", [new_name, final_shape_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ), ] ) else: main_node = g.make_node( node.op_type, [left_name, right_name], [next_node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ) res.append(main_node) if g.is_used_more_than_once(node.output[0]): res.append( g.make_node( "Reshape", [main_node.output[0], compute_shape_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ) ) return res
[docs] class ReshapeReshapeBinaryPattern(PatternOptimization): """ Moves two reshape operators beyond a binary operator if it is possible. """ _op_types = element_wise_binary_op_types()
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in self._op_types or node.domain != "": return self.none() if g.is_used_more_than_once(node.input[0]) or g.is_used_more_than_once(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) left, right = g.node_before(node.input[0]), g.node_before(node.input[1]) if left is None or left.op_type != "Reshape" or left.domain != "": return self.none(node, inspect.currentframe().f_lineno) if right is None or right.op_type != "Reshape" or right.domain != "": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(left.input[1]) or not g.is_constant(right.input[1]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) cst_left = g.get_computed_constant(left.input[1]).tolist() cst_right = g.get_computed_constant(right.input[1]).tolist() if cst_left != cst_right: return self.none(node, inspect.currentframe().f_lineno) shape1 = g.get_shape(left.input[0]) if g.has_shape(left.input[0]) else None shape2 = g.get_shape(right.input[0]) if g.has_shape(right.input[0]) else None if shape1 is None or shape2 is None or shape1 != shape2: return self.none(node, inspect.currentframe().f_lineno) # If there is not broadcast involved then it is ok. # At this stage, we know shapes are equal before the reshaped operators # and the same reshape is applied. So checking the output shape # is not necesssary. return MatchResult(self, [left, right, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 left: NodeProto, right: NodeProto, node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( node.op_type, [left.input[0], right.input[0]], name=f"{self.__class__.__name__}--{node.name}", ) reshape_node = g.make_node( "Reshape", [new_node.output[0], left.input[1]], node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node, reshape_node]
[docs] class ConcatReshapePattern(PatternOptimization): """ Tries to reduce the number of nodes in the sequence Concat + Reshape by replacing one of the dimension by -1. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() gen = g.node_before(node.input[1]) if gen is None or gen.op_type != "Concat": return self.none(node, inspect.currentframe().f_lineno) op_types = {} for i in gen.input: if g.is_constant(i): cst = g.get_computed_constant(i) if cst is None: return self.none(node, inspect.currentframe().f_lineno) li = cst.tolist() if -1 in li: return self.none(node, inspect.currentframe().f_lineno) else: p = g.node_before(i) if p is None: return self.none(node, inspect.currentframe().f_lineno) op_types[p.op_type] = op_types.get(p.op_type, 0) + 1 if len(op_types) == 1: # only ony operator op_type = list(op_types)[0] # noqa: RUF015 if op_type != "Shape": return self.none(node, inspect.currentframe().f_lineno) # Then we can replace any of the node by -1. elif len(op_types) == 2: if "Shape" not in set(op_types): return self.none(node, inspect.currentframe().f_lineno) total = sum(op_types.values()) if op_types["Shape"] != total - 1: return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node.input[1]): # Not really safe to do the replacement. return MatchResult(self, [gen, node], self.apply) return MatchResult(self, [gen, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 concat: NodeProto, reshape: NodeProto, ) -> List[NodeProto]: m1 = g.make_initializer( "", np.array([-1], dtype=np.int64), source="ConcatReshapePattern.m1", ) inputs = [] done = False last_shape = -1 for i in concat.input: if g.is_constant(i): inputs.append(i) continue p = g.node_before(i) if p is None: inputs.append(i) continue if p.op_type != "Shape": inputs.append(m1) done = True continue last_shape = len(inputs) inputs.append(i) if not done: # only shape assert last_shape != -1, f"last_shape={last_shape} but done={done}, unexpected" inputs[last_shape] = m1 keep_concat = g.is_used_more_than_once(concat.output[0]) new_output = g.unique_name(f"{concat.output[0]}--concat") res = [ g.make_node( "Concat", inputs, [new_output], name=f"{self.__class__.__name__}--{concat.name}", doc_string=concat.doc_string, axis=0, ), g.make_node( "Reshape", [reshape.input[0], new_output], reshape.output, name=f"{self.__class__.__name__}--{reshape.name}", doc_string=concat.doc_string, ), ] if keep_concat: return [concat, *res] return res
[docs] class StaticConcatReshapePattern(PatternOptimization): """ Tries to reduce the number of nodes in the sequence Concat + Reshape by replacing one of the dimension by -1. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() gen = g.node_before(node.input[1]) if gen is None or gen.op_type != "Concat": return self.none(node, inspect.currentframe().f_lineno) not_cst = [] for i in gen.input: if g.is_constant(i): cst = g.get_computed_constant(i) if cst is None: return self.none(node, inspect.currentframe().f_lineno) li = cst.tolist() if -1 in li: return self.none(node, inspect.currentframe().f_lineno) elif g.has_shape(i) and g.get_shape(i) == (1,): not_cst.append(i) else: return self.none(node, inspect.currentframe().f_lineno) if len(not_cst) != 1: return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node.input[1]): # Not really safe to do the replacement. return MatchResult(self, [gen, node], self.apply) return MatchResult(self, [gen, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 concat: NodeProto, reshape: NodeProto, ) -> List[NodeProto]: m1 = g.make_initializer( "", np.array([-1], dtype=np.int64), source="ConcatReshapePattern.m1", ) inputs = [] done = False for i in concat.input: if g.is_constant(i): inputs.append(i) continue if g.has_shape(i) and g.get_shape(i) == (1,): assert not done, f"-1 was already added, input {i!r} cannot be replaced by -1." inputs.append(m1) done = True continue raise RuntimeError( f"The pattern was allowed but input {i!r} " f"is not a constant and its shape is not (1,)." ) assert ( done ), f"-1 was not inserted, pattern {self.__class__.__name__} should not have matched." keep_concat = g.is_used_more_than_once(concat.output[0]) new_output = g.unique_name(f"{concat.output[0]}--concat") res = [ g.make_node( "Concat", inputs, [new_output], name=f"{self.__class__.__name__}--{concat.name}", doc_string=concat.doc_string, axis=0, ), g.make_node( "Reshape", [reshape.input[0], new_output], reshape.output, name=f"{self.__class__.__name__}--{reshape.name}", doc_string=concat.doc_string, ), ] if keep_concat: return [concat, *res] return res
[docs] class ShapeBasedEditDistanceReshapePattern(PatternOptimization): """ Tries to reduce the number of nodes in the sequence Concat + Reshape by replacing one of the dimension by -1 or 0. The pattern tries to align shape information to infer a static shape. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority) @classmethod def _prod(cls, sequence): p = 1 for s in sequence: if not isinstance(s, int): return None p *= s return p @classmethod def _align_shapes( cls, s1: DYNAMIC_SHAPE, s2: Tuple[Union[str, int]] ) -> Optional[Tuple[int, ...]]: """ Compute the edit distance (Levenshtein distance) between two shapes and tries to align them in order to return a reshape argument with only integers. """ assert all( isinstance(s, (int, str)) and s != -1 for s in s1 ), f"Unsupported shape s1={s1}" assert all( isinstance(s, (int, str)) and s != -1 for s in s2 ), f"Unsupported shape s2={s2}" eps = 0.5 mat = np.full((len(s1) + 1, len(s2) + 1), max(len(s1), len(s2)) + 10, dtype=np.float32) mat[0, 0] = 0 predecessor = {} for i in range(1, len(s1) + 1): for j in range(1, len(s2) + 1): c_cmp = mat[i - 1, j - 1] + ( 0 if s1[i - 1] == s2[j - 1] else (1 if isinstance(s1[i - 1], int) and isinstance(s2[j - 1], int) else eps) ) options = [(c_cmp, (1, 1, i - 1, j - 1))] for ki in range(1, 5): if i < ki: break ss1 = s1[i - ki : i] vi = cls._prod(ss1) for kj in range(1, 5): if kj == 1 and ki == 1: continue if i - ki == 0 and j - kj != 0: continue if i - ki != 0 and j - kj == 0: continue if j < kj: break ss2 = s2[j - kj : j] vj = cls._prod(ss2) if vi is None or vj is None: c1 = sum(isinstance(_, str) for _ in ss1) c2 = sum(isinstance(_, str) for _ in ss2) if c1 <= 1 and c2 <= 1: options.append( (mat[i - ki, j - kj] + eps, (ki, kj, i - ki, j - kj)) ) elif vi == vj: options.append((mat[i - ki, j - kj], (ki, kj, i - ki, j - kj))) mini = min(options) mat[i, j], predecessor[i, j] = mini # computed if mat[len(s1), len(s2)] >= 1: # No possible equivalence. return None last = predecessor[len(s1), len(s2)] path = [] while last[2:] in predecessor: path.append(last) last = predecessor[last[2:]] path.append(last) new_shape = [] mone = 0 for di, dj, pi, pj in reversed(path): sh1, sh2 = s1[pi : pi + di], s2[pj : pj + dj] if all(isinstance(_, int) for _ in sh2): new_shape.extend(sh2) elif all(isinstance(_, int) for _ in sh1): if len(sh2) == 1: new_shape.append(cls._prod(sh1)) else: return None elif len(sh1) == len(sh2) == 1: # They are equal and both strings if pi == pj and s1[pi] == s2[pj]: new_shape.append(0) else: new_shape.append(-1) mone += 1 else: for i in sh2: if isinstance(i, str): new_shape.append(-1) mone += 1 else: new_shape.append(i) if mone > 1: return None assert ( None not in new_shape ), f"Unexpected inputs: new_shape={new_shape}, shape1={s1}, shape2={s2}" return tuple(new_shape)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) sh1 = g.get_shape_renamed(node.input[0]) sh2 = g.get_shape_renamed(node.output[0]) aligned_reshape = self._align_shapes(sh1, sh2) if aligned_reshape is None: return self.none(node, inspect.currentframe().f_lineno) assert len(aligned_reshape) == g.get_rank(node.output[0]), ( f"Issue with input shape {sh1}, output shape={sh2}, " f"proposed new_shape {aligned_reshape}" ) gen = g.node_before(node.input[1]) if gen is None or gen.op_type != "Concat": return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 reshape: NodeProto, ) -> List[NodeProto]: aligned_reshape = self._align_shapes( g.get_shape_renamed(reshape.input[0]), g.get_shape_renamed(reshape.output[0]) ) new_shape = g.make_initializer( "", np.array(aligned_reshape, dtype=np.int64), source="EditDistanceReshapePattern.m1", ) return [ g.make_node( "Reshape", [reshape.input[0], new_shape], [reshape.output[0]], name=f"{self.__class__.__name__}--{reshape.name}", doc_string=reshape.doc_string, ) ]
[docs] class ShapeBasedReshapeIsSqueezePattern(PatternOptimization): """ Replaces a replaces by a squeeze or unsqueeze pattern if possible. It is only available for opset < 18. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority) @classmethod def _squeeze_axes( cls, s1: DYNAMIC_SHAPE, s2: Tuple[Union[str, int]] ) -> Optional[Tuple[int, ...]]: if s1 == s2: return None, None sh1 = tuple(s for s in s1 if s != 1) sh2 = tuple(s for s in s2 if s != 1) if sh1 != sh2: return None, None if len(s1) < len(s2): op_type = "Unsqueeze" axes = cls._find_unsqueeze_axes(s1, s2) else: op_type = "Squeeze" axes = cls._find_squeeze_axes(s1, s2) if axes is None: return None, None return op_type, axes @classmethod def _find_squeeze_axes(cls, s1: DYNAMIC_SHAPE, s2: DYNAMIC_SHAPE) -> Tuple[int, ...]: sh1 = tuple(s for s in s1 if s != 1) if sh1 != s2: return None return tuple(i for i, s in enumerate(s1) if s == 1) @classmethod def _find_unsqueeze_axes(cls, s1: DYNAMIC_SHAPE, s2: DYNAMIC_SHAPE) -> Tuple[int, ...]: return cls._find_squeeze_axes(s2, s1)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if g.main_opset < 18: return self.none() if node.op_type != "Reshape" or node.domain != "": return self.none() if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) sh1 = g.get_shape_renamed(node.input[0]) sh2 = g.get_shape_renamed(node.output[0]) op_type, _axes = self._squeeze_axes(sh1, sh2) if op_type is None: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 reshape: NodeProto, ) -> List[NodeProto]: op_type, axes = self._squeeze_axes( g.get_shape_renamed(reshape.input[0]), g.get_shape_renamed(reshape.output[0]) ) new_axes = g.make_initializer( "", np.array(axes, dtype=np.int64), source="ReshapeIsSqueezePattern.m1", ) return [ g.make_node( op_type, [reshape.input[0], new_axes], reshape.output, name=f"{self.__class__.__name__}--{reshape.name}", doc_string=reshape.doc_string, ) ]