Source code for experimental_experiment.xoptim.patterns_ort.llm_optim

import inspect
from typing import List, Optional, Sequence
import numpy as np
from onnx import NodeProto
from ...helpers import tensor_dtype_to_np_dtype
from ..patterns_api import MatchResult, PatternOptimization
from ..patterns.onnx_attention import FunctionAttentionPattern
from ..patterns.onnx_rotary import FunctionHalfRotaryEmbeddingPattern


[docs] class ContribRotaryEmbeddingPattern(PatternOptimization): """ Very similar to :class:`experimental_experiment.xoptim.patterns.onnx_rotary.RotaryEmbeddingPattern`. """ _operator_name = FunctionHalfRotaryEmbeddingPattern._operator_name _domain_name = FunctionHalfRotaryEmbeddingPattern._domain_name def __init__(self, verbose: int = 0, priority: int = 2): super().__init__(verbose, priority) self._info = []
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != self._operator_name or node.domain != self._domain_name: return self.none() if not g.has_shape(node.input[0]) or g.get_rank(node.input[0]) != 4: return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(node.input[1]) or not g.has_shape(node.input[2]): return self.none(node, inspect.currentframe().f_lineno) shape_cos = g.get_shape(node.input[1]) shape_sin = g.get_shape(node.input[2]) if shape_cos != shape_sin: return self.none(node, inspect.currentframe().f_lineno) if len(shape_cos) != 4: return self.none(node, inspect.currentframe().f_lineno) if shape_cos[1] != 1 or shape_sin[1] != 1: return self.none(node, inspect.currentframe().f_lineno) # if shape_cos[0] != 1 or shape_sin[0] != 1: # batch size is not 1 because position_ids was involved in the # computation of cos/sin caches. # return self.none(node, inspect.currentframe().f_lineno) concat_cos = g.node_before(node.input[1]) if concat_cos is None or concat_cos.op_type != "Concat" or concat_cos.domain != "": return self.none(node, inspect.currentframe().f_lineno) if concat_cos.input[0] != concat_cos.input[1]: return self.none(node, inspect.currentframe().f_lineno) if g.get_attribute(concat_cos, "axis").i != -1: return self.none(node, inspect.currentframe().f_lineno) concat_sin = g.node_before(node.input[2]) if concat_sin is None or concat_sin.op_type != "Concat" or concat_sin.domain != "": return self.none(node, inspect.currentframe().f_lineno) if concat_sin.input[0] != concat_sin.input[1]: return self.none(node, inspect.currentframe().f_lineno) if g.get_attribute(concat_sin, "axis").i != -1: return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) # If cos_cache[-1] + sin_cache[-1] == X.shape[-1], # then there is no split before. split_node = g.node_before(node.input[0]) if split_node is None or split_node.op_type != "Split" or split_node.domain != "": if not g.has_shape(concat_cos.input[0]) or not g.has_shape(concat_sin.input[0]): return self.none(node, inspect.currentframe().f_lineno) cos_shape = g.get_shape(concat_cos.input[0]) sin_shape = g.get_shape(concat_sin.input[0]) input_shape = g.get_shape(node.input[0]) if g.builder.evaluate_dimension_equality_with_constraints( input_shape[-1], cos_shape[-1], "+", sin_shape[-1] ): shape = g.get_shape(node.input[0]) self._info.append((node.input[0], shape)) if not isinstance(shape[1], int): # Number of heads is not fixed" return self.none(node, inspect.currentframe().f_lineno) # No split before, no concat after but there could be still position ids return self._match_last_part( g, concat_cos, concat_sin, None, node, None, comment="path with no split before, no concat after", ) # return MatchResult( # self, # [None, concat_cos, concat_sin, None, node, None], # self.apply, # comment="path with no split before, no concat after", # ) if split_node is None or split_node.op_type != "Split" or split_node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(split_node.input[0]): return self.none(node, inspect.currentframe().f_lineno) shape_input = g.get_shape(split_node.input[0]) if not isinstance(shape_input[1], int): # Not a fixed number of heads. return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(split_node.input[1]): return self.none(node, inspect.currentframe().f_lineno) cst = g.get_computed_constant(split_node.input[1]) if cst.shape != (2,): return self.none(node, inspect.currentframe().f_lineno) next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 1: return self.none(node, inspect.currentframe().f_lineno) concat_node = next_nodes[0] if concat_node.op_type != "Concat" or concat_node.domain != "": return self.none(node, inspect.currentframe().f_lineno) if split_node.output[1] != concat_node.input[1]: return self.none(node, inspect.currentframe().f_lineno) axis = g.get_attribute(concat_node, "axis").i if axis != -1: return self.none(node, inspect.currentframe().f_lineno) input_name = node.input[0] if split_node is None else split_node.input[0] shape = g.get_shape(input_name) self._info.append((input_name, shape)) if not isinstance(shape[1], int): # Number of heads is not fixed" return self.none(node, inspect.currentframe().f_lineno) return self._match_last_part( g, concat_cos, concat_sin, split_node, node, concat_node, comment="path with split before, concat after", )
def _match_last_part( self, g: "GraphBuilderPatternOptimization", # noqa: F821 concat_cos: NodeProto, concat_sin: NodeProto, split_node: Optional[NodeProto], node: NodeProto, concat_node: Optional[NodeProto], comment: str, ) -> Optional[MatchResult]: # Finally, we need to check if position_ids exists or it is given # a default value. common = self._find_common_ancestor(g, concat_cos, concat_sin) if common is not None and not common: # cos/sin are switched. The pattern cannot match. return self.none(node, inspect.currentframe().f_lineno) if ( common and common[0].op_type == "Mul" and {"Sin", "Cos"} & set(n.op_type for n in common) ): # pattern FunctionCosSinCache has yet to be triggered first. return self.none(node, inspect.currentframe().f_lineno) if ( common and common[0].op_type.startswith("CosSinCache") and common[0].domain == self._domain_name ): # Finally, we need to check if position_ids exists or if it is given # a default value. cos_sin = common[0] if not g.has_shape(cos_sin.input[0]) or not g.has_shape(cos_sin.input[1]): return self.none(node, inspect.currentframe().f_lineno) expand_node = g.node_before(cos_sin.input[1]) if expand_node is None: return self.none(node, inspect.currentframe().f_lineno) shape_expand = g.builder.value_as_shape(expand_node.input[1]) if shape_expand is None or len(shape_expand) != 3 or shape_expand[1:] != (1, 1): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(expand_node.input[0]): return self.none(node, inspect.currentframe().f_lineno) wei_shape = g.get_shape(expand_node.input[0]) if wei_shape[0] != 1: return self.none(node, inspect.currentframe().f_lineno) position_ids_shape = g.get_shape_renamed(cos_sin.input[0]) weights_shape = g.get_shape_renamed(cos_sin.input[1]) if ( len(position_ids_shape) != 2 or len(weights_shape) != 3 or position_ids_shape[0] != weights_shape[0] ): return self.none(node, inspect.currentframe().f_lineno) # Then we need to add those nodes to the matched nodes. return MatchResult( self, [expand_node, concat_cos, concat_sin, split_node, node, concat_node, *common], self.apply, comment=f"{comment} / with CosSinCache", ) return MatchResult( self, [None, concat_cos, concat_sin, split_node, node, concat_node], self.apply, insert_at=None if g.is_used_more_than_once(concat_cos.output[0]) else concat_node, comment=f"{comment} / without CosSinCache", ) def _find_common_ancestor( self, g: "GraphBuilderPatternOptimization", # noqa: F821 concat_cos: NodeProto, concat_sin: NodeProto, ) -> Optional[List[NodeProto]]: anc_cos, anc_sin = concat_cos, concat_sin nodes = [] for _it in range(5): cos_name, sin_name = anc_cos.input[0], anc_sin.input[0] anc_cos = g.node_before(cos_name) anc_sin = g.node_before(sin_name) if anc_cos is None or anc_sin is None: return None if ( anc_cos.input[0] == anc_sin.input[0] and id(anc_cos) == id(anc_sin) and len(anc_cos.output) == 2 ): if cos_name != anc_cos.output[0] or sin_name != anc_cos.output[1]: # cos/sin were switched, the pattern should not match at all. return [] nodes.append(anc_cos) return nodes[::-1] nodes.extend([anc_cos, anc_sin]) return None
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 expand_node: Optional[NodeProto], concat_cos: NodeProto, concat_sin: NodeProto, split_node: NodeProto, half_node: NodeProto, concat_node: NodeProto, *prefix_nodes: Sequence[NodeProto], ) -> List[NodeProto]: if split_node is None: rotary_dim = None shape = g.get_shape(half_node.input[0]) main_input = half_node.input[0] main_output = half_node.output[0] else: cst = g.get_computed_constant(split_node.input[1]) rotary_dim = int(cst[0]) shape = g.get_shape(split_node.input[0]) main_input = split_node.input[0] main_output = concat_node.output[0] assert isinstance(shape[1], int), ( f"Number of heads is not fixed, shape(" f"{split_node.input[0] if split_node is not None else half_node.input[0]}" f")={shape}, info={self._info}" ) num_heads = shape[1] used_twice_cos = g.is_used_more_than_once(concat_cos.output[0]) used_twice_sin = g.is_used_more_than_once(concat_sin.output[0]) rotary_nodes = [expand_node, *prefix_nodes] if used_twice_cos or used_twice_sin else [] if used_twice_cos: rotary_nodes.append(concat_cos) if used_twice_sin: rotary_nodes.append(concat_sin) batch_name = g.unique_name(f"{self.__class__.__name__}--{half_node.input[0]}--batch") zeroone = g.make_initializer( "", np.array([0, 1], dtype=np.int64), source=f"{self.__class__.__name__}.01" ) one = g.make_initializer("", g.ONE, source=f"{self.__class__.__name__}.1") one_no_dim = g.make_initializer("", g.ONE_NO_DIM, source=f"{self.__class__.__name__}.1d") # position_ids zero_no_dim = g.make_initializer( "", g.ZERO_NO_DIM, source=f"{self.__class__.__name__}.0d" ) added_nodes = [] if prefix_nodes: assert expand_node is not None, "expand node is missing, pattern should not match" assert prefix_nodes[0].op_type.startswith( "CosSinCache" ), f"Unexpected first node {prefix_nodes[0]}" cos_sin = prefix_nodes[0] position_ids = cos_sin.input[0] (max_ids, max_ids_1, new_positions_ids, cos_out, sin_out, range_ids) = [ g.unique_name(f"{self.__class__.__name__}--{position_ids}") for i in range(6) ] zero = g.make_initializer("", g.ZERO, source=f"{self.__class__.__name__}.0") added_nodes = [ g._make_node("ReduceMax", [position_ids], [max_ids], keepdims=0), g._make_node("Add", [max_ids, one_no_dim], [max_ids_1]), g._make_node("Range", [zero_no_dim, max_ids_1, one_no_dim], [range_ids]), g._make_node("Unsqueeze", [range_ids, zero], [new_positions_ids]), g._make_node( cos_sin.op_type, [new_positions_ids, expand_node.input[0]], [cos_out, sin_out], domain=cos_sin.domain, ), ] cos_cur, sin_cur = cos_out, sin_out for i in range(1, len(prefix_nodes), 2): ncos, nsin = prefix_nodes[i : i + 2] if ncos.op_type == "Concat": break rcos, rsin = [ g.unique_name(f"{self.__class__.__name__}--{position_ids}") for i in range(2) ] added_nodes.extend( [ g._make_node(ncos.op_type, [cos_cur, *ncos.input[1:]], [rcos]), g._make_node(ncos.op_type, [sin_cur, *nsin.input[1:]], [rsin]), ] ) if ncos.attribute: added_nodes[-2].attribute.extend(ncos.attribute) if nsin.attribute: added_nodes[-1].attribute.extend(nsin.attribute) cos_cur, sin_cur = rcos, rsin cos_input, sin_input = cos_cur, sin_cur range_nodes = [] else: assert expand_node is None, f"Unexpected expand node {expand_node}" position_ids = g.unique_name( f"{self.__class__.__name__}--{half_node.input[0]}_position_ids" ) seq_length = g.unique_name( f"{self.__class__.__name__}--{half_node.input[0]}--seq_length" ) seq_length_squeezed = g.unique_name( f"{self.__class__.__name__}--{half_node.input[0]}--seqsq" ) exp_shape = g.unique_name(f"{self.__class__.__name__}--{half_node.input[0]}_pshape") flat_pids = g.unique_name( f"{self.__class__.__name__}--{half_node.input[0]}_flat_pids" ) cos_input, sin_input = concat_cos.input[0], concat_sin.input[0] range_nodes = [ g._make_node("Shape", [main_input], [batch_name], start=0, end=1), g._make_node("Shape", [main_input], [seq_length], start=2, end=3), g._make_node("Squeeze", [seq_length], [seq_length_squeezed]), g._make_node( "Range", [zero_no_dim, seq_length_squeezed, one_no_dim], [flat_pids] ), g._make_node("Concat", [batch_name, one], [exp_shape], axis=0), g._make_node("Expand", [flat_pids, exp_shape], [position_ids]), ] cos_name = g.unique_name(f"{self.__class__.__name__}--{half_node.input[1]}") sin_name = g.unique_name(f"{self.__class__.__name__}--{half_node.input[2]}") rotary_nodes.extend( [ *added_nodes, g._make_node("Squeeze", [cos_input, zeroone], [cos_name]), g._make_node("Squeeze", [sin_input, zeroone], [sin_name]), *range_nodes, ] ) rotary_nodes = [n for n in rotary_nodes if n] for node in rotary_nodes: if not node.name: node.name = g.builder.unique_node_name( f"{self.__class__.__name__}--{half_node.name}" ) kwargs = {} if rotary_dim is None else {"rotary_embedding_dim": rotary_dim} rotary_node = g.make_node( "RotaryEmbedding", [main_input, position_ids, cos_name, sin_name], [main_output], name=f"{self.__class__.__name__}--{half_node.name}", num_heads=num_heads, domain="com.microsoft", **kwargs, ) rotary_nodes.append(rotary_node) return rotary_nodes
[docs] class ContribRotaryEmbedding3DPattern(PatternOptimization): """ Extension to :class:`experimental_experiment.xoptim.patterns_ort.llm_optim.ContribRotaryEmbeddingPattern`, turn the operator into a 3D operator including the transpose. """
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "RotaryEmbedding" or node.domain != "com.microsoft": return self.none() transpose = g.node_before(node.input[0]) if transpose is None or transpose.op_type != "Transpose" or transpose.domain != "": return self.none(node, inspect.currentframe().f_lineno) perm = tuple(g.get_attribute(transpose, "perm").ints) if perm != (0, 2, 1, 3): return self.none(node, inspect.currentframe().f_lineno) if g.is_used_more_than_once(node.input[0]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, [transpose, node], self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", transpose: NodeProto, rotary: NodeProto # noqa: F821 ) -> List[NodeProto]: last_dim = g.unique_name(f"{transpose.input[0]}::Shape3") new_shape2 = g.unique_name(f"{transpose.input[0]}::Shape+1") new_shape = g.make_initializer( "", np.array([0, 0, -1], dtype=np.int64), source=f"{self.__class__.__name__}.00_1" ) reshaped = g.unique_name(f"{transpose.input[0]}::3D") rot_name = g.unique_name(f"{transpose.input[0]}::3Dr") reshaped2 = g.unique_name(f"{transpose.input[0]}::4D") nodes = [ g._make_node("Reshape", [transpose.input[0], new_shape], [reshaped]), g._make_node( rotary.op_type, [reshaped, *rotary.input[1:]], [rot_name], domain=rotary.domain ), g._make_node("Shape", [transpose.input[0]], [last_dim], start=3), g._make_node("Concat", [new_shape, last_dim], [new_shape2], axis=0), g._make_node("Reshape", [rot_name, new_shape2], [reshaped2]), g._make_node("Transpose", [reshaped2], [rotary.output[0]], perm=[0, 2, 1, 3]), ] if rotary.attribute: nodes[1].attribute.extend(rotary.attribute) for node in nodes: if not node.name: node.name = g.builder.unique_node_name( f"{self.__class__.__name__}--{rotary.name}" ) return nodes
[docs] class MultiHeadAttention3DPattern(PatternOptimization): """ Merges multiple nodes into MultiHeadAttention. It assumes pattern :class:`experimental_experiment.xoptim.patterns.onnx_attention.FunctionAttentionPattern` was triggered before. """ _prefix_operator_name = f"{FunctionAttentionPattern._operator_name}_to"
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if ( not node.op_type.startswith(self._prefix_operator_name) or node.domain != FunctionAttentionPattern._domain_name or len(node.input) != 5 ): return self.none() if not g.is_constant_scalar(node.input[4]): return self.none(node, inspect.currentframe().f_lineno) q_transpose = g.node_before(node.input[0]) expected_perm = (0, 2, 1, 3) if ( q_transpose is None or q_transpose.op_type != "Transpose" or tuple(g.get_attribute(q_transpose, "perm").ints) != expected_perm ): return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(q_transpose.input[0]): return self.none(node, inspect.currentframe().f_lineno) shape = g.get_shape(q_transpose.input[0]) if not isinstance(shape[2], int): return self.none(node, inspect.currentframe().f_lineno) k_concat = g.node_before(node.input[1]) if ( k_concat is None or k_concat.op_type != "Concat" or g.get_attribute(k_concat, "axis").i != -2 or len(k_concat.input) != 2 ): return self.none(node, inspect.currentframe().f_lineno) k_transpose = g.node_before(k_concat.input[1]) if ( k_transpose is None or k_transpose.op_type != "Transpose" or tuple(g.get_attribute(k_transpose, "perm").ints) != expected_perm ): return self.none(node, inspect.currentframe().f_lineno) v_concat = g.node_before(node.input[2]) if ( v_concat is None or v_concat.op_type != "Concat" or g.get_attribute(v_concat, "axis").i != -2 or len(v_concat.input) != 2 ): return self.none(node, inspect.currentframe().f_lineno) v_transpose = g.node_before(v_concat.input[1]) if ( v_transpose is None or v_transpose.op_type != "Transpose" or tuple(g.get_attribute(v_transpose, "perm").ints) != expected_perm ): return self.none(node, inspect.currentframe().f_lineno) transposes = g.next_nodes(node.output[0]) if len(transposes) != 1: return self.none(node, inspect.currentframe().f_lineno) transpose = transposes[0] if ( transpose is None or transpose.op_type != "Transpose" or tuple(g.get_attribute(transpose, "perm").ints) != expected_perm ): return self.none(node, inspect.currentframe().f_lineno) if ( not g.has_shape(q_transpose.input[0]) or g.get_rank(q_transpose.input[0]) != 4 or not isinstance(g.get_shape(q_transpose.input[0])[-1], int) ): return self.none(node, inspect.currentframe().f_lineno) for n in [q_transpose, k_transpose, v_transpose, node]: if g.is_used_more_than_once(n.output[0]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult( self, [q_transpose, k_transpose, k_concat, v_transpose, v_concat, node, transpose], self.apply, )
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 q_transpose: NodeProto, k_transpose: NodeProto, k_concat: NodeProto, v_transpose: NodeProto, v_concat: NodeProto, attention: NodeProto, transpose: NodeProto, ) -> List[NodeProto]: query = q_transpose.input[0] keys = k_transpose.input[0] values = v_transpose.input[0] mask = attention.input[3] past_keys = k_concat.input[0] past_values = v_concat.input[0] num_heads = g.get_shape(query)[2] scale = float(g.get_constant_scalar(attention.input[4])) ** 2 dtype = tensor_dtype_to_np_dtype(g.get_type(query)) zero = g.make_initializer( "", np.array([0], dtype=dtype), source=f"{self.__class__.__name__}.0" ) minfty = g.make_initializer( "", np.array([-np.inf], dtype=dtype), source=f"{self.__class__.__name__}._inf" ) init_00_1 = g.make_initializer( "", np.array([0, 0, -1], dtype=np.int64), source=f"{self.__class__.__name__}.00_1" ) last = g.get_shape(query)[-1] init_00_1l = g.make_initializer( "", np.array([0, 0, -1, last], dtype=np.int64), source=f"{self.__class__.__name__}.00_1l", ) r_query = g.unique_name(f"{self.__class__.__name__}--{query}") r_keys = g.unique_name(f"{self.__class__.__name__}--{keys}") r_values = g.unique_name(f"{self.__class__.__name__}--{values}") attention_bias = g.unique_name(f"{self.__class__.__name__}--{mask}") r_output = g.unique_name(f"{self.__class__.__name__}--{transpose.output[0]}") nodes = [ g._make_node("Reshape", [query, init_00_1], [r_query]), g._make_node("Reshape", [keys, init_00_1], [r_keys]), g._make_node("Reshape", [values, init_00_1], [r_values]), g._make_node("Where", [mask, zero, minfty], [attention_bias]), g._make_node( "MultiHeadAttention", [r_query, r_keys, r_values, "", "", attention_bias, past_keys, past_values], [r_output, k_concat.output[0], v_concat.output[0]], num_heads=num_heads, scale=scale, domain="com.microsoft", ), g._make_node("Reshape", [r_output, init_00_1l], [transpose.output[0]]), ] for node in nodes: if not node.name: node.name = g.builder.unique_node_name( f"{self.__class__.__name__}--{attention.name}" ) return nodes