Source code for experimental_experiment.xoptim.patterns.onnx_reshape

import inspect
from typing import List, Optional
from onnx import NodeProto
from ...xbuilder._onnx_helper import element_wise_binary_op_types
from ...xbuilder._shape_helper import all_int
from ..patterns_api import MatchResult, PatternOptimization


[docs] class ReshapePattern(PatternOptimization): """ Checks that a Reshape is really needed. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) shape = g.get_shape(node.input[0]) if not all_int(shape): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[1]): # It may be a symbolic shape. return self.none(node, inspect.currentframe().f_lineno) value = g.get_computed_constant(node.input[1]) if value is None: return self.none(node, inspect.currentframe().f_lineno) new_shape = tuple(int(i) for i in value) if shape != new_shape: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( "Identity", node.input, node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node]
[docs] class ReduceReshapePattern(PatternOptimization): """ Replaces the sequence Reduce* Reshape if reshape is only introduces to deal with a dimension kept because keepdim=1. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if not node.op_type.startswith("Reduce") or node.domain != "": return self.none() if g.is_used_more_than_once(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) att = g.get_attribute(node, "keepdims", exc=False) keepdims = 1 if att is None else att.i if keepdims == 0: # not keeping the dimension so Reshape means to restore them. return self.none(node, inspect.currentframe().f_lineno) if len(node.input) == 2: if not g.is_constant(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) axes = tuple(g.get_computed_constant(node.input[1])) else: att = g.get_attribute(node, "axes", exc=False) axes = tuple(range(g.get_rank(node.input[0]))) if att is None else tuple(att.ints) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) next_node = next_nodes[0] if next_node.op_type != "Reshape" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if next_node.input[0] != node.output[0]: return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(node.input[0]) != g.get_rank(next_node.output[0]) + len(axes): return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(next_node.output[0]) > 1: if not g.has_shape(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) set_axes = set(axes) shape = g.get_shape(node.input[0]) reduced_shape = [s for i, s in enumerate(shape) if i not in set_axes] reshaped_shape = g.get_shape(next_node.output[0]) if reduced_shape != reshaped_shape: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, next_node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, next_node: NodeProto, ) -> List[NodeProto]: axes = g.get_attribute(node, "axes", exc=False) if axes is None: new_node = g.make_node( node.op_type, node.input, next_node.output, keepdims=0, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node] # older opset new_node = g.make_node( node.op_type, node.input, next_node.output, keepdims=0, axes=list(axes.ints), name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node]
[docs] class ReshapeReshapePattern(PatternOptimization): """ Replaces the sequence Reshape, Reshape by Reshape. """ def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Reshape" or node.domain != "": return self.none() if g.is_used_more_than_once(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) next_node = next_nodes[0] if next_node.op_type != "Reshape" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if next_node.input[0] != node.output[0]: return self.none(node, inspect.currentframe().f_lineno) if g.is_constant(node.input[1]): cst = g.get_computed_constant(node.input[1]) if -1 in cst.tolist(): # Then we only allow it the shape is static. if not g.is_constant(next_node.input[1]): return self.none(node, inspect.currentframe().f_lineno) cst = g.get_computed_constant(next_node.input[1]) if cst.min() <= 0: return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, next_node], self.apply, insert_at=next_node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, next_node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( "Reshape", [node.input[0], next_node.input[1]], next_node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=next_node.doc_string, ) return [new_node]
[docs] class Reshape2Of3Pattern(PatternOptimization): """ Replaces the reshapes around element-wise operators. It can be 3 or 2 out of 3. """ _op_types = element_wise_binary_op_types()
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in self._op_types or node.domain != "": return self.none() if ( not g.has_shape(node.output[0]) or not g.has_shape(node.input[0]) or not g.has_shape(node.input[1]) ): # Shapes are missing. They should be populated as much as possible. return self.none(node, inspect.currentframe().f_lineno) shape_out = g.get_shape(node.output[0]) shape_in = g.get_shape(node.input[0]), g.get_shape(node.input[1]) if not (shape_out == shape_in[0] == shape_in[1]): # Broadcasting is involved. return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) > 1 or (len(next_nodes) == 0 and not g.is_output(node.output[0])): return self.none(node, inspect.currentframe().f_lineno) next_node = None if len(next_nodes) == 0 else next_nodes[0] type_out = None if next_node is None else next_node.op_type node_left = g.node_before(node.input[0]) node_right = g.node_before(node.input[1]) type_left = None if node_left is None else node_left.op_type type_right = None if node_right is None else node_right.op_type types = [type_left, type_right, type_out, node.op_type] n_reshape = len([_ for _ in types if _ == "Reshape"]) if n_reshape < 2: return self.none(node, inspect.currentframe().f_lineno) if node_left is not None and node_left.op_type != "Reshape": node_left = None if node_right is not None and node_right.op_type != "Reshape": node_right = None if next_node is not None and next_node.op_type != "Reshape": next_node = None shapes = [ ( None if (node_left is None or not g.has_shape(node_left.input[0])) else g.get_shape(node_left.input[0]) ), ( None if (node_right is None or not g.has_shape(node_right.input[0])) else g.get_shape(node_right.input[0]) ), ( None if (next_node is None or not g.has_shape(next_node.output[0])) else g.get_shape(next_node.output[0]) ), ] if len(set(_ for _ in shapes if _ is not None)) != 1: # Not the same shapes. return self.none(node, inspect.currentframe().f_lineno) nodes = [node_left, node_right, next_node, node] return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node_left: NodeProto, node_right: NodeProto, next_node: NodeProto, node: NodeProto, ) -> List[NodeProto]: compute_shape_name = node_left.input[1] if node_right is None else node_right.input[1] final_shape_name = compute_shape_name if next_node is None else next_node.input[1] res = [] # node left if node_left is None: left_name = g.unique_name(f"{self.__class__.__name__}L_{node.input[0]}") res.append( g.make_node( "Reshape", [node.input[0], final_shape_name], [left_name], name=f"{self.__class__.__name__}--{node.name}", ) ) elif g.is_used_more_than_once(node_left.output[0]): res.append(node_left) left_name = node_left.input[0] else: left_name = node_left.input[0] # node right if node_right is None: right_name = g.unique_name(f"{self.__class__.__name__}R_{node.input[1]}") res.append( g.make_node( "Reshape", [node.input[1], final_shape_name], [right_name], name=f"{self.__class__.__name__}--{node.name}", ) ) elif g.is_used_more_than_once(node_right.output[0]): res.append(node_right) right_name = node_right.input[0] else: right_name = node_right.input[0] # node and next node if next_node is None: # Reshape is needed. new_name = g.unique_name(f"{self.__class__.__name__}L_{node.output[0]}") res.extend( [ g.make_node( node.op_type, [left_name, right_name], [new_name], name=f"{self.__class__.__name__}--{node.name}", ), g.make_node( "Reshape", [new_name, final_shape_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ), ] ) else: main_node = g.make_node( node.op_type, [left_name, right_name], [next_node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ) res.append(main_node) if g.is_used_more_than_once(node.output[0]): res.append( g.make_node( "Reshape", [main_node.output[0], compute_shape_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ) ) return res
[docs] class ReshapeReshapeBinaryPattern(PatternOptimization): """ Moves two reshape operators beyond a binary operator if it is possible. """ _op_types = element_wise_binary_op_types()
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in self._op_types or node.domain != "": return self.none() if g.is_used_more_than_once(node.input[0]) or g.is_used_more_than_once(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) left, right = g.node_before(node.input[0]), g.node_before(node.input[1]) if left is None or left.op_type != "Reshape" or left.domain != "": return self.none(node, inspect.currentframe().f_lineno) if right is None or right.op_type != "Reshape" or right.domain != "": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(left.input[1]) or not g.is_constant(right.input[1]): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) cst_left = g.get_computed_constant(left.input[1]).tolist() cst_right = g.get_computed_constant(right.input[1]).tolist() if cst_left != cst_right: return self.none(node, inspect.currentframe().f_lineno) shape1 = g.get_shape(left.input[0]) if g.has_shape(left.input[0]) else None shape2 = g.get_shape(right.input[0]) if g.has_shape(right.input[0]) else None if shape1 is None or shape2 is None or shape1 != shape2: return self.none(node, inspect.currentframe().f_lineno) # If there is not broadcast involved then it is ok. # At this stage, we know shapes are equal before the reshaped operators # and the same reshape is applied. So checking the output shape # is not necesssary. return MatchResult(self, [left, right, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 left: NodeProto, right: NodeProto, node: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( node.op_type, [left.input[0], right.input[0]], name=f"{self.__class__.__name__}--{node.name}", ) reshape_node = g.make_node( "Reshape", [new_node.output[0], left.input[1]], node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ) return [new_node, reshape_node]