Source code for experimental_experiment.xoptim.patterns.onnx_attention

import inspect
from typing import List, Optional, Tuple, Union
import numpy as np
from onnx import NodeProto, TensorProto
from ...helpers import tensor_dtype_to_np_dtype
from ...xbuilder import FunctionOptions, GraphBuilder
from ..patterns_api import MatchResult, PatternOptimization


[docs] class FunctionAttentionPattern(PatternOptimization): """ Merges Attention nodes into a local function. That includes a version for GroupQueryAttention (see second pattern). Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("intermediate", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info( "values", onnx.TensorProto.FLOAT, shape=("av", "bv", "cv", "dv") ) ) inputs.append( oh.make_tensor_value_info( "keys", onnx.TensorProto.FLOAT, shape=("ak", "bk", "ck", "dk") ) ) inputs.append( oh.make_tensor_value_info("scale_sqrt", onnx.TensorProto.FLOAT, shape=(1,)) ) inputs.append( oh.make_tensor_value_info( "mask", onnx.TensorProto.BOOL, shape=("am", "bm", "cm", "dm") ) ) inputs.append( oh.make_tensor_value_info( "query", onnx.TensorProto.FLOAT, shape=("aq", "bq", "cq", "dq") ) ) nodes.append( oh.make_node( "Constant", [], ["scale_sqrt"], value=onh.from_array( np.array([0.3162277638912201], dtype=np.float32), name="value" ), ) ) nodes.append( oh.make_node( "Constant", [], ["zero"], value=onh.from_array(np.array([0.0], dtype=np.float32), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["minfty"], value=onh.from_array(np.array([-np.inf], dtype=np.float32), name="value"), ) ) nodes.append(oh.make_node("Mul", ["query", "scale_sqrt"], ["query_scaled"])) nodes.append(oh.make_node("Mul", ["keys", "scale_sqrt"], ["keys_scaled"])) nodes.append( oh.make_node( "Transpose", ["keys_scaled"], ["keys_scaled_t"], perm=[0, 1, 3, 2] ) ) nodes.append(oh.make_node("MatMul", ["query_scaled", "keys_scaled_t"], ["qk"])) nodes.append(oh.make_node("Where", ["mask", "zero", "minfty"], ["bias"])) nodes.append(oh.make_node("Add", ["qk", "bias"], ["qkb"])) nodes.append(oh.make_node("Softmax", ["qkb"], ["qkbs"], axis=-1)) nodes.append(oh.make_node("IsNaN", ["qkbs"], ["nans"])) nodes.append(oh.make_node("Where", ["nans", "zero", "qkbs"], ["filt"])) nodes.append(oh.make_node("MatMul", ["filt", "values"], ["Y"])) outputs.append( oh.make_tensor_value_info( "Y", onnx.TensorProto.FLOAT, shape=("ay", "by", "cy", "dy") ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("intermediate", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info( "values", onnx.TensorProto.FLOAT, shape=("av", "bv", "cv", "dv") ) ) inputs.append( oh.make_tensor_value_info( "keys", onnx.TensorProto.FLOAT, shape=("ak", "bk", "ck", "dk") ) ) inputs.append( oh.make_tensor_value_info("scale_sqrt", onnx.TensorProto.FLOAT, shape=(1,)) ) inputs.append( oh.make_tensor_value_info( "mask", onnx.TensorProto.BOOL, shape=("am", "bm", "cm", "dm") ) ) inputs.append( oh.make_tensor_value_info( "query", onnx.TensorProto.FLOAT, shape=("aq", "bq", "cq", "dq") ) ) nodes.append( oh.make_node( "LocalAttention_to1", ["query", "keys", "values", "mask", "scale_sqrt"], ["Y"], domain="intermediate", ) ) outputs.append( oh.make_tensor_value_info( "Y", onnx.TensorProto.FLOAT, shape=("ay", "by", "cy", "dy") ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) GroupQueryAttention (GQA): .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("intermediate", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("init1_s_::RSh1", onnx.TensorProto.FLOAT, shape=(1,)) ) inputs.append( oh.make_tensor_value_info( "query", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) inputs.append( oh.make_tensor_value_info( "cat_1", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32), ) ) inputs.append( oh.make_tensor_value_info( "cat", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32) ) ) inputs.append( oh.make_tensor_value_info( "to", onnx.TensorProto.BOOL, shape=("seq_length", "total_length") ) ) inputs.append( oh.make_tensor_value_info("init7_s4_0_8_-1_32", onnx.TensorProto.INT64, shape=(4,)) ) inputs.append( oh.make_tensor_value_info("init7_s5_1_1_2_1_1", onnx.TensorProto.INT64, shape=(5,)) ) nodes.append( oh.make_node( "Constant", [], ["init1_s_::RSh1"], value=onh.from_array( np.array([0.4204482138156891], dtype=np.float32), name="value" ), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_2"], value=onh.from_array(np.array([2], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init1_s_::RSh12"], value=onh.from_array( np.array([0.4204482138156891], dtype=np.float32), name="value" ), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s5_1_1_2_1_1"], value=onh.from_array(np.array([1, 1, 2, 1, 1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s4_0_8_-1_32"], value=onh.from_array(np.array([0, 8, -1, 32], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init1_s1_"], value=onh.from_array(np.array([-np.inf], dtype=np.float32), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["c_lifted_tensor_0"], value=onh.from_array(np.array(0.0, dtype=np.float32), name="value"), ) ) nodes.append(oh.make_node("Mul", ["query", "init1_s_::RSh1"], ["_onx_mul_query"])) nodes.append(oh.make_node("Unsqueeze", ["cat", "init7_s1_2"], ["cat::UnSq2"])) nodes.append( oh.make_node( "Mul", ["cat::UnSq2", "init1_s_::RSh12"], ["ShapeBasedExpandSwapPattern_SwapUnaryPattern--repeat_interleave_1"], ) ) nodes.append( oh.make_node( "Expand", [ "ShapeBasedExpandSwapPattern_SwapUnaryPattern--repeat_interleave_1", "init7_s5_1_1_2_1_1", ], ["SwapUnaryPattern--repeat_interleave_1"], ) ) nodes.append( oh.make_node( "Reshape", ["SwapUnaryPattern--repeat_interleave_1", "init7_s4_0_8_-1_32"], ["SwapUnaryPattern--transpose"], ) ) nodes.append( oh.make_node( "Transpose", ["SwapUnaryPattern--transpose"], ["_onx_mul_transpose"], perm=[0, 1, 3, 2], ) ) nodes.append( oh.make_node("MatMul", ["_onx_mul_query", "_onx_mul_transpose"], ["matmul"]) ) nodes.append( oh.make_node("Where", ["to", "init1_s1_", "matmul"], ["masked_fill"]) ) nodes.append(oh.make_node("Softmax", ["masked_fill"], ["softmax"], axis=-1)) nodes.append(oh.make_node("IsNaN", ["softmax"], ["isnan"])) nodes.append( oh.make_node("Where", ["isnan", "c_lifted_tensor_0", "softmax"], ["where"]) ) nodes.append(oh.make_node("Unsqueeze", ["cat_1", "init7_s1_2"], ["cat_1::UnSq2"])) nodes.append( oh.make_node( "Expand", ["cat_1::UnSq2", "init7_s5_1_1_2_1_1"], ["_onx_expand_cat_1::UnSq2"] ) ) nodes.append( oh.make_node( "Reshape", ["_onx_expand_cat_1::UnSq2", "init7_s4_0_8_-1_32"], ["repeat_interleave"], ) ) nodes.append(oh.make_node("MatMul", ["where", "repeat_interleave"], ["output_0"])) outputs.append( oh.make_tensor_value_info( "output_0", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("intermediate", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("init1_s_::RSh1", onnx.TensorProto.FLOAT, shape=(1,)) ) inputs.append( oh.make_tensor_value_info( "query", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) inputs.append( oh.make_tensor_value_info( "cat_1", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32), ) ) inputs.append( oh.make_tensor_value_info( "cat", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32) ) ) inputs.append( oh.make_tensor_value_info( "to", onnx.TensorProto.BOOL, shape=("seq_length", "total_length") ) ) inputs.append( oh.make_tensor_value_info("init7_s4_0_8_-1_32", onnx.TensorProto.INT64, shape=(4,)) ) inputs.append( oh.make_tensor_value_info("init7_s5_1_1_2_1_1", onnx.TensorProto.INT64, shape=(5,)) ) nodes.append( oh.make_node( "LocalAttentionGQASW_to1", [ "query", "cat", "cat_1", "to", "init1_s_::RSh1", "init7_s5_1_1_2_1_1", "init7_s4_0_8_-1_32", ], ["output_0"], domain="intermediate", ) ) outputs.append( oh.make_tensor_value_info( "output_0", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ _operator_name = "LocalAttention" _domain_name = "intermediate" def __init__(self, verbose: int = 0, priority: int = 0): super().__init__(verbose, priority) def _find_index_inf(self, g, where_node): for i in (1, 2): if g.is_constant_scalar(where_node.input[i]): cst = g.get_constant_scalar(where_node.input[i]) if np.isinf(cst): return i return None
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if node.op_type != "Softmax" or node.domain != "" or g.main_opset < 18: return self.none() axis = g.get_attribute(node, "axis").i if axis != -1: return self.none(node, inspect.currentframe().f_lineno) node_before = g.node_before(node.input[0]) if node_before.op_type == "Add": # Add(X, Where(mask, 0, -inf)) add_node = node_before where_node = g.node_before(add_node.input[1]) if where_node is None or where_node.op_type != "Where": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(where_node.input[1]): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(where_node.input[2]): return self.none(node, inspect.currentframe().f_lineno) cst_zero = g.get_constant_scalar(where_node.input[1]) if cst_zero != 0: return self.none(node, inspect.currentframe().f_lineno) cst_inf = g.get_constant_scalar(where_node.input[2]) if not np.isinf(cst_inf): return self.none(node, inspect.currentframe().f_lineno) mat_qk = g.node_before(add_node.input[0]) if mat_qk is None or mat_qk.op_type not in ("MatMul", "FusedMatMul"): return self.none(node, inspect.currentframe().f_lineno) elif node_before.op_type == "Where": # Where(mask, -inf, X) add_node = None where_node = node_before if not g.is_constant_scalar(where_node.input[1]) and not g.is_constant_scalar( where_node.input[2] ): return self.none(node, inspect.currentframe().f_lineno) cst_zero = None inf_index = 1 if g.is_constant_scalar(where_node.input[1]) else 2 cst_inf = g.get_constant_scalar(where_node.input[inf_index]) if not np.isinf(cst_inf) or cst_inf > 0: return self.none(node, inspect.currentframe().f_lineno) mat_qk = g.node_before(where_node.input[3 - inf_index]) if mat_qk is None or mat_qk.op_type not in ("MatMul", "FusedMatMul"): return self.none(node, inspect.currentframe().f_lineno) else: return self.none(node, inspect.currentframe().f_lineno) mul1 = g.node_before(mat_qk.input[0]) if mul1 is None or mul1.op_type != "Mul": return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(mul1.input[1]): return self.none(node, inspect.currentframe().f_lineno) if mat_qk.op_type == "MatMul": transpose = g.node_before(mat_qk.input[1]) if transpose is None or transpose.op_type != "Transpose": return self.none(node, inspect.currentframe().f_lineno) perm = g.get_attribute(transpose, "perm").ints if tuple(perm) != (0, 1, 3, 2): return self.none(node, inspect.currentframe().f_lineno) mul2 = g.node_before(transpose.input[0]) else: transA = g.get_attribute_with_default(mat_qk, "transA", 0) transB = g.get_attribute_with_default(mat_qk, "transB", 1) if transA != 0 or transB != 1: return self.none(node, inspect.currentframe().f_lineno) transpose = None mul2 = g.node_before(mat_qk.input[1]) if mul2 is None: return self.none(node, inspect.currentframe().f_lineno) if mul2.op_type == "Mul": # This condition is verified for Attention or MultiHeadAttention. gqa_expand = gqa_reshape = gqa_unsqueeze = None elif mul2.op_type == "Reshape": # This condition is verified by GroupQueryAttention. gqa_reshape = mul2 mul2 = None gqa_expand = g.node_before(gqa_reshape.input[0]) if gqa_expand.op_type != "Expand": return self.none(node, inspect.currentframe().f_lineno) mul2 = g.node_before(gqa_expand.input[0]) if mul2.op_type != "Mul": return self.none(node, inspect.currentframe().f_lineno) gqa_unsqueeze = g.node_before(mul2.input[0]) if gqa_unsqueeze.op_type != "Unsqueeze": return self.none(node, inspect.currentframe().f_lineno) # if not g.is_constant(gqa_expand.input[1]): return self.none(node, inspect.currentframe().f_lineno) exp_shape = g.get_computed_constant(gqa_expand.input[1]) if tuple(exp_shape[:2]) != (1, 1) or tuple(exp_shape[3:]) != (1, 1): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(gqa_unsqueeze.input[1]): return self.none(node, inspect.currentframe().f_lineno) unsq_shape = g.get_computed_constant(gqa_unsqueeze.input[1]) if tuple(unsq_shape) != (2,): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(gqa_reshape.input[1]): return self.none(node, inspect.currentframe().f_lineno) resh_shape = g.get_computed_constant(gqa_reshape.input[1]) if resh_shape.size != 4: return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(gqa_unsqueeze.input[0]) or not g.has_shape(gqa_reshape.output[0]): return self.none(node, inspect.currentframe().f_lineno) shape1 = g.get_shape_renamed(gqa_unsqueeze.input[0]) shape2 = g.get_shape_renamed(gqa_reshape.output[0]) if shape1[0] != shape2[0] or shape1[2] != shape2[2] or shape1[3] != shape2[3]: return self.none( node, inspect.currentframe().f_lineno, msg=lambda: f"Shape mismatch {shape1=}, {shape2=}", ) else: # No Attention, no MultiHeadAttention, no GroupQueryAttention return self.none(node, inspect.currentframe().f_lineno) if mul2.input[1] != mul1.input[1]: if not g.is_constant_scalar(mul1.input[1]) or not g.is_constant_scalar(mul2.input[1]): return self.none(node, inspect.currentframe().f_lineno) cst1 = g.get_constant_scalar(mul1.input[1]) cst2 = g.get_constant_scalar(mul2.input[1]) if cst1 != cst2: return self.none(node, inspect.currentframe().f_lineno) # after softmax next_nodes = g.next_nodes(node.output[0]) if len(next_nodes) != 2: return self.none(node, inspect.currentframe().f_lineno) if {n.op_type for n in next_nodes} != {"Where", "IsNaN"}: return self.none(node, inspect.currentframe().f_lineno) isnan, where2 = next_nodes[:: (1 if next_nodes[0].op_type == "IsNaN" else -1)] if where2.input[0] != isnan.output[0]: return self.none(node, inspect.currentframe().f_lineno) if where2.input[2] != node.output[0]: return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant_scalar(where2.input[1]): return self.none(node, inspect.currentframe().f_lineno) cst = g.get_constant_scalar(where2.input[1]) if cst != 0: return self.none(node, inspect.currentframe().f_lineno) mat_qkvs = g.next_nodes(where2.output[0]) if len(mat_qkvs) != 1: return self.none(node, inspect.currentframe().f_lineno) mat_qkv = mat_qkvs[0] if mat_qkv.op_type != "MatMul": return self.none(node, inspect.currentframe().f_lineno) if gqa_reshape: # We need to include the nodes repeating values, # the same one which repeated the keys. gqa_reshape_v = g.node_before(mat_qkv.input[1]) if gqa_reshape_v.op_type != "Reshape": return self.none(node, inspect.currentframe().f_lineno) gqa_expand_v = g.node_before(gqa_reshape_v.input[0]) if gqa_expand_v.op_type != "Expand": return self.none(node, inspect.currentframe().f_lineno) gqa_unsqueeze_v = g.node_before(gqa_expand_v.input[0]) if gqa_unsqueeze_v.op_type != "Unsqueeze": return self.none(node, inspect.currentframe().f_lineno) # if not g.is_constant(gqa_expand.input[1]): return self.none(node, inspect.currentframe().f_lineno) exp_shape_v = g.get_computed_constant(gqa_expand_v.input[1]) if tuple(exp_shape) != tuple(exp_shape_v): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(gqa_unsqueeze_v.input[1]): return self.none(node, inspect.currentframe().f_lineno) unsq_shape_v = g.get_computed_constant(gqa_unsqueeze_v.input[1]) if tuple(unsq_shape_v) != tuple(unsq_shape): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(gqa_reshape_v.input[1]): return self.none(node, inspect.currentframe().f_lineno) resh_shape_v = g.get_computed_constant(gqa_reshape_v.input[1]) if tuple(resh_shape_v) != tuple(resh_shape): return self.none(node, inspect.currentframe().f_lineno) else: gqa_expand_v = gqa_reshape_v = gqa_unsqueeze_v = None nodes = [ mul1, gqa_unsqueeze, mul2, gqa_expand, gqa_reshape, transpose, mat_qk, where_node, add_node, node, isnan, where2, gqa_unsqueeze_v, gqa_expand_v, gqa_reshape_v, mat_qkv, ] for n in nodes[:-1]: if not n: continue if n.op_type == "Softmax": if len(g.next_nodes(n.output[0])) != 2: return self.none(node, inspect.currentframe().f_lineno) continue if g.is_used_more_than_once(n.output[0]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 mul1: NodeProto, gqa_unsqueeze: Optional[NodeProto], mul2: NodeProto, gqa_expand: Optional[NodeProto], gqa_reshape: Optional[NodeProto], transpose: Optional[NodeProto], mat_qk: NodeProto, where_node: NodeProto, add_node: Optional[NodeProto], softmax: NodeProto, isnan: NodeProto, where: NodeProto, gqa_unsqueeze_v: Optional[NodeProto], gqa_expand_v: Optional[NodeProto], gqa_reshape_v: Optional[NodeProto], mat_qkv: NodeProto, ) -> List[NodeProto]: itype = g.get_type(mul1.input[1]) suffix = [] index_inf = self._find_index_inf(g, where_node) assert index_inf, ( f"Could not any inf in node {g.pretty_node(where_node)}, " f"the pattern {self.__class__.__name__} should not have matched." ) switch_where = index_inf == 1 if switch_where: suffix.append("SW") if transpose is None: assert ( mat_qk.op_type == "FusedMatMul" ), f"transpose is None but mat_qk={g.pretty_node(mat_qk)}" suffix.append("noT") if gqa_reshape: gqa = "GQA" if gqa_reshape.op_type == "Reshape" else "GQAsQ" gqa_args = [gqa_expand.input[1], gqa_reshape.input[1]] else: gqa = "" gqa_args = [] name = f"{self._operator_name}{gqa}{''.join(suffix)}_to{itype}" attention_nodes = [ g.make_node( name, [ mul1.input[0], gqa_unsqueeze.input[0] if gqa_reshape else mul2.input[0], gqa_unsqueeze_v.input[0] if gqa_reshape else mat_qkv.input[1], where_node.input[0], mul1.input[1], *gqa_args, ], [mat_qkv.output[0]], name=f"{self.__class__.__name__}--{softmax.name}", domain=self._domain_name, ) ] nodes_to_return = attention_nodes # Creates the local function if not g.builder.has_local_function(name, domain=self._domain_name): self._add_local_function( g.builder, name, itype=itype, gqa=gqa, switch_where=switch_where, use_qga_squeeze=gqa_reshape and gqa_reshape.op_type == "Squeeze", ) return nodes_to_return
@classmethod def _add_local_function( cls, g: GraphBuilder, name: str, itype: int, gqa: bool, switch_where: bool, use_qga_squeeze: bool, ): lg = GraphBuilder(g.main_opset, as_function=True) lg.make_tensor_input("query") lg.make_tensor_input("keys") lg.make_tensor_input("values") mask_name = "not_mask" if switch_where else "mask" lg.make_tensor_input(mask_name) lg.make_tensor_input("scale_sqrt") scaled_keys = lg.op.Mul("keys", "scale_sqrt", name=cls.__name__) if gqa: lg.make_tensor_input("expand_shape") lg.make_tensor_input("gqa_shape") two = np.array([2], dtype=np.int64) unsq_keys = lg.op.UnsqueezeAnyOpset(scaled_keys, two, name=cls.__name__) unsq_values = lg.op.UnsqueezeAnyOpset("values", two, name=cls.__name__) exp_keys = lg.op.Expand(unsq_keys, "expand_shape") exp_values = lg.op.Expand(unsq_values, "expand_shape") if use_qga_squeeze: resh_keys = lg.op.Squeeze(exp_keys, "gqa_shape") resh_values = lg.op.Squeeze(exp_values, "gqa_shape") else: resh_keys = lg.op.Reshape(exp_keys, "gqa_shape") resh_values = lg.op.Reshape(exp_values, "gqa_shape") scaled_keys = resh_keys values = resh_values else: values = "values" scaled_query = lg.op.Mul("query", "scale_sqrt", name=cls.__name__) scaled_keys_t = lg.op.Transpose(scaled_keys, perm=(0, 1, 3, 2), name=cls.__name__) qk = lg.op.MatMul(scaled_query, scaled_keys_t, name=cls.__name__) dtype = tensor_dtype_to_np_dtype(itype) zero = np.array([0], dtype=dtype) minfty = np.array([-np.inf], dtype=dtype) where_args = (minfty, qk) if switch_where else (qk, minfty) masked_qk = lg.op.Where(mask_name, *where_args, name=cls.__name__) softmax = lg.op.Softmax(masked_qk, axis=-1, name=cls.__name__) filtered = lg.op.Where( lg.op.IsNaN(softmax, name=cls.__name__), zero, softmax, name=cls.__name__ ) lg.op.MatMul(filtered, values, outputs=["Y"], name=cls.__name__) lg.make_tensor_output("Y") function_options = FunctionOptions( export_as_function=True, name=name, domain=cls._domain_name, move_initializer_to_constant=True, ) g.make_local_function(lg, function_options=function_options) assert g.has_local_function( name, domain=cls._domain_name ), f"The function {cls._domain_name}.{name} was not added to the builder."
class _CommonGQAMethods: def _match_keys_or_values( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, keys_or_values: str, ) -> Optional[Tuple[NodeProto, NodeProto, NodeProto, Tuple[Tuple[Union[int, str], ...]]]]: gqa_reshape = g.node_before(keys_or_values) if ( not gqa_reshape or gqa_reshape.op_type not in ("Reshape", "Squeeze") or gqa_reshape.domain != "" or g.main_opset < 18 ): return self.none(node, inspect.currentframe().f_lineno) gqa_expand = g.node_before(gqa_reshape.input[0]) if gqa_expand.op_type != "Expand": return self.none(node, inspect.currentframe().f_lineno) gqa_unsqueeze = g.node_before(gqa_expand.input[0]) if gqa_unsqueeze.op_type != "Unsqueeze": return self.none(node, inspect.currentframe().f_lineno) # if not g.is_constant(gqa_expand.input[1]): return self.none(node, inspect.currentframe().f_lineno) exp_shape = g.get_computed_constant(gqa_expand.input[1]) if tuple(exp_shape[:2]) != (1, 1) or tuple(exp_shape[3:]) != (1, 1): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(gqa_unsqueeze.input[1]): return self.none(node, inspect.currentframe().f_lineno) unsq_shape = g.get_computed_constant(gqa_unsqueeze.input[1]) if tuple(unsq_shape) != (2,): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(gqa_reshape.input[1]): return self.none(node, inspect.currentframe().f_lineno) resh_shape = g.get_computed_constant(gqa_reshape.input[1]) if gqa_reshape.op_type == "Reshape": if resh_shape.size != 4: return self.none(node, inspect.currentframe().f_lineno) elif gqa_reshape.op_type == "Squeeze": if resh_shape.size != 1: return self.none(node, inspect.currentframe().f_lineno) if not g.has_shape(gqa_unsqueeze.input[0]) or not g.has_shape(gqa_reshape.output[0]): return self.none(node, inspect.currentframe().f_lineno) shape1 = g.get_shape_renamed(gqa_unsqueeze.input[0]) shape2 = g.get_shape_renamed(gqa_reshape.output[0]) if shape1[0] != shape2[0] or shape1[2] != shape2[2] or shape1[3] != shape2[3]: return self.none(node, inspect.currentframe().f_lineno) return ( gqa_unsqueeze, gqa_expand, gqa_reshape, (tuple(unsq_shape), tuple(exp_shape), tuple(resh_shape)), )
[docs] class FunctionAttentionGQAPattern(FunctionAttentionPattern, _CommonGQAMethods): """ Merges onnx nodes equivalent to repeat interleave followed by function ``LocalAttention`` into ``LocalAttentionGQA`` (GQA for GroupQueryAttention). Model with nodes to be fused: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("intermediate", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info( "cat", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32) ) ) inputs.append( oh.make_tensor_value_info("init1_s_::RSh1", onnx.TensorProto.FLOAT, shape=(1,)) ) inputs.append( oh.make_tensor_value_info( "to", onnx.TensorProto.BOOL, shape=("seq_length", "total_length") ) ) inputs.append( oh.make_tensor_value_info("init7_s4_0_8_-1_32", onnx.TensorProto.INT64, shape=(4,)) ) inputs.append( oh.make_tensor_value_info("init7_s5_1_1_2_1_1", onnx.TensorProto.INT64, shape=(5,)) ) inputs.append( oh.make_tensor_value_info( "cat_1", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32), ) ) inputs.append( oh.make_tensor_value_info( "query", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s1_2"], value=onh.from_array(np.array([2], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s5_1_1_2_1_1"], value=onh.from_array(np.array([1, 1, 2, 1, 1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init7_s4_0_8_-1_32"], value=onh.from_array(np.array([0, 8, -1, 32], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["init1_s_::RSh1"], value=onh.from_array( np.array([0.4204482138156891], dtype=np.float32), name="value" ), ) ) nodes.append(oh.make_node("Unsqueeze", ["cat", "init7_s1_2"], ["cat::UnSq2"])) nodes.append( oh.make_node( "Expand", ["cat::UnSq2", "init7_s5_1_1_2_1_1"], ["_onx_expand_cat::UnSq2"] ) ) nodes.append( oh.make_node( "Reshape", ["_onx_expand_cat::UnSq2", "init7_s4_0_8_-1_32"], ["repeat_interleave_1"], ) ) nodes.append(oh.make_node("Unsqueeze", ["cat_1", "init7_s1_2"], ["cat_1::UnSq2"])) nodes.append( oh.make_node( "Expand", ["cat_1::UnSq2", "init7_s5_1_1_2_1_1"], ["_onx_expand_cat_1::UnSq2"] ) ) nodes.append( oh.make_node( "Reshape", ["_onx_expand_cat_1::UnSq2", "init7_s4_0_8_-1_32"], ["repeat_interleave"], ) ) nodes.append( oh.make_node( "LocalAttentionSW_to1", ["query", "repeat_interleave_1", "repeat_interleave", "to", "init1_s_::RSh1"], ["output_0"], domain="intermediate", ) ) outputs.append( oh.make_tensor_value_info( "output_0", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 18), oh.make_opsetid("intermediate", 1), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info( "cat", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32) ) ) inputs.append( oh.make_tensor_value_info("init1_s_::RSh1", onnx.TensorProto.FLOAT, shape=(1,)) ) inputs.append( oh.make_tensor_value_info( "to", onnx.TensorProto.BOOL, shape=("seq_length", "total_length") ) ) inputs.append( oh.make_tensor_value_info("init7_s4_0_8_-1_32", onnx.TensorProto.INT64, shape=(4,)) ) inputs.append( oh.make_tensor_value_info("init7_s5_1_1_2_1_1", onnx.TensorProto.INT64, shape=(5,)) ) inputs.append( oh.make_tensor_value_info( "cat_1", onnx.TensorProto.FLOAT, shape=("batch", 4, "past_length+seq_length", 32), ) ) inputs.append( oh.make_tensor_value_info( "query", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) nodes.append( oh.make_node( "LocalAttentionGQASW_to1", [ "query", "cat", "cat_1", "to", "init1_s_::RSh1", "init7_s5_1_1_2_1_1", "init7_s4_0_8_-1_32", ], ["output_0"], domain="intermediate", ) ) outputs.append( oh.make_tensor_value_info( "output_0", onnx.TensorProto.FLOAT, shape=("batch", 8, "seq_length", 32) ) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ _operator_gqa_name = f"{FunctionAttentionPattern._operator_name}GQA"
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if ( not node.op_type.startswith(FunctionAttentionPattern._operator_name) or node.op_type.startswith(FunctionAttentionGQAPattern._operator_gqa_name) or node.domain != FunctionAttentionGQAPattern._domain_name ): return self.none() keys, values = node.input[1:3] matched_keys = self._match_keys_or_values(g, node, keys) if not matched_keys: return self.none(node, inspect.currentframe().f_lineno) matched_values = self._match_keys_or_values(g, node, values) if not matched_values: return self.none(node, inspect.currentframe().f_lineno) gqa_unsqueeze, gqa_expand, gqa_reshape, shapes = matched_keys gqa_unsqueeze_v, gqa_expand_v, gqa_reshape_v, _shapes_v = matched_values unsq_shape, exp_shape, resh_shape = shapes unsq_shape_v, exp_shape_v, resh_shape_v = shapes if unsq_shape_v != unsq_shape: return self.none(node, inspect.currentframe().f_lineno) if exp_shape != exp_shape_v: return self.none(node, inspect.currentframe().f_lineno) if resh_shape_v != resh_shape: return self.none(node, inspect.currentframe().f_lineno) # Final verification, let's check none the nodes is used outside the pattern. nodes = [ gqa_unsqueeze, gqa_expand, gqa_reshape, gqa_unsqueeze_v, gqa_expand_v, gqa_reshape_v, node, ] for n in nodes[:-1]: if n and g.is_used_more_than_once(n.output[0]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, nodes, self.apply)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 gqa_unsqueeze: NodeProto, gqa_expand: NodeProto, gqa_reshape: NodeProto, gqa_unsqueeze_v: NodeProto, gqa_expand_v: NodeProto, gqa_reshape_v: NodeProto, attn: NodeProto, ) -> List[NodeProto]: itype = g.get_type(gqa_unsqueeze.input[0]) gqa = "" if gqa_reshape.op_type == "Reshape" else "sQ" name = f"{self._operator_gqa_name}{gqa}{attn.op_type[len(self._operator_name):]}" attention_nodes = [ g.make_node( name, [ attn.input[0], gqa_unsqueeze.input[0], gqa_unsqueeze_v.input[0], attn.input[3] if len(attn.input) > 3 else "", attn.input[4] if len(attn.input) > 4 else "", gqa_expand.input[1], gqa_reshape.input[1], ], [attn.output[0]], name=f"{self.__class__.__name__}--{attn.name}", domain=self._domain_name, ) ] # Creates the local function if not g.builder.has_local_function(name, domain=self._domain_name): self._add_local_function( g.builder, name, itype=itype, gqa=True, switch_where="SW" in attn.op_type, use_qga_squeeze=gqa_reshape_v.op_type == "Squeeze", ) return attention_nodes
[docs] class AttentionGQAPattern(PatternOptimization, _CommonGQAMethods): """ Fuses LocalAttention into Attention. Opset must be >= 23 to do so. .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 24), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("key", onnx.TensorProto.FLOAT, shape=("a", 2, "c", 8)) ) inputs.append( oh.make_tensor_value_info("mask", onnx.TensorProto.BOOL, shape=("a", 1, "c", "c+h")) ) inputs.append( oh.make_tensor_value_info("value", onnx.TensorProto.FLOAT, shape=("a", 2, "c", 8)) ) inputs.append( oh.make_tensor_value_info( "past_key", onnx.TensorProto.FLOAT, shape=("a", 2, "h", 8) ) ) inputs.append( oh.make_tensor_value_info("query", onnx.TensorProto.FLOAT, shape=("a", 4, "c", 8)) ) inputs.append( oh.make_tensor_value_info( "past_value", onnx.TensorProto.FLOAT, shape=("a", 2, "h", 8) ) ) nodes.append( oh.make_node( "Constant", [], ["two"], value=onh.from_array(np.array([2], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["t11211"], value=onh.from_array(np.array([1, 1, 2, 1, 1], dtype=np.int64), name="value"), ) ) nodes.append( oh.make_node( "Constant", [], ["resh"], value=onh.from_array(np.array([0, 4, -1, 8], dtype=np.int64), name="value"), ) ) nodes.append(oh.make_node("Concat", ["past_key", "key"], ["present_key"], axis=2)) nodes.append( oh.make_node("Concat", ["past_value", "value"], ["present_value"], axis=2) ) nodes.append(oh.make_node("Unsqueeze", ["present_key", "two"], ["key_u"])) nodes.append(oh.make_node("Expand", ["key_u", "t11211"], ["key_ue"])) nodes.append(oh.make_node("Reshape", ["key_ue", "resh"], ["key_ues"])) nodes.append(oh.make_node("Unsqueeze", ["present_value", "two"], ["value_u"])) nodes.append(oh.make_node("Expand", ["value_u", "t11211"], ["value_ue"])) nodes.append(oh.make_node("Reshape", ["value_ue", "resh"], ["value_ues"])) nodes.append( oh.make_node( "Attention", ["query", "key_ues", "value_ues", "mask"], ["Y"], scale=0.10999999940395355, ) ) outputs.append( oh.make_tensor_value_info( "present_value", onnx.TensorProto.FLOAT, shape=("a", 2, "c+h", 8) ) ) outputs.append( oh.make_tensor_value_info( "present_key", onnx.TensorProto.FLOAT, shape=("a", 2, "c+h", 8) ) ) outputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 4, "c_", 8)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) Outcome of the fusion: .. gdot:: :script: DOT-SECTION :process: from experimental_experiment.doc import to_dot import numpy as np import ml_dtypes import onnx import onnx.helper as oh import onnx.numpy_helper as onh opset_imports = [ oh.make_opsetid("", 24), ] inputs = [] outputs = [] nodes = [] initializers = [] sparse_initializers = [] functions = [] inputs.append( oh.make_tensor_value_info("key", onnx.TensorProto.FLOAT, shape=("a", 2, "c", 8)) ) inputs.append( oh.make_tensor_value_info("mask", onnx.TensorProto.BOOL, shape=("a", 1, "c", "c+h")) ) inputs.append( oh.make_tensor_value_info("value", onnx.TensorProto.FLOAT, shape=("a", 2, "c", 8)) ) inputs.append( oh.make_tensor_value_info( "past_key", onnx.TensorProto.FLOAT, shape=("a", 2, "h", 8) ) ) inputs.append( oh.make_tensor_value_info("query", onnx.TensorProto.FLOAT, shape=("a", 4, "c", 8)) ) inputs.append( oh.make_tensor_value_info( "past_value", onnx.TensorProto.FLOAT, shape=("a", 2, "h", 8) ) ) nodes.append( oh.make_node( "Attention", ["query", "key", "value", "mask", "past_key", "past_value"], ["Y", "present_key", "present_value"], is_causal=0, scale=0.10999999940395355, ) ) outputs.append( oh.make_tensor_value_info( "present_value", onnx.TensorProto.FLOAT, shape=("a", 2, "c+h", 8) ) ) outputs.append( oh.make_tensor_value_info( "present_key", onnx.TensorProto.FLOAT, shape=("a", 2, "c+h", 8) ) ) outputs.append( oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape=("a", 4, "c_", 8)) ) graph = oh.make_graph( nodes, "pattern", inputs, outputs, initializers, sparse_initializer=sparse_initializers, ) model = oh.make_model(graph, functions=functions, opset_imports=opset_imports) print("DOT-SECTION", to_dot(model)) """ _prefixes_operator_name = ( f"{FunctionAttentionGQAPattern._operator_gqa_name}SW_to", f"{FunctionAttentionGQAPattern._operator_gqa_name}SWsQ_to", f"{FunctionAttentionGQAPattern._operator_gqa_name}_to", f"{FunctionAttentionGQAPattern._operator_gqa_name}sQ_to", ) def __init__(self, verbose: int = 0, priority: int = 2): super().__init__(verbose, priority)
[docs] def match( self, g: "GraphBuilderPatternOptimization", # noqa: F821 node: NodeProto, matched: List[MatchResult], ) -> Optional[MatchResult]: if g.main_opset < 23: return self.none() if ( (node.op_type != "Attention" or node.domain != "") and ( not node.op_type.startswith(self._prefixes_operator_name) or node.domain != FunctionAttentionGQAPattern._domain_name or len(node.input) != 7 ) ) or len(node.output) > 1: return self.none() if len(node.input) > 3 and ( not g.has_rank(node.input[3]) or g.get_rank(node.input[3]) < 2 ): # Only 2D ranks allowed. return self.none(node, inspect.currentframe().f_lineno) if node.op_type == "Attention": if not g.has_rank(node.input[0]) and g.get_rank(node.input[0]) != 4: # Only 4D Attention return self.none(node, inspect.currentframe().f_lineno) # Node Attention, we still need to check if there is some GQA node. gqa_keys = self._match_keys_or_values(g, node, node.input[1]) if not gqa_keys: return self.none(node, inspect.currentframe().f_lineno) gqa_values = self._match_keys_or_values(g, node, node.input[2]) if not gqa_values: return self.none(node, inspect.currentframe().f_lineno) gqa_unsqueeze, gqa_expand, gqa_reshape, shapes = gqa_keys gqa_unsqueeze_v, gqa_expand_v, gqa_reshape_v, shapes_v = gqa_values unsq_shape, exp_shape, resh_shape = shapes unsq_shape_v, exp_shape_v, resh_shape_v = shapes_v if unsq_shape_v != unsq_shape: return self.none(node, inspect.currentframe().f_lineno) if exp_shape != exp_shape_v: return self.none(node, inspect.currentframe().f_lineno) if resh_shape_v != resh_shape: return self.none(node, inspect.currentframe().f_lineno) gqa_nodes = [ gqa_unsqueeze, gqa_expand, gqa_reshape, gqa_unsqueeze_v, gqa_expand_v, gqa_reshape_v, ] concats = g.node_before(gqa_unsqueeze.input[0]), g.node_before( gqa_unsqueeze_v.input[0] ) if None in concats: return self.none(node, inspect.currentframe().f_lineno) if len(concats[0].input) != 2 or len(concats[1].input) != 2: return self.none(node, inspect.currentframe().f_lineno) if concats[0].op_type != "Concat" or concats[1].op_type != "Concat": return self.none(node, inspect.currentframe().f_lineno) if g.get_attribute_with_default( concats[0], "axis", 0 ) != g.get_attribute_with_default(concats[1], "axis", 0): return self.none(node, inspect.currentframe().f_lineno) else: keys, values = node.input[1:3] concats = g.node_before(keys), g.node_before(values) if None in concats: return self.none(node, inspect.currentframe().f_lineno) if len(concats[0].input) != 2 or len(concats[1].input) != 2: return self.none(node, inspect.currentframe().f_lineno) if concats[0].op_type != "Concat" or concats[1].op_type != "Concat": return self.none(node, inspect.currentframe().f_lineno) if g.get_attribute_with_default( concats[0], "axis", 0 ) != g.get_attribute_with_default(concats[1], "axis", 0): return self.none(node, inspect.currentframe().f_lineno) # Local function if not g.is_constant_scalar(node.input[4]): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[5]): return self.none(node, inspect.currentframe().f_lineno) cst = g.get_computed_constant(node.input[5]) if cst is None: return self.none(node, inspect.currentframe().f_lineno) cst = tuple(cst) if len(cst) < 4: return self.none(node, inspect.currentframe().f_lineno) if cst[:2] != cst[3:] or cst[:2] != (1, 1): return self.none(node, inspect.currentframe().f_lineno) if not g.is_constant(node.input[6]): return self.none(node, inspect.currentframe().f_lineno) shape_or_axis = g.get_computed_constant(node.input[6]) if shape_or_axis is None: return self.none(node, inspect.currentframe().f_lineno) if "sQ_to" in node.op_type: # This is an axis for a Squeeze node. if not g.get_shape(node.input[1]): # We need that shape to get kv_num_heads. return self.none(node, inspect.currentframe().f_lineno) else: # This is a shape for a Reshape node. if shape_or_axis[1] <= 0: return self.none(node, inspect.currentframe().f_lineno) gqa_nodes = [None for _ in range(6)] # Final verification, let's check none the nodes is used outside the pattern. nodes = [*concats, *gqa_nodes, node] for n in nodes[2:-1]: if n and g.is_used_more_than_once(n.output[0]): return self.none(node, inspect.currentframe().f_lineno) return MatchResult(self, nodes, self.apply, insert_at=node)
[docs] def apply( self, g: "GraphBuilder", # noqa: F821 keys_concat_node: NodeProto, values_concat_node: NodeProto, gqa_unsqueeze: Optional[NodeProto], gqa_expand: Optional[NodeProto], gqa_reshape: Optional[NodeProto], gqa_unsqueeze_v: Optional[NodeProto], gqa_expand_v: Optional[NodeProto], gqa_reshape_v: Optional[NodeProto], local_attention_gqa: Optional[NodeProto], ) -> List[NodeProto]: query, _keys, _values, mask = local_attention_gqa.input[:4] attn_kwargs = {} if local_attention_gqa.op_type == "Attention": scale = g.get_attribute_with_default(local_attention_gqa, "scale", None) if scale is not None: attn_kwargs["scale"] = scale attn_kwargs["is_causal"] = g.get_attribute_with_default( local_attention_gqa, "is_causal", 0 ) else: scale = g.get_constant_scalar(local_attention_gqa.input[4]) ** 2 # this scale ** 0.5 attn_kwargs["scale"] = scale # In case we need the 3D pattern. # expand_shape = g.get_computed_constant(local_attention_gqa.input[5]) # repeat = int(expand_shape[2]) # if "sQ_" in local_attention_gqa.op_type: # k_shape = g.get_shape(local_attention_gqa.input[1]) # kv_num_heads = k_shape[1] # else: # reshape_shape = g.get_computed_constant(local_attention_gqa.input[6]) # kv_num_heads = reshape_shape[1] // repeat # # num_heads = kv_num_heads * repeat nodes = [] final_mask = mask if mask: switch_where = "SW" in local_attention_gqa.op_type if switch_where: # mask is not mask if SW if g.get_type(mask) == TensorProto.BOOL: final_mask = g.unique_name(f"{self.__class__.__name__}--{mask}") nodes.append(g._make_node("Not", [mask], [final_mask])) else: raise NotImplementedError( f"float mask is not implemented yet for pattern " f"{self.__class__.__name__!r}" ) nodes.extend( [ g._make_node( "Attention", [ query, keys_concat_node.input[1], values_concat_node.input[1], final_mask, keys_concat_node.input[0], values_concat_node.input[0], ], [ local_attention_gqa.output[0], keys_concat_node.output[0], values_concat_node.output[0], ], # q_num_heads=num_heads, # kv_num_heads=kv_num_heads, **attn_kwargs, ), ] ) for node in nodes: if not node.name: node.name = g.builder.unique_node_name( f"{self.__class__.__name__}--{local_attention_gqa.name}" ) return nodes