Source code for experimental_experiment.xoptim.patterns.onnx_mul

import inspect
from enum import IntEnum
from typing import List, Optional
import numpy as np
from onnx import NodeProto
from ...xbuilder._shape_helper import DYNAMIC_SHAPE
from ..patterns_api import MatchResult, PatternOptimization


[docs] class MulMulMulScalarPattern(PatternOptimization): """ Replaces the sequence {Div | Mul} and {Div | Mul} + {Div | Mul} with {Div | Mul} Mul. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in {"Div", "Mul"} or node.domain != "": return self.none() if g.is_used_more_than_once(node.input[0]) or g.is_used_more_than_once(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) node_left = g.node_before(node.input[0]) if node_left is None or node_left.op_type not in {"Div", "Mul"} or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) node_right = g.node_before(node.input[1]) if node_right is None or node_right.op_type not in {"Div", "Mul"} or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) # checking for the constant (right) if not g.is_constant(node_left.input[1]) or not g.is_constant(node_right.input[1]): return self.none(node, inspect.currentframe().f_lineno) cst_left = g.get_computed_constant(node_left.input[1]) cst_right = g.get_computed_constant(node_right.input[1]) if cst_left.shape not in {tuple(), (1,)} or cst_right.shape not in { tuple(), (1,), }: return self.none(node, inspect.currentframe().f_lineno) nodes = [node, node_left, node_right] return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, node_left: NodeProto, node_right: NodeProto, ) -> List[NodeProto]: new_node = g.make_node( node.op_type, [node_left.input[0], node_right.input[0]], [g.unique_name(f"{self.__class__.__name__}--{node.output[0]}")], name=f"{self.__class__.__name__}--{node.name}", ) cst_left = g.get_computed_constant(node_left.input[1]) cst_right = g.get_computed_constant(node_right.input[1]) if node_left.op_type == "Div": cst_left = np.reciprocal(cst_left) if node_right.op_type == "Div": cst_right = np.reciprocal(cst_right) if not isinstance(cst_left, np.ndarray): cst_left = np.array(cst_left) if not isinstance(cst_right, np.ndarray): cst_right = np.array(cst_right) assert ( cst_left.dtype == cst_right.dtype ), f"Type mismatch left is {cst_left.dtype}, right is {cst_right.dtype}" new_value = cst_left * cst_right if not isinstance(new_value, np.ndarray): new_value = np.array(new_value) new_cst = g.make_initializer( "", new_value, source="MulMulMulScalarPattern.apply.new_cst" ) new_node2 = g.make_node( "Mul", [new_node.output[0], new_cst], node.output, name=f"{self.__class__.__name__}--{node.name}-Cst", ) return [new_node, new_node2]
[docs] class SwitchOrderBinaryPattern(PatternOptimization): """ If it makes sense, switches the order of two multiplications or two addtions if the broadcasting reduces one operator to a an insignificant number. """
[docs] class BroadcastType(IntEnum): """ Kind of broadcast. """ FALSE = 0 TRUE = 1 MAYBE = 2 BOTH = 3
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in {"Add", "Mul"} or node.domain != "": return self.none() if not g.has_shape(node.input[0]) or not g.has_shape(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) op_type = node.op_type left = g.node_before(node.input[0]) right = g.node_before(node.input[1]) left_type = getattr(left, "op_type", None) right_type = getattr(right, "op_type", None) if op_type not in {left_type, right_type}: return self.none() if left_type is None: choose = 1 elif right_type is None: choose = 0 else: # Both left and right do the same operator. if ( left.op_type != op_type or not g.has_shape(left.input[0]) or not g.has_shape(left.input[1]) ): if right.op_type != op_type: return self.none(node, inspect.currentframe().f_lineno) choose = 1 elif ( right.op_type != op_type or not g.has_shape(right.input[0]) or not g.has_shape(right.input[1]) ): if left.op_type != op_type: return self.none(node, inspect.currentframe().f_lineno) choose = 0 elif right.op_type != op_type: if left.op_type != op_type: return self.none(node, inspect.currentframe().f_lineno) choose = 0 elif left.op_type != op_type: choose = 1 else: # all have shapes and the right type choose = 3 other_node = left if choose == 0 else right assert ( other_node.op_type == node.op_type ), f"Type mismatch {node.op_type} != {other_node.op_type}" if not g.has_shape(other_node.input[0]) or not g.has_shape(other_node.input[1]): return self.none(node, inspect.currentframe().f_lineno) shape_left = g.get_shape(node.input[0]) shape_right = g.get_shape(node.input[1]) before_left = g.get_shape(other_node.input[0]) before_right = g.get_shape(other_node.input[1]) if self.switch_order(shape_left, shape_right, before_left, before_right, choose) == 0: if choose < 3: return self.none(node, inspect.currentframe().f_lineno) choose = 1 other_node = right before_left = g.get_shape(other_node.input[0]) before_right = g.get_shape(other_node.input[1]) if ( self.switch_order(shape_left, shape_right, before_left, before_right, choose) == 0 ): return self.none(node, inspect.currentframe().f_lineno) assert choose in (0, 1), f"Unexpected value for choose={choose}" assert ( other_node.op_type == node.op_type ), f"Type mismatch {node.op_type} != {other_node.op_type}" if g.is_used_more_than_once(other_node.output[0]): return self.none(node, inspect.currentframe().f_lineno) nodes = [node, left if choose == 0 else None, right if choose == 1 else None] return MatchResult(self, nodes, self.apply, insert_at=node)
def _align_shape(self, sh: DYNAMIC_SHAPE, rk: int) -> DYNAMIC_SHAPE: """ Aligns shapes to the same size. """ if len(sh) == rk: return sh return (1,) * (rk - len(sh)) + sh
[docs] def switch_order( self, shape_left: DYNAMIC_SHAPE, shape_right: DYNAMIC_SHAPE, shape_before_left: DYNAMIC_SHAPE, shape_before_right: DYNAMIC_SHAPE, side: int, ) -> int: """ Tells if the order should be switched. `side==0` indicates if `shape_left` comes from `Op(shape_before_left, shape_before_mul)`. if ``side == 0``: * Case 0: (B + C) + A: ``Op(Op(shape_before_left, shape_before_right), shape_right)`` * Case 1: (B + A) + C: ``Op(Op(shape_before_left, shape_right), shape_before_right)`` * Case 2: (A + C) + B: ``Op(Op(shape_right, shape_before_left), shape_before_left)`` The function returns the case. """ if side == 1: return self.switch_order( shape_right, shape_left, shape_before_left, shape_before_right, 0 ) # option r_left = len(shape_left) r_right = len(shape_right) r_b_left = len(shape_before_left) r_b_right = len(shape_before_right) rk = max(max(r_left, r_right), max(r_b_left, r_b_right)) assert max(r_left, r_right) == rk, ( f"Inconsistencies with shapes (side={side}) shape_left={shape_left}, " f"shape_right={shape_right}, shape_before_left={shape_before_left}, " f"shape_before_right={shape_before_right}" ) cases = [ max(r_b_left, r_b_right), max(r_right, r_b_left), max(r_right, r_b_right), ] if cases[0] < min(cases[1], cases[2]): return 0 if cases[1] < min(cases[0], cases[2]): return 1 if cases[2] < min(cases[0], cases[1]): return 2 # Ranks cannot be used to determine if switch is recommended. rk = max(cases) # shape_left = self._align_shape(shape_left, rk) shape_right = self._align_shape(shape_right, rk) shape_before_left = self._align_shape(shape_before_left, rk) shape_before_right = self._align_shape(shape_before_right, rk) for b, c, a in zip(shape_before_left, shape_before_right, shape_right): if b == c == a: continue if isinstance(a, int) and isinstance(b, int) and isinstance(c, int): cases = [max(b, c), max(b, a), max(a, c)] if cases[0] < min(cases[1], cases[2]): return 0 if cases[1] < min(cases[0], cases[2]): return 1 if cases[2] < min(cases[0], cases[1]): return 2 # Dynamic shapes is not implemented yet but it should # take place here. # No change. return 0
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, node_left: NodeProto, node_right: NodeProto, ) -> List[NodeProto]: side = 1 if node_left is None else 0 other_node = node_right if node_left is None else node_left assert ( other_node.op_type == node.op_type ), f"Type mismatch {node.op_type} != {other_node.op_type}" shape_left = g.get_shape(node.input[0]) shape_right = g.get_shape(node.input[1]) before_left = g.get_shape(other_node.input[0]) before_right = g.get_shape(other_node.input[1]) case = self.switch_order(shape_left, shape_right, before_left, before_right, side) assert case in (1, 2), ( f"case={case}, the matching should not have happened " f"(side={side}) shape_left={shape_left}, " f"shape_right={shape_right}, before_left={before_left}, " f"before_right={before_right}" ) # For side == 0 # Case 0: (B + C) + A # Case 1: (B + A) + C # Case 2: (A + C) + B op_type = node.op_type final = node.output[0] if side == 0: B, C, A = other_node.input[0], other_node.input[1], node.input[1] if case == 1: op1 = g.make_node( op_type, [B, A], name=f"{self.__class__.__name__}--{node.name}" ) op2 = g.make_node( op_type, [op1.output[0], C], [final], doc_string=node.doc_string, name=f"{self.__class__.__name__}--{node.name}", ) return [op1, op2] # case 2 op1 = g.make_node(op_type, [C, A], name=f"{self.__class__.__name__}--{node.name}") op2 = g.make_node( op_type, [op1.output[0], B], [final], doc_string=node.doc_string, name=f"{self.__class__.__name__}--{node.name}", ) return [op1, op2] # side 1 B, C, A = other_node.input[0], other_node.input[1], node.input[0] if case == 1: op1 = g.make_node(op_type, [B, A], name=f"{self.__class__.__name__}--{node.name}") op2 = g.make_node( op_type, [op1.output[0], C], [final], doc_string=node.doc_string, name=f"{self.__class__.__name__}--{node.name}", ) return [op1, op2] # case 2 op1 = g.make_node(op_type, [C, A], name=f"{self.__class__.__name__}--{node.name}") op2 = g.make_node( op_type, [op1.output[0], B], [final], doc_string=node.doc_string, name=f"{self.__class__.__name__}--{node.name}", ) return [op1, op2]