Source code for experimental_experiment.xoptim.patterns.onnx_matmul

import inspect
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
from onnx import NodeProto
from ...xshape._shape_helper import (
    compatible_shapes,
    compatible_dimensions,
    is_static_shape,
    all_int,
)
from ..patterns_api import MatchResult, PatternOptimization


[docs] class MatMulAddPattern(PatternOptimization): """ Replaces the sequence MatMul, Add into Gemm. By default, no reshape is allowed this happens only it is two dimensions. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape=("a", "b", "d")) ) inputs.append( oh.make_tensor_value_info("X1", onnx.TensorProto.FLOAT, shape=("a", "b", 3)) ) inputs.append(oh.make_tensor_value_info("X2", onnx.TensorProto.FLOAT, shape=(3, "d"))) nodes.append(oh.make_node("MatMul", ["X1", "X2"], ["Y"])) nodes.append(oh.make_node("Add", ["Y", "B"], ["Z"])) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "b", "d")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape=("a", "b", "d")) ) inputs.append( oh.make_tensor_value_info("X1", onnx.TensorProto.FLOAT, shape=("a", "b", 3)) ) inputs.append(oh.make_tensor_value_info("X2", onnx.TensorProto.FLOAT, shape=(3, "d"))) nodes.append( oh.make_node( "Constant", [], ["init7_s2_-1_3"], value=onh.from_array(np.array([-1, 3], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_-1"], value=onh.from_array(np.array([-1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node("Reshape", ["X1", "init7_s2_-1_3"], ["MatMulAddPattern--X1"]) ) nodes.append(oh.make_node("Shape", ["B"], ["MatMulAddPattern--X12"], start=-1)) nodes.append( oh.make_node( "Concat", ["init7_s1_-1", "MatMulAddPattern--X12"], ["MatMulAddPattern--X13"], axis=0, ) ) nodes.append( oh.make_node( "Reshape", ["B", "MatMulAddPattern--X13"], ["MatMulAddPattern--X14"] ) ) nodes.append( oh.make_node("Shape", ["X1"], ["MatMulAddPattern--X16"], end=-1, start=0) ) nodes.append( oh.make_node( "Concat", ["MatMulAddPattern--X16", "init7_s1_-1"], ["MatMulAddPattern--X17"], axis=0, ) ) nodes.append( oh.make_node( "Gemm", ["MatMulAddPattern--X1", "X2", "MatMulAddPattern--X14"], ["MatMulAddPattern--X15"], ) ) nodes.append( oh.make_node( "Reshape", ["MatMulAddPattern--X15", "MatMulAddPattern--X17"], ["Z"] ) ) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "b", "d")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ def __init__(self, verbose: int = 0, priority: int = 3, allow_reshape: bool = False): super().__init__(verbose, priority) self.allow_reshape = allow_reshape
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in {"MatMul", "Gemm"} or node.domain != "": return self.none() if not g.has_rank(node.input[0]) or not g.has_rank(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(node.input[0]) < 2 or g.get_rank(node.input[1]) != 2: # If node.op_type is Gemm, this condition is useless, # if node.op_type is MatMul we reshape the matrix if # it the rank is > 2 and the last dimension known, # but then no bias should be allowed a Gemm does not support # broadcast. return self.none(node, inspect.currentframe().f_lineno) if not self.allow_reshape and g.get_rank(node.input[0]) != 2: # No reshape is allowed. return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(node.input[0]) > 2: sh1 = g.get_shape(node.input[0]) if g.has_shape(node.input[0]) else None sh2 = g.get_shape(node.input[1]) if g.has_shape(node.input[1]) else None if (sh1 is None or not isinstance(sh1[-1], int)) and ( sh2 is None or not isinstance(sh2[0], int) ): # unkown k for the matrix multiplication return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) add_node = next_nodes[0] if add_node.op_type != "Add": return self.none(node, inspect.currentframe().f_lineno) # Gemm does not allow broadcasting. bias2 = add_node.input[0 if add_node.input[1] == node.output[0] else 1] if not g.has_shape(node.input[1]) or not g.has_shape(bias2): return self.none(node, inspect.currentframe().f_lineno) transB = ( g.get_attributes_with_default(node, transB=0).get("transB", 0) if node.op_type == "Gemm" else 0 ) shape_2 = g.get_shape(node.input[1]) last_dim = shape_2[-1 - transB] shape_bias = g.get_shape(bias2) if last_dim != shape_bias[-1]: return self.none(node, inspect.currentframe().f_lineno) if len(shape_bias) > 1: shape_node_out = g.get_shape(node.output[0]) if g.has_shape(node.output[0]) else None if shape_node_out is not None: if len(shape_node_out) != len(shape_bias): return self.none(node, inspect.currentframe().f_lineno) elif shape_node_out != shape_bias: return self.none(node, inspect.currentframe().f_lineno) elif min(shape_bias[:-1]) <= 1: return self.none(node, inspect.currentframe().f_lineno) if add_node.input[0] == add_node.input[1]: return self.none(node, inspect.currentframe().f_lineno) if node.op_type == "MatMul" or len(node.input) == 2: return MatchResult(self, [node, add_node], self.apply, insert_at=add_node) bias = node.input[2] if ( not g.has_shape(bias) or not g.has_shape(bias2) or g.get_shape(bias) != g.get_shape(bias2) ): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node, add_node], self.apply, insert_at=add_node)
def _apply_matmmul( self, g: "GraphBuilder", # noqa: F821 matmul_node: NodeProto, add_node: NodeProto, ) -> List[NodeProto]: bias2 = add_node.input[0 if add_node.input[1] == matmul_node.output[0] else 1] if g.get_rank(matmul_node.input[0]) > 2: rk_bias = g.get_rank(bias2) # get k sh1 = g.get_shape(matmul_node.input[0]) if g.has_shape(matmul_node.input[0]) else None sh2 = g.get_shape(matmul_node.input[1]) if g.has_shape(matmul_node.input[1]) else None k = sh1[-1] if sh1 is not None and isinstance(sh1[-1], int) else sh2[0] new_shape = g.make_initializer( "", np.array([-1, k], dtype=np.int64), source="MatMulAddPattern.new_shape.1", ) reshaped = g.unique_name(f"{self.__class__.__name__}--{matmul_node.input[0]}") reshape_node = g.make_node( "Reshape", [matmul_node.input[0], new_shape], [reshaped], name=f"{self.__class__.__name__}--{matmul_node.name}", ) reshape_nodes = [reshape_node] if rk_bias > 2: if g.has_shape(bias2) and isinstance(g.get_shape(bias2)[-1], int): new_shape_bias = g.make_initializer( "", np.array([-1, g.get_shape(bias2)[-1]], dtype=np.int64), source="MatMulAddPattern.new_shape.3", ) else: that_shape_bias = g.unique_name( f"{self.__class__.__name__}--{matmul_node.input[0]}" ) reshape_nodes.append( g.make_node( "Shape", [bias2], [that_shape_bias], start=-1, name=f"{self.__class__.__name__}--{matmul_node.name}", ) ) new_shape_bias = g.unique_name( f"{self.__class__.__name__}--{matmul_node.input[0]}" ) minus1 = g.make_initializer( "", g.MINUS_ONE, source="MatMulAddPattern.new_shape.7", ) reshape_nodes.append( g.make_node( "Concat", [minus1, that_shape_bias], [new_shape_bias], axis=0, name=f"{self.__class__.__name__}--{matmul_node.name}", ) ) reshaped_bias = g.unique_name( f"{self.__class__.__name__}--{matmul_node.input[0]}" ) reshape_bias_node = g.make_node( "Reshape", [bias2, new_shape_bias], [reshaped_bias], name=f"{self.__class__.__name__}--{matmul_node.name}", ) reshape_nodes.append(reshape_bias_node) bias_gemm_name = reshaped_bias else: bias_gemm_name = bias2 inputs = [reshaped, matmul_node.input[1]] unshaped = g.unique_name(f"{self.__class__.__name__}--{matmul_node.input[0]}") outputs = [unshaped] # last reshape if g.has_shape(matmul_node.input[0]) and all_int(g.get_shape(matmul_node.input[0])): shape_back = g.make_initializer( "", np.array([*g.get_shape(matmul_node.input[0])[:-1], -1], dtype=np.int64), source="MatMulAddPattern.new_shape.2", ) else: # We extract the shape. that_shape = g.unique_name(f"{self.__class__.__name__}--{matmul_node.input[0]}") reshape_nodes.append( g.make_node( "Shape", [matmul_node.input[0]], [that_shape], start=0, end=-1, name=f"{self.__class__.__name__}--{matmul_node.name}", ) ) shape_back = g.unique_name(f"{self.__class__.__name__}--{matmul_node.input[0]}") minus1 = g.make_initializer( "", g.MINUS_ONE, source="MatMulAddPattern.new_shape.3", ) reshape_nodes.append( g.make_node( "Concat", [that_shape, minus1], [shape_back], axis=0, name=f"{self.__class__.__name__}--{matmul_node.name}", ) ) reshape_back = g.make_node( "Reshape", [unshaped, shape_back], add_node.output, name=f"{self.__class__.__name__}--{matmul_node.name}", ) else: inputs = matmul_node.input outputs = add_node.output reshape_node = None bias_gemm_name = bias2 new_node = g.make_node( "Gemm", [*inputs, bias_gemm_name], outputs, name=f"{self.__class__.__name__}--{matmul_node.name}", doc_string=matmul_node.doc_string, ) if matmul_node.op_type == "Gemm": new_node.attribute.extend(matmul_node.attribute) if reshape_node: return [*reshape_nodes, new_node, reshape_back] return [new_node]
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 matmul_node: NodeProto, add_node: NodeProto, ) -> List[NodeProto]: if matmul_node.op_type == "MatMul" or len(matmul_node.input) == 2: return self._apply_matmmul(g, matmul_node, add_node) bias2 = add_node.input[0 if add_node.input[1] == matmul_node.output[0] else 1] # Two bias we need to add first. bias_all = g.unique_name(f"{self.__class__.__name__}--{matmul_node.input[2]}") new_add_node = g.make_node( "Add", [bias2, matmul_node.input[2]], [bias_all], name=f"{self.__class__.__name__}--{matmul_node.name}", ) new_node = g.make_node( "Gemm", [*matmul_node.input[:2], bias_all], add_node.output, name=f"{self.__class__.__name__}--{matmul_node.name}", doc_string=matmul_node.doc_string, ) new_node.attribute.extend(matmul_node.attribute) return [new_add_node, new_node]
[docs] class GemmTransposePattern(PatternOptimization): """ Replaces Gemm (., constant) by Gemm(., constant', transB=1) Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape=(3, 2))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(2, 3))) nodes.append( oh.make_node( "Constant", [], ["B"], value=onh.from_array( np.array( [ [0.0, 0.1666666716337204], [0.3333333432674408, 0.5], [0.6666666865348816, 0.8333333134651184], ], dtype=np.float32, ), name="value", ), ) ) nodes.append(oh.make_node("Gemm", ["X", "B"], ["Z"])) outputs.append(oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(2, 2))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape=(3, 2))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(2, 3))) nodes.append( oh.make_node("Transpose", ["B"], ["GemmTransposePattern--B"], perm=[1, 0]) ) nodes.append( oh.make_node("Gemm", ["X", "GemmTransposePattern--B"], ["Z"], transB=1) ) outputs.append(oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(2, 2))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Gemm" or node.domain != "": return self.none() if not g.is_constant(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if node.op_type == "Gemm": atts = g.get_attributes_with_default(node, transA=0, transB=0, beta=1.0) if atts.get("beta", 1) != 1: return self.none(node, inspect.currentframe().f_lineno) if atts.get("transB", 0) or atts.get("transA", 0): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, ) -> List[NodeProto]: tr = g.unique_name(f"{self.__class__.__name__}--{node.input[1]}") return [ g.make_node( "Transpose", [node.input[1]], [tr], perm=[1, 0], name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ), g.make_node( "Gemm", [node.input[0], tr, *node.input[2:]], node.output, transB=1, name=f"{self.__class__.__name__}--{node.name}", doc_string=node.doc_string, ), ]
[docs] class MatMulReshape2Of3Pattern(PatternOptimization): """ Replaces the reshapes around a matmul It can be 3 or 2 out of 3. It is similar to :class:`experimental_experiment.xoptim.patterns.onnx_reshape.Reshape2Of3Pattern`. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("div_5", onnx.TensorProto.FLOAT, shape=(13, 4, 7, 7)) ) inputs.append( oh.make_tensor_value_info("transpose_23", onnx.TensorProto.FLOAT, shape=(52, 7, 8)) ) inputs.append( oh.make_tensor_value_info("init7_s4_13_4_7_8", onnx.TensorProto.INT64, shape=(4,)) ) nodes.append( oh.make_node( "Constant", [], ["init7_s3_52_7_7"], value=onh.from_array(np.array([52, 7, 7], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s4_13_4_7_8"], value=onh.from_array(np.array([13, 4, 7, 8], dtype=np.int64), name="value"), ) ) nodes.append(oh.make_node("Reshape", ["div_5", "init7_s3_52_7_7"], ["view_83"])) nodes.append(oh.make_node("MatMul", ["view_83", "transpose_23"], ["bmm_11"])) nodes.append( oh.make_node("Reshape", ["bmm_11", "init7_s4_13_4_7_8"], ["view_85"]) ) outputs.append( oh.make_tensor_value_info("view_85", onnx.TensorProto.FLOAT, shape=(13, 4, 7, 8)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("div_5", onnx.TensorProto.FLOAT, shape=(13, 4, 7, 7)) ) inputs.append( oh.make_tensor_value_info("transpose_23", onnx.TensorProto.FLOAT, shape=(52, 7, 8)) ) inputs.append( oh.make_tensor_value_info("init7_s4_13_4_7_8", onnx.TensorProto.INT64, shape=(4,)) ) nodes.append( oh.make_node( "Reshape", ["transpose_23", "init7_s4_13_4_7_8"], ["MatMulReshape2Of3PatternL_view_83"], ) ) nodes.append( oh.make_node( "MatMul", ["div_5", "MatMulReshape2Of3PatternL_view_83"], ["view_85"] ) ) outputs.append( oh.make_tensor_value_info("view_85", onnx.TensorProto.FLOAT, shape=(13, 4, 7, 8)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ def same_size( self, g: "GraphBuilderPatternOptimization", # noqa: F821, sh1: Tuple[int, ...], sh2: Tuple[int, ...], constraints: Dict[str, Set[Union[int, str]]], ) -> bool: # We cannot handle all the case. if is_static_shape(sh1) and is_static_shape(sh2): return np.prod(sh1) == np.prod(sh2) if sh1 == sh2: return True # The constraints should be applied here. return False
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if (node.op_type != "MatMul" or node.domain != "") and ( node.op_type != "FusedMatMul" or node.domain != "com.microsoft" ): return self.none() if node.op_type == "FusedMatMul": tA = g.get_attribute(node, "transBatchA", exc=False) if tA is not None and tA.i != 0: return self.none(node, inspect.currentframe().f_lineno) tB = g.get_attribute(node, "transBatchB", exc=False) if tB is not None and tB.i != 0: return self.none(node, inspect.currentframe().f_lineno) if ( not g.has_shape(node.output[0]) or not g.has_shape(node.input[0]) or not g.has_shape(node.input[1]) ): # Shapes are missing. They should be populated as much as possible. return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) > 1 or (len(next_nodes) == 0 and not g.is_output(node.output[0])): return self.none(node, inspect.currentframe().f_lineno) next_node = None if len(next_nodes) == 0 else next_nodes[0] node_left = g.node_before(node.input[0]) node_right = g.node_before(node.input[1]) type_left = None if node_left is None else node_left.op_type type_right = None if node_right is None else node_right.op_type type_out = None if next_node is None else next_node.op_type types = [type_left, type_right, type_out] n_reshape = len([_ for _ in types if _ == "Reshape"]) if n_reshape < 2: return self.none(node, inspect.currentframe().f_lineno) if node_left is not None and node_left.op_type != "Reshape": node_left = None if node_right is not None and node_right.op_type != "Reshape": node_right = None if next_node is not None and next_node.op_type != "Reshape": next_node = None if node_left is not None and not g.has_shape(node_left.input[0]): return self.none(node, inspect.currentframe().f_lineno) if node_right is not None and not g.has_shape(node_right.input[0]): return self.none(node, inspect.currentframe().f_lineno) shape_left_left = None if node_left is None else g.get_shape(node_left.input[0]) shape_right_right = None if node_right is None else g.get_shape(node_right.input[0]) shape_left = g.get_shape(node.input[0]) shape_right = g.get_shape(node.input[1]) if ( shape_left_left is not None and not self.same_size( g, shape_left[-2:], shape_left_left[-2:], g.get_registered_constraints() ) ) or ( shape_right_right is not None and not self.same_size( g, shape_right[-2:], shape_right_right[-2:], g.get_registered_constraints() ) ): # last dimension are the same return self.none(node, inspect.currentframe().f_lineno) the_shape_left = shape_left_left or shape_left the_shape_right = shape_right_right or shape_right if not is_static_shape(the_shape_left) or not is_static_shape(the_shape_right): return self.none(node, inspect.currentframe().f_lineno) if not self.same_size( g, the_shape_left[:-2], the_shape_right[:-2], g.get_registered_constraints() ): # first dimension are the same return self.none(node, inspect.currentframe().f_lineno) if next_node is not None: next_shape = g.get_shape(next_node.output[0]) matmul_shape = (*the_shape_left[:-1], shape_right[-1]) if matmul_shape[-2:] != next_shape[-2:] and not self.same_size( g, matmul_shape[:-2], next_shape[:-2], g.get_registered_constraints() ): return self.none(node, inspect.currentframe().f_lineno) first_dims = {next_shape[:-2], the_shape_left[:-2], the_shape_right[:-2]} if len(first_dims) == 3: # All shapes are different. It is not worth it. return self.none(node, inspect.currentframe().f_lineno) if len(next_shape) != len(the_shape_left) and len(next_shape) != len(the_shape_right): return self.none(node, inspect.currentframe().f_lineno) if matmul_shape[-1] != next_shape[-1]: # 1x9x64, 1x64x9 -> 1x9x9 -> 1x81 # The last dimension changed. return self.none(node, inspect.currentframe().f_lineno) else: if len(the_shape_left) != len(the_shape_right): return self.none(node, inspect.currentframe().f_lineno) # The pattern is not handling the reshape after the matmul, # ReshapeReshapePattern will do it. nodes = [node_left, node_right, node, next_node] return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node_left: Optional[NodeProto], node_right: Optional[NodeProto], node: NodeProto, next_node: Optional[NodeProto], ) -> List[NodeProto]: res = [] shape_left_left = None if node_left is None else g.get_shape(node_left.input[0]) shape_right_right = None if node_right is None else g.get_shape(node_right.input[0]) shape_left = g.get_shape(node.input[0]) shape_right = g.get_shape(node.input[1]) the_shape_left = shape_left_left or shape_left the_shape_right = shape_right_right or shape_right # If the first dimensions are not the same, we may assume # the size is the same but a reshape is still needed. add_right, add_left = False, False one_more_reshape = the_shape_left[:-2] != the_shape_right[:-2] if one_more_reshape: expected_shape = g.get_shape( node.output[0] if next_node is None else next_node.output[0] ) assert node_left is not None or node_right is not None, ( f"Shapes are not consistent, one node Reshape should be there, " f"node.name={node.name!r}, " f"shape_left={shape_left}, shape_right={shape_right}, " f"the_shape_left={shape_left_left}, " f"the_shape_right={the_shape_right}, " f"node_left is None={node_left is None}, " f"node_right is None={node_right is None}, " f"next_node is None={next_node is None}, " f"expected_shape={expected_shape}" ) if node_left is not None and the_shape_left[:-2] != expected_shape[:-2]: add_left = True elif node_right is not None and the_shape_right[:-2] != expected_shape[:-2]: add_right = True elif node_left is not None and node_right is not None: raise AssertionError( f"Case still not implemented, shapes are not consistent, " f"one node Reshape should be there, " f"node.name={node.name!r}, " f"shape_left={shape_left}, shape_right={shape_right}, " f"the_shape_left={shape_left_left}, " f"the_shape_right={the_shape_right}, " f"node_left is None={node_left is None}, " f"node_right is None={node_right is None}, " f"next_node is None={next_node is None}, " f"expected_shape={expected_shape}" ) # node left if node_left is None: expected_shape = the_shape_right[:-2] + shape_left[-2:] if the_shape_left != expected_shape: shape_name = g.make_initializer( "", np.array(expected_shape, dtype=np.int64), source="MatMulReshape2Of3Pattern.apply.shape.1", ) left_name = g.unique_name(f"{self.__class__.__name__}L_{node.input[0]}") res.append( g.make_node( "Reshape", [node.input[0], shape_name], [left_name], name=f"{self.__class__.__name__}--{node.name}", ) ) else: left_name = node.input[0] elif g.is_used_more_than_once(node_left.output[0]): res.append(node_left) left_name = node_left.input[0] else: left_name = node_left.input[0] # node right if node_right is None: expected_shape = the_shape_left[:-2] + shape_right[-2:] if the_shape_right != expected_shape: shape_name = g.make_initializer( "", np.array(expected_shape, dtype=np.int64), source="MatMulReshape2Of3Pattern.apply.shape.2", ) right_name = g.unique_name(f"{self.__class__.__name__}L_{node.input[0]}") res.append( g.make_node( "Reshape", [node.input[1], shape_name], [right_name], name=f"{self.__class__.__name__}--{node.name}", ) ) else: right_name = node.input[1] elif g.is_used_more_than_once(node_right.output[0]): res.append(node_right) right_name = node_right.input[0] else: right_name = node_right.input[0] if next_node is None: assert not add_right and not add_left, ( f"add_right={add_right}, add_left={add_left} " f"are not implemented yet in this case." ) # Reshape is needed. previous_shape = (*shape_left[:-1], shape_right[-1]) new_shape = (*the_shape_left[:-1], the_shape_right[-1]) if previous_shape != new_shape: new_name = g.unique_name(f"{self.__class__.__name__}L_{node.output[0]}") previous_shape_name = g.make_initializer( "", np.array(previous_shape, dtype=np.int64), source="MatMulReshape2Of3Pattern.shape.3", ) mm = g.make_node( node.op_type, [left_name, right_name], [new_name], name=f"{self.__class__.__name__}--{node.name}", domain=node.domain, ) if node.attribute: mm.attribute.extend(node.attribute) res.extend( [ mm, g.make_node( "Reshape", [new_name, previous_shape_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ), ] ) else: mm = g.make_node( node.op_type, [left_name, right_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", domain=node.domain, ) if node.attribute: mm.attribute.extend(node.attribute) res.append(mm) else: if add_left: new_left_name = g.unique_name(f"{self.__class__.__name__}AL_{left_name}") new_sh = g.get_shape(next_node.output[0])[:-2] + g.get_shape(node.input[0])[-2:] sh = g.make_initializer( "", np.array(new_sh, dtype=np.int64), source="MatMulReshape2Of3Pattern.apply.sh.1", ) add = g.make_node( "Reshape", [left_name, sh], [new_left_name], name=f"{self.__class__.__name__}--AL--{node.name}", ) res.append(add) left_name = new_left_name if add_right: new_right_name = g.unique_name(f"{self.__class__.__name__}AR_{right_name}") new_sh = g.get_shape(next_node.output[0])[:-2] + g.get_shape(node.input[1])[-2:] sh = g.make_initializer( "", np.array(new_sh, dtype=np.int64), source="MatMulReshape2Of3Pattern.apply.sh.2", ) add = g.make_node( "Reshape", [right_name, sh], [new_right_name], name=f"{self.__class__.__name__}--AR--{node.name}", ) res.append(add) right_name = new_right_name main_node = g.make_node( node.op_type, [left_name, right_name], [next_node.output[0]], name=f"{self.__class__.__name__}--{node.name}", domain=node.domain, ) if node.attribute: main_node.attribute.extend(node.attribute) res.append(main_node) if g.is_used_more_than_once(node.output[0]): previous_shape = (*shape_left[:-1], shape_right[-1]) previous_shape_name = g.make_initializer( "", np.array(previous_shape, dtype=np.int64), source="MatMulReshape2Of3Pattern.apply.shape.4", ) res.append( g.make_node( "Reshape", [main_node.output[0], previous_shape_name], [node.output[0]], name=f"{self.__class__.__name__}--{node.name}", ) ) return res
[docs] class MulMulMatMulPattern(PatternOptimization): """ Replaces ``MatMul(a*c, b*d)`` where c and d are constant scalar by ``MatMul(a,b) * (c,d)``. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(16, 64))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(32, 16))) nodes.append( oh.make_node( "Constant", [], ["c"], value=onh.from_array( np.array([0.4000000059604645], dtype=np.float32), name="value" ), ) ) nodes.append( oh.make_node( "Constant", [], ["d"], value=onh.from_array( np.array([0.6000000238418579], dtype=np.float32), name="value" ), ) ) nodes.append(oh.make_node("Mul", ["X", "c"], ["a"])) nodes.append(oh.make_node("Mul", ["d", "Y"], ["b"])) nodes.append(oh.make_node("MatMul", ["a", "b"], ["z"])) outputs.append(oh.make_tensor_value_info("z", onnx.TensorProto.FLOAT, shape=(32, 64))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(16, 64))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(32, 16))) nodes.append( oh.make_node( "Constant", [], ["init1_s1_"], value=onh.from_array( np.array([0.24000000953674316], dtype=np.float32), name="value" ), ) ) nodes.append(oh.make_node("MatMul", ["X", "Y"], ["MulMulMatMulPattern_z"])) nodes.append(oh.make_node("Mul", ["MulMulMatMulPattern_z", "init1_s1_"], ["z"])) outputs.append(oh.make_tensor_value_info("z", onnx.TensorProto.FLOAT, shape=(32, 64))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "MatMul" or node.domain != "": return self.none() node_before = [g.node_before(i) for i in node.input] if None in node_before: return self.none(node, inspect.currentframe().f_lineno) types = set(_.op_type for _ in node_before) if types != {"Mul"}: return self.none(node, inspect.currentframe().f_lineno) cst = [i for i in [*node_before[0].input, *node_before[1].input] if g.is_constant(i)] if len(cst) != 2 or not all(g.is_constant_scalar(c) for c in cst): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [*node_before, node], self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 mul1: NodeProto, mul2: NodeProto, node: NodeProto, ) -> List[NodeProto]: cst = [i for i in [*mul1.input, *mul2.input] if g.is_constant(i)] not_cst = [i for i in [*mul1.input, *mul2.input] if i not in cst] assert len(cst) == 2, f"impossible cst={cst!r}" assert len(not_cst) == 2, f"impossible not_cst={not_cst!r}" cs = [g.get_computed_constant(c) for c in cst] c = (cs[0] * cs[1]).astype(cs[0].dtype) ccc = g.make_initializer("", c, source="MulMulMatMulPattern.apply.ccc") mul_name = g.unique_name(f"{self.__class__.__name__}_{node.output[0]}") return [ g.make_node( "MatMul", not_cst, [mul_name], name=f"{self.__class__.__name__}--{node.name}-1", ), g.make_node( "Mul", [mul_name, ccc], node.output, name=f"{self.__class__.__name__}--{node.name}-2", ), ]
[docs] class ReshapeMatMulReshapePattern(PatternOptimization): """ Replaces the sequence Reshape, Matmul, Reshape by Matmul. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("xu2", onnx.TensorProto.FLOAT, shape=(1, 1, 32, 128)) ) inputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(3, 5, 128, 64)) ) nodes.append( oh.make_node( "Constant", [], ["shape1"], value=onh.from_array(np.array([1, 32, 128], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["shape2"], value=onh.from_array(np.array([15, 128, 64], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["shape3"], value=onh.from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="value"), ) ) nodes.append(oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"])) nodes.append(oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"])) nodes.append(oh.make_node("MatMul", ["xm1", "xm2c"], ["xm"])) nodes.append(oh.make_node("Reshape", ["xm", "shape3"], ["Z"])) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(3, 5, 32, 64)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("xu2", onnx.TensorProto.FLOAT, shape=(1, 1, 32, 128)) ) inputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(3, 5, 128, 64)) ) nodes.append(oh.make_node("MatMul", ["xu2", "Y"], ["Z"])) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(3, 5, 32, 64)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "MatMul" or node.domain != "": return self.none() if g.is_used_more_than_once(node.output[0]): return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) == 0: return self.none(node, inspect.currentframe().f_lineno) next_node = next_nodes[0] if next_node.op_type != "Reshape" or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) node_before_left = g.node_before(node.input[0]) node_before_right = g.node_before(node.input[1]) if node_before_left is None or node_before_right is None: return self.none(node, inspect.currentframe().f_lineno) if ( node_before_left.op_type != "Reshape" or node_before_left.domain != "" or node_before_right.op_type != "Reshape" or node_before_right.domain != "" ): return self.none(node, inspect.currentframe().f_lineno) # condition on shapes if not g.is_constant(node_before_left.input[1]): return shape_left = tuple(int(i) for i in g.get_computed_constant(node_before_left.input[1])) if not g.is_constant(node_before_right.input[1]): return shape_right = tuple(int(i) for i in g.get_computed_constant(node_before_right.input[1])) if not g.is_constant(next_node.input[1]): return shape_final = tuple(int(i) for i in g.get_computed_constant(next_node.input[1])) if len(shape_final) < 4: return self.none(node, inspect.currentframe().f_lineno) ndim = len(shape_final) if len(shape_left) != 3 or len(shape_right) != 3: return self.none(node, inspect.currentframe().f_lineno) mshape_left = g.get_shape(node_before_left.input[0]) mshape_right = g.get_shape(node_before_right.input[0]) if len(mshape_left) != ndim or len(mshape_right) != ndim: return self.none(node, inspect.currentframe().f_lineno) if ( not compatible_shapes(mshape_left[-2:], shape_left[-2:]) or not compatible_shapes(mshape_right[-2:], shape_right[-2:]) or not compatible_dimensions( mshape_left[-1], shape_left[-1], mshape_right[-2], shape_right[-2] ) ): return self.none(node, inspect.currentframe().f_lineno) # At this stage, both Reshape before MatMul reduces the rank by 1 # without changing the two last dimensions # and the Reshape after restores it. They can safely be removed. if g.verbose > 3: print( f"[ReshapeMatMulReshapePattern] compatible shapes: " f"mshape_left={mshape_left} " f"shape_left={shape_left} | mshape_left={mshape_right} " f"shape_left={shape_right}" ) return MatchResult( self, [node_before_left, node_before_right, node, next_node], self.apply, )
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node_before_left: NodeProto, node_before_right: NodeProto, node: NodeProto, next_node: NodeProto, ) -> List[NodeProto]: res = [] if g.is_used_more_than_once(node_before_left.output[0]): res.append(node_before_left) if g.is_used_more_than_once(node_before_right.output[0]): res.append(node_before_right) new_node = g.make_node( "MatMul", [node_before_left.input[0], node_before_right.input[0]], next_node.output, name=f"{self.__class__.__name__}--{node.name}", doc_string=next_node.doc_string, ) res.append(new_node) return res
[docs] class TransposeMatMulPattern(PatternOptimization): """ Replaces the sequence Transpose, Matmul or Gemm into Gemm Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(128, 64))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(128, 32))) nodes.append(oh.make_node("Transpose", ["X"], ["xm1"], perm=[1, 0])) nodes.append(oh.make_node("MatMul", ["xm1", "Y"], ["Z"])) outputs.append(oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(32, 64))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(128, 64))) inputs.append(oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(128, 32))) nodes.append(oh.make_node("Gemm", ["X", "Y"], ["Z"], transA=1, transB=0)) outputs.append(oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(32, 64))) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type not in {"MatMul", "Gemm"} or node.domain != "": return self.none() if not g.has_rank(node.input[0]) or not g.has_rank(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if g.get_rank(node.input[0]) != 2 or g.get_rank(node.input[1]) != 2: return self.none(node, inspect.currentframe().f_lineno) nodes_before = [g.node_before(node.input[0]), g.node_before(node.input[1])] ns = [ (n if n is not None and n.op_type == "Transpose" and n.domain == "" else None) for n in nodes_before ] if len([_ for _ in ns if _ is not None]) == 0: return self.none(node, inspect.currentframe().f_lineno) if g.has_processor("CUDA"): nns = [] for n in ns: if n is None: nns.append(n) continue if g.is_used_more_than_once(n.output[0]): nns.append(None) continue nns.append(n) if len([_ for _ in ns if _ is not None]) == 0: return self.none(node, inspect.currentframe().f_lineno) ns = nns for n in ns: if n is None: continue perm = tuple(g.get_attribute(n, "perm").ints) if perm != (1, 0): # unexpected transpose return self.none(node, inspect.currentframe().f_lineno) if len([_ for _ in ns if _ is not None]) == 0: return self.none(node, inspect.currentframe().f_lineno) # At this stage, one or two inputs are transposed before being used. # MatMul or Gemm are operating on 2D tensors. nodes = [*ns, node] if node.op_type == "Gemm": if nodes[1] is not None: # nodes_before_right atts = g.get_attributes_with_default(node, transA=0, transB=0) if atts.get("transB", 0) != atts.get("transA", 0) and g.is_constant( node.input[1] ): # it is better to do constant folding rather than changing transB return self.none(node, inspect.currentframe().f_lineno) if nodes[0] is not None: # nodes_before_left atts = g.get_attributes_with_default(node, transA=0, transB=0) if atts.get("transB", 0) != atts.get("transA", 0) and g.is_constant( node.input[0] ): # it is better to do constant folding rather than changing transB return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node_before_left: Optional[NodeProto], node_before_right: Optional[NodeProto], node: NodeProto, ) -> List[NodeProto]: inputs = [ (node.input[0] if node_before_left is None else node_before_left.input[0]), (node.input[1] if node_before_right is None else node_before_right.input[0]), *node.input[2:], ] transA = 0 if node_before_left is None else 1 transB = 0 if node_before_right is None else 1 keep = [] for att in node.attribute: if att.name in {"alpha", "beta"}: keep.append(att) elif att.name == "transA": transA = (att.i + transA) % 2 elif att.name == "transB": transB = (att.i + transB) % 2 else: raise NotImplementedError( f"Unexpected attribute {att.name!r}={att} for node={node}" ) new_node = g.make_node( "Gemm", inputs, node.output, name=f"{self.__class__.__name__}--{node.name}", transA=transA, transB=transB, doc_string=node.doc_string, ) new_node.attribute.extend(keep) res = [new_node] if node_before_left is not None and g.is_used_more_than_once(node_before_left.output[0]): # This is not efficient on CUDA. res.append(node_before_left) if node_before_right is not None and g.is_used_more_than_once( node_before_right.output[0] ): # This is not efficient on CUDA. res.append(node_before_right) return res
[docs] class TransposeReshapeMatMulPattern(PatternOptimization): """ Replaces the sequence Transpose, Reshape, Matmul into Reshape, Transpose, Matmul if possible. Another optimizer will optimizes this sequence by using Gemm or better. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("yts", onnx.TensorProto.FLOAT, shape=(2, 2, 7, 3)) ) inputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(4, 3, 7))) inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(2, 2, 5, 7)) ) nodes.append( oh.make_node( "Constant", [], ["shape"], value=onh.from_array(np.array([2, 2, 7, 3], dtype=np.int64), name="value"), ) ) nodes.append(oh.make_node("MatMul", ["X", "yts"], ["Z"])) nodes.append(oh.make_node("Reshape", ["yt", "shape"], ["yts"])) nodes.append(oh.make_node("Transpose", ["Y"], ["yt"], perm=[0, 2, 1])) outputs.append( oh.make_tensor_value_info("yts", onnx.TensorProto.FLOAT, shape=(2, 2, 7, 3)) ) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(2, 2, 5, 3)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 26), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("yts", onnx.TensorProto.FLOAT, shape=(2, 2, 7, 3)) ) inputs.append(oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(4, 3, 7))) inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(2, 2, 5, 7)) ) nodes.append( oh.make_node( "Constant", [], ["init7_s4_2_2_3_7"], value=onh.from_array(np.array([2, 2, 3, 7], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Reshape", ["Y", "init7_s4_2_2_3_7"], ["TransposeReshapeMatMulPatternL_Y"] ) ) nodes.append( oh.make_node( "Transpose", ["TransposeReshapeMatMulPatternL_Y"], ["yts"], perm=[0, 1, 3, 2] ) ) nodes.append(oh.make_node("MatMul", ["X", "yts"], ["Z"])) outputs.append( oh.make_tensor_value_info("yts", onnx.TensorProto.FLOAT, shape=(2, 2, 7, 3)) ) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=(2, 2, 5, 3)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ def check_transpose_node(self, g: "GraphBuilder", name: str) -> bool: # noqa: F821 if g.is_used_more_than_once(name): return False node = g.node_before(name) if node is None or node.op_type != "Reshape": return False if g.is_used_more_than_once(node.input[0]): return False node_node = g.node_before(node.input[0]) if node_node is None or node_node.op_type != "Transpose": return False perm = tuple(g.get_attribute(node_node, "perm").ints) id_perm = tuple(range(len(perm))) if perm[:-2] != id_perm[:-2] or (perm[-1], perm[-2]) != id_perm[-2:]: return False return True
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], left_first: bool = True, ) -> Optional[MatchResult]: if node.op_type != "MatMul" or node.domain != "": return self.none() left = self.check_transpose_node(g, node.input[0]) right = self.check_transpose_node(g, node.input[1]) if left and left_first: # even right is ok, it will be handled by another call to the optimizer. side = "left" elif right: side = "right" else: return self.none(node, inspect.currentframe().f_lineno) if side == "left": node_left = g.node_before(node.input[0]) node_left_tr = g.node_before(node_left.input[0]) node_right = None node_right_tr = None shape_name = node_left.input[1] else: node_left = None node_left_tr = None node_right = g.node_before(node.input[1]) node_right_tr = g.node_before(node_right.input[0]) shape_name = node_right.input[1] if not g.is_constant(shape_name): if left_first and right: return self.match(g, node, matched, left_first=False) return self.none(node, inspect.currentframe().f_lineno) shape_before = g.get_shape((node_left or node_right).input[0]) shape_after = g.get_shape((node_left or node_right).output[0]) if shape_before[-2:] != shape_after[-2:]: # the two last dimension are not modified by the reshape if left_first and right: return self.match(g, node, matched, left_first=False) return self.none(node, inspect.currentframe().f_lineno) return MatchResult( self, [node, node_left, node_left_tr, node_right, node_right_tr], self.apply, insert_at=node, )
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 node: NodeProto, node_left: Optional[NodeProto], node_left_tr: Optional[NodeProto], node_right: Optional[NodeProto], node_right_tr: Optional[NodeProto], ) -> List[NodeProto]: shape = list(g.get_computed_constant((node_left or node_right).input[1])) shape[-2], shape[-1] = shape[-1], shape[-2] shape_name = g.make_initializer( "", np.array(shape, dtype=np.int64), source="TransposeReshapeMatMulPattern.apply.shape_name", ) if node_right is None: # left side perm = list(range(g.get_rank(node.input[0]))) perm[-2], perm[-1] = perm[-1], perm[-2] left_name = g.unique_name(f"{self.__class__.__name__}L_{node_left_tr.input[0]}") res = [ g.make_node( "Reshape", [node_left_tr.input[0], shape_name], [left_name], name=f"{self.__class__.__name__}--{node.name}", ), g.make_node( "Transpose", [left_name], [node.input[0]], perm=tuple(perm), name=f"{self.__class__.__name__}--{node.name}", ), node, ] else: # right side perm = list(range(g.get_rank(node.input[1]))) perm[-2], perm[-1] = perm[-1], perm[-2] right_name = g.unique_name(f"{self.__class__.__name__}L_{node_right_tr.input[0]}") res = [ g.make_node( "Reshape", [node_right_tr.input[0], shape_name], [right_name], name=f"{self.__class__.__name__}--{node.name}", ), g.make_node( "Transpose", [right_name], [node.input[1]], perm=tuple(perm), name=f"{self.__class__.__name__}--{node.name}", ), node, ] return res
[docs] class SwitchReshapeActivationPattern(PatternOptimization): """ Swiches Gelu and Reshape after a Gemm or a MatMul. Gelu can also be Gelu, Exp, Elu, Relu, Tan, Tanh, Cos, Cosh, Sin, Sinh, Erf, LeakyRelu, PRelu, Selu, Softmax, Softplus. Reshape can also be Transpose. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(3, 2, 5, 6)) ) inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(3, 2, 6, 5)) ) nodes.append(oh.make_node("MatMul", ["X", "Y"], ["mm"])) nodes.append(oh.make_node("Transpose", ["mm"], ["tmm"], perm=[0, 2, 1, 3])) nodes.append(oh.make_node("Relu", ["tmm"], ["Z"])) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "b", "c", "d")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=(3, 2, 5, 6)) ) inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=(3, 2, 6, 5)) ) nodes.append( oh.make_node("MatMul", ["X", "Y"], ["SwitchReshapeActivationPatternL_mm"]) ) nodes.append( oh.make_node( "Relu", ["SwitchReshapeActivationPatternL_mm"], ["SwitchReshapeActivationPatternL_tmm"], ) ) nodes.append( oh.make_node( "Transpose", ["SwitchReshapeActivationPatternL_tmm"], ["Z"], perm=[0, 2, 1, 3] ) ) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "b", "c", "d")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], left_first: bool = True, ) -> Optional[MatchResult]: if ( node.op_type not in { "Cos", "Cosh", "Elu", "Erf", "Exp", "Gelu", "LeakyRelu", "PRelu", "Relu", "Selu", "Sin", "Sinh", "Softmax", "Softplus", "Tan", "Tanh", } or node.domain != "" ): return self.none() if g.is_used_more_than_once(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) before = g.node_before(node.input[0]) if before is None or g.is_used_more_than_once(before.input[0]): return self.none(before, inspect.currentframe().f_lineno) if before.op_type not in {"Reshape", "Transpose"} or node.domain != "": return self.none(node, inspect.currentframe().f_lineno) before_before = g.node_before(before.input[0]) if ( before_before is None or before_before.op_type not in {"Gemm", "MatMul"} or before_before.domain != "" ): return self.none(node, inspect.currentframe().f_lineno) return MatchResult( self, [before_before, before, node], self.apply, insert_at=before if before.op_type == "Reshape" else before_before, )
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 mm_node: NodeProto, tr_node: NodeProto, f_node: NodeProto, ) -> List[NodeProto]: name1 = g.unique_name(f"{self.__class__.__name__}L_{mm_node.output[0]}") name2 = g.unique_name(f"{self.__class__.__name__}L_{tr_node.output[0]}") nodes = [ g.make_node( mm_node.op_type, mm_node.input, [name1], domain=mm_node.domain, name=f"{self.__class__.__name__}--{mm_node.name}", ), g.make_node( f_node.op_type, [name1], [name2], domain=f_node.domain, name=f"{self.__class__.__name__}--{f_node.name}", ), g.make_node( tr_node.op_type, [name2, *tr_node.input[1:]], f_node.output, domain=tr_node.domain, name=f"{self.__class__.__name__}--{tr_node.name}", ), ] nodes[0].attribute.extend(mm_node.attribute) nodes[1].attribute.extend(f_node.attribute) nodes[2].attribute.extend(tr_node.attribute) return nodes
[docs] class ShapeBasedMatMulToMulPattern(PatternOptimization): """ MatMul can be replaced by Mul with broadcast. It makes it easier to detect optimization pattern with Expand operators. Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 1, "c")) ) inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b", 1)) ) nodes.append(oh.make_node("MatMul", ["X", "Y"], ["Zt"])) nodes.append(oh.make_node("Transpose", ["Zt"], ["Z"], perm=[0, 2, 1])) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "c", "b")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 1, "c")) ) inputs.append( oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape=("a", "b", 1)) ) nodes.append( oh.make_node( "Constant", [], ["X-ZEROS2"], value=onh.from_array(np.array([0, 1, -1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["Y-ZEROS2"], value=onh.from_array(np.array([0, -1, 1], dtype=np.int64), name="value"), ) ) nodes.append(oh.make_node("Reshape", ["X", "X-ZEROS2"], ["X2"])) nodes.append(oh.make_node("Reshape", ["Y", "Y-ZEROS2"], ["Y2"])) nodes.append(oh.make_node("Mul", ["X2", "Y2"], ["Z"])) outputs.append( oh.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, shape=("a", "c", "b")) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], left_first: bool = True, ) -> Optional[MatchResult]: if node.op_type != "MatMul" or node.domain != "": return self.none() if not g.has_shape(node.input[0]) or not g.has_shape(node.input[1]): return self.none(node, inspect.currentframe().f_lineno) shape1 = g.get_shape(node.input[0]) shape2 = g.get_shape(node.input[1]) if len(shape1) < 2: return self.none(node, inspect.currentframe().f_lineno) if len(shape2) < 2: return self.none(node, inspect.currentframe().f_lineno) if shape1[-1] != 1 or shape2[-2] != 1: return self.none(node, inspect.currentframe().f_lineno) next_node = g.next_nodes(node.output[0]) if len(next_node) == 1 and next_node[0].op_type == "Transpose": transpose = next_node[0] perm = tuple(g.get_attribute(transpose, "perm").ints) n = len(perm) if ( g.has_rank(node.input[0]) and g.has_rank(node.input[1]) and g.get_rank(node.input[0]) >= 2 and g.get_rank(node.input[1]) >= 2 and perm[:-2] == tuple(range(n - 2)) and perm[-2] == n - 1 and perm[-1] == n - 2 ): return MatchResult(self, [node, transpose], self.apply, insert_at=node) return MatchResult(self, [node, None], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 mm_node: NodeProto, transpose: Optional[NodeProto], ) -> List[NodeProto]: if transpose is None: return [ g.make_node( "Mul", mm_node.input, mm_node.output, name=f"{self.__class__.__name__}--{mm_node.name}", ) ] rk1 = g.get_rank(mm_node.input[0]) rk2 = g.get_rank(mm_node.input[1]) assert rk1 > 2, f"rank({mm_node.input[0]})=={rk1} > 2 - unexpected" assert rk2 > 2, f"rank({mm_node.input[1]})=={rk2} > 2 - unexpected" rsh1 = g.make_initializer( g.unique_name(f"{mm_node.input[0]}-ZEROS"), np.array([0] * (rk1 - 2) + [1, -1], dtype=np.int64), source=f"{self.__class__.__name__}.0", ) rsh2 = g.make_initializer( g.unique_name(f"{mm_node.input[1]}-ZEROS"), np.array([0] * (rk2 - 2) + [-1, 1], dtype=np.int64), source=f"{self.__class__.__name__}.0", ) new1 = g.unique_name(mm_node.input[0]) new2 = g.unique_name(mm_node.input[1]) return [ g.make_node( "Reshape", [mm_node.input[0], rsh1], [new1], name=f"{self.__class__.__name__}--{mm_node.name}", ), g.make_node( "Reshape", [mm_node.input[1], rsh2], [new2], name=f"{self.__class__.__name__}--{mm_node.name}", ), g.make_node( "Mul", [new1, new2], transpose.output, name=f"{self.__class__.__name__}--{mm_node.name}", ), ]