yobx.xoptim.patterns.onnx_attention#
- class yobx.xoptim.patterns.onnx_attention.AttentionGQAPattern(verbose: int = 0, priority: int = 2)[source]#
Fuses LocalAttention into Attention. Opset must be >= 23 to do so.
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_key(["key FLOAT(a, 2, c, 8)"]) I_mask(["mask BOOL(a, 1, c, c+h)"]) I_value(["value FLOAT(a, 2, c, 8)"]) I_past_key(["past_key FLOAT(a, 2, h, 8)"]) I_query(["query FLOAT(a, 4, c, 8)"]) I_past_value(["past_value FLOAT(a, 2, h, 8)"]) Concat_0[["Concat(., ., axis=2)"]] Concat_1[["Concat(., ., axis=2)"]] Unsqueeze_2[["Unsqueeze(., [2])"]] Expand_3[["Expand(., [1, 1, 2, 1, 1])"]] Reshape_4[["Reshape(., [0, 4, -1, 8])"]] Unsqueeze_5[["Unsqueeze(., [2])"]] Expand_6[["Expand(., [1, 1, 2, 1, 1])"]] Reshape_7[["Reshape(., [0, 4, -1, 8])"]] Attention_8[["Attention(., ., ., .)"]] I_past_key -->|"FLOAT(a, 2, h, 8)"| Concat_0 I_key -->|"FLOAT(a, 2, c, 8)"| Concat_0 I_past_value -->|"FLOAT(a, 2, h, 8)"| Concat_1 I_value -->|"FLOAT(a, 2, c, 8)"| Concat_1 Concat_0 -->|"FLOAT(a, 2, c+h, 8)"| Unsqueeze_2 Unsqueeze_2 -->|"FLOAT(a, 2, 1, c+h, 8)"| Expand_3 Expand_3 -->|"FLOAT(a, 2, 2, c+h, 8)"| Reshape_4 Concat_1 -->|"FLOAT(a, 2, c+h, 8)"| Unsqueeze_5 Unsqueeze_5 -->|"FLOAT(a, 2, 1, c+h, 8)"| Expand_6 Expand_6 -->|"FLOAT(a, 2, 2, c+h, 8)"| Reshape_7 I_query -->|"FLOAT(a, 4, c, 8)"| Attention_8 Reshape_4 -->|"FLOAT(a, 4, c+h, 8)"| Attention_8 Reshape_7 -->|"FLOAT(a, 4, c+h, 8)"| Attention_8 I_mask -->|"BOOL(a, 1, c, c+h)"| Attention_8 O_present_value(["present_value FLOAT(a, 2, c+h, 8)"]) Concat_1 --> O_present_value O_present_key(["present_key FLOAT(a, 2, c+h, 8)"]) Concat_0 --> O_present_key O_Y(["Y FLOAT(a, 4, c_, 8)"]) Attention_8 --> O_Y class I_key,I_mask,I_value,I_past_key,I_query,I_past_value ioNode class O_present_value,O_present_key,O_Y ioNode class Concat_0,Concat_1,Unsqueeze_2,Expand_3,Reshape_4,Unsqueeze_5 opNode class Expand_6,Reshape_7,Attention_8 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_key(["key FLOAT(a, 2, c, 8)"]) I_mask(["mask BOOL(a, 1, c, c+h)"]) I_value(["value FLOAT(a, 2, c, 8)"]) I_past_key(["past_key FLOAT(a, 2, h, 8)"]) I_query(["query FLOAT(a, 4, c, 8)"]) I_past_value(["past_value FLOAT(a, 2, h, 8)"]) Attention_0[["Attention(., ., ., ., ., .)"]] I_query -->|"FLOAT(a, 4, c, 8)"| Attention_0 I_key -->|"FLOAT(a, 2, c, 8)"| Attention_0 I_value -->|"FLOAT(a, 2, c, 8)"| Attention_0 I_mask -->|"BOOL(a, 1, c, c+h)"| Attention_0 I_past_key -->|"FLOAT(a, 2, h, 8)"| Attention_0 I_past_value -->|"FLOAT(a, 2, h, 8)"| Attention_0 O_present_value(["present_value FLOAT(a, 2, c+h, 8)"]) Attention_0 --> O_present_value O_present_key(["present_key FLOAT(a, 2, c+h, 8)"]) Attention_0 --> O_present_key O_Y(["Y FLOAT(a, 4, c_, 8)"]) Attention_0 --> O_Y class I_key,I_mask,I_value,I_past_key,I_query,I_past_value opNode class O_present_value,O_present_key,O_Y ioNode class Attention_0 opNode- apply(g: GraphBuilder, keys_concat_node: NodeProto, values_concat_node: NodeProto, gqa_unsqueeze: NodeProto | None, gqa_expand: NodeProto | None, gqa_reshape: NodeProto | None, gqa_unsqueeze_v: NodeProto | None, gqa_expand_v: NodeProto | None, gqa_reshape_v: NodeProto | None, local_attention_gqa: NodeProto | None) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_attention.FunctionAttentionGQAPattern(verbose: int = 0, priority: int = 0)[source]#
Merges onnx nodes equivalent to repeat interleave followed by function
LocalAttentionintoLocalAttentionGQA(GQA for GroupQueryAttention).Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"]) I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"]) I_to(["to BOOL(seq_length, total_length)"]) I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"]) I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"]) I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"]) I_query(["query FLOAT(batch, 8, seq_length, 32)"]) Constant_0[["Constant() -#gt; init7_s5_1_1_2_1_1"]] Constant_1[["Constant() -#gt; init7_s4_0_8_-1_32"]] Constant_2[["Constant() -#gt; init1_s_::RSh1"]] Unsqueeze_3[["Unsqueeze(., [2])"]] Expand_4[["Expand(., .)"]] Reshape_5[["Reshape(., .)"]] Unsqueeze_6[["Unsqueeze(., [2])"]] Expand_7[["Expand(., .)"]] Reshape_8[["Reshape(., .)"]] LocalAttentionSW_to1_9[["intermediate.LocalAttentionSW_to1(., ., ., ., .)"]] I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_3 Unsqueeze_3 --> Expand_4 Constant_0 -->|"INT64(5)"| Expand_4 Expand_4 --> Reshape_5 Constant_1 -->|"INT64(4)"| Reshape_5 I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_6 Unsqueeze_6 --> Expand_7 Constant_0 -->|"INT64(5)"| Expand_7 Expand_7 --> Reshape_8 Constant_1 -->|"INT64(4)"| Reshape_8 I_query -->|"FLOAT(batch, 8, seq_length, 32)"| LocalAttentionSW_to1_9 Reshape_5 --> LocalAttentionSW_to1_9 Reshape_8 --> LocalAttentionSW_to1_9 I_to -->|"BOOL(seq_length, total_length)"| LocalAttentionSW_to1_9 Constant_2 -->|"FLOAT(1)"| LocalAttentionSW_to1_9 O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"]) LocalAttentionSW_to1_9 --> O_output_0 class I_cat,I_init1_s___RSh1,I_to,I_init7_s4_0_8__1_32 ioNode class I_init7_s5_1_1_2_1_1,I_cat_1,I_query,O_output_0 ioNode class Constant_0,Constant_1,Constant_2 constNode class Unsqueeze_3,Expand_4,Reshape_5,Unsqueeze_6,Expand_7 opNode class Reshape_8,LocalAttentionSW_to1_9 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"]) I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"]) I_to(["to BOOL(seq_length, total_length)"]) I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"]) I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"]) I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"]) I_query(["query FLOAT(batch, 8, seq_length, 32)"]) LocalAttentionGQASW_to1_0[["intermediate.LocalAttentionGQASW_to1( ., ., ., ., ., ., .)"]] I_query -->|"FLOAT(batch, 8, seq_length, 32)"| LocalAttentionGQASW_to1_0 I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0 I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0 I_to -->|"BOOL(seq_length, total_length)"| LocalAttentionGQASW_to1_0 I_init1_s___RSh1 -->|"FLOAT(1)"| LocalAttentionGQASW_to1_0 I_init7_s5_1_1_2_1_1 -->|"INT64(5)"| LocalAttentionGQASW_to1_0 I_init7_s4_0_8__1_32 -->|"INT64(4)"| LocalAttentionGQASW_to1_0 O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"]) LocalAttentionGQASW_to1_0 --> O_output_0 class I_cat,I_init1_s___RSh1,I_to,I_init7_s4_0_8__1_32 ioNode class I_init7_s5_1_1_2_1_1,I_cat_1,I_query,O_output_0 ioNode class LocalAttentionGQASW_to1_0 opNode- apply(g: GraphBuilder, gqa_unsqueeze: NodeProto, gqa_expand: NodeProto, gqa_reshape: NodeProto, gqa_unsqueeze_v: NodeProto, gqa_expand_v: NodeProto, gqa_reshape_v: NodeProto, attn: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_attention.FunctionAttentionPattern(verbose: int = 0, priority: int = 0)[source]#
Merges Attention nodes into a local function. That includes a version for GroupQueryAttention (see second pattern).
Main Pattern#
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_values(["values FLOAT(av, bv, cv, dv)"]) I_keys(["keys FLOAT(ak, bk, ck, dk)"]) I_scale_sqrt(["scale_sqrt FLOAT(1)"]) I_mask(["mask BOOL(am, bm, cm, dm)"]) I_query(["query FLOAT(aq, bq, cq, dq)"]) Constant_0[["Constant() -#gt; scale_sqrt"]] Mul_1[["Mul(., .)"]] Mul_2[["Mul(., .)"]] Transpose_3[["Transpose(., perm=[0, 1, 3, 2])"]] MatMul_4[["MatMul(., .)"]] Where_5[["Where(., [0.0], [-inf])"]] Add_6[["Add(., .)"]] Softmax_7[["Softmax(., axis=-1)"]] IsNaN_8[["IsNaN(.)"]] Where_9[["Where(., [0.0], .)"]] MatMul_10[["MatMul(., .)"]] I_query -->|"FLOAT(aq, bq, cq, dq)"| Mul_1 Constant_0 -->|"FLOAT(1)"| Mul_1 I_keys -->|"FLOAT(ak, bk, ck, dk)"| Mul_2 Constant_0 -->|"FLOAT(1)"| Mul_2 Mul_2 -->|"FLOAT(ak, bk, ck, dk)"| Transpose_3 Mul_1 -->|"FLOAT(aq, bq, cq, dq)"| MatMul_4 Transpose_3 -->|"FLOAT(ak, bk, dk, ck)"| MatMul_4 I_mask -->|"BOOL(am, bm, cm, dm)"| Where_5 MatMul_4 -->|"FLOAT(aq^ak, bq^bk, cq, ck)"| Add_6 Where_5 -->|"FLOAT(am, bm, cm, dm)"| Add_6 Add_6 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| Softmax_7 Softmax_7 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| IsNaN_8 IsNaN_8 -->|"BOOL(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| Where_9 Softmax_7 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| Where_9 Where_9 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| MatMul_10 I_values -->|"FLOAT(av, bv, cv, dv)"| MatMul_10 O_Y(["Y FLOAT(ay, by, cy, dy)"]) MatMul_10 --> O_Y class I_values,I_keys,I_scale_sqrt,I_mask,I_query,O_Y ioNode class Constant_0 constNode class Mul_1,Mul_2,Transpose_3,MatMul_4,Where_5,Add_6,Softmax_7 opNode class IsNaN_8,Where_9,MatMul_10 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_values(["values FLOAT(av, bv, cv, dv)"]) I_keys(["keys FLOAT(ak, bk, ck, dk)"]) I_scale_sqrt(["scale_sqrt FLOAT(1)"]) I_mask(["mask BOOL(am, bm, cm, dm)"]) I_query(["query FLOAT(aq, bq, cq, dq)"]) LocalAttention_to1_0[["intermediate.LocalAttention_to1(., ., ., ., .)"]] I_query -->|"FLOAT(aq, bq, cq, dq)"| LocalAttention_to1_0 I_keys -->|"FLOAT(ak, bk, ck, dk)"| LocalAttention_to1_0 I_values -->|"FLOAT(av, bv, cv, dv)"| LocalAttention_to1_0 I_mask -->|"BOOL(am, bm, cm, dm)"| LocalAttention_to1_0 I_scale_sqrt -->|"FLOAT(1)"| LocalAttention_to1_0 O_Y(["Y FLOAT(ay, by, cy, dy)"]) LocalAttention_to1_0 --> O_Y class I_values,I_keys,I_scale_sqrt,I_mask,I_query,O_Y ioNode class LocalAttention_to1_0 opNodeGroupQueryAttention (GQA)#
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"]) I_query(["query FLOAT(batch, 8, seq_length, 32)"]) I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"]) I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"]) I_to(["to BOOL(seq_length, total_length)"]) I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"]) I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"]) Constant_0[["Constant() -#gt; init1_s_::RSh1"]] Constant_1[["Constant() -#gt; init7_s5_1_1_2_1_1"]] Constant_2[["Constant() -#gt; init7_s4_0_8_-1_32"]] Mul_3[["Mul(., .)"]] Unsqueeze_4[["Unsqueeze(., [2])"]] Mul_5[["Mul(., [0.4204482])"]] Expand_6[["Expand(., .)"]] Reshape_7[["Reshape(., .)"]] Transpose_8[["Transpose(., perm=[0, 1, 3, 2])"]] MatMul_9[["MatMul(., .)"]] w10[["Where(., [-inf], .)"]] s11[["Softmax(., axis=-1)"]] nan12[["IsNaN(.)"]] w13[["Where(., 0.0, .)"]] Unsqueeze_14[["Unsqueeze(., [2])"]] Expand_15[["Expand(., .)"]] Reshape_16[["Reshape(., .)"]] mm17[["MatMul(., .)"]] I_query -->|"FLOAT(batch, 8, seq_length, 32)"| Mul_3 Constant_0 -->|"FLOAT(1)"| Mul_3 I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_4 Unsqueeze_4 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Mul_5 Expand_6 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Expand_6 Constant_1 -->|"INT64(5)"| Expand_6 Expand_6 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Reshape_7 Constant_2 -->|"INT64(4)"| Reshape_7 Reshape_7 -->|"FLOAT(batch, 8, 128*(past_length+seq_length)//256, 32)"| Transpose_8 Mul_3 -->|"FLOAT(batch, 8, seq_length, 32)"| MatMul_9 Transpose_8 -->|"FLOAT(batch, 8, 32, 128*(past_length+seq_length)//256)"| MatMul_9 I_to -->|"BOOL(seq_length, total_length)"| w10 MatMul_9 -->|"FLOAT(batch, 8, seq_length, 128*(past_length+seq_length)//256)"| w10 w10 -->|"FLOAT(batch, 8, seq_length, total_length^128*(past_length+seq_length)//256)"| s11 s11 -->|"FLOAT(batch, 8, seq_length, total_length^128*(past_length+seq_length)//256)"| nan12 nan12 -->|"BOOL(batch, 8, seq_length, total_length^128*(past_length+seq_length)//256)"| w13 s11 -->|"FLOAT(batch, 8, seq_length, total_length^128*(past_length+seq_length)//256)"| w13 I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_14 Unsqueeze_14 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Expand_15 Constant_1 -->|"INT64(5)"| Expand_15 Expand_15 -->|"FLOAT(batch, 4, 2, past_length+seq_length, 32)"| Reshape_16 Constant_2 -->|"INT64(4)"| Reshape_16 w13 -->|"FLOAT(batch, 8, seq_length, total_length^128*(past_length+seq_length)//256)"| mm17 Reshape_16 -->|"FLOAT(batch, 8, past_length+seq_length, 32)"| mm17 O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"]) mm17 --> O_output_0 class I_init1_s___RSh1,I_query,I_cat_1,I_cat,I_to ioNode class I_init7_s4_0_8__1_32,I_init7_s5_1_1_2_1_1,O_output_0 ioNode class Constant_0,Constant_1,Constant_2 constNode class Mul_3,Unsqueeze_4,Mul_5,Expand_6,Reshape_7,Transpose_8 opNode class MatMul_9,w10,s11,nan12,w13,Unsqueeze_14,Expand_15 opNode class Reshape_16,mm17 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_query(["query FLOAT(batch, 8, seq_length, 32)"]) I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"]) I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"]) I_to(["to BOOL(seq_length, total_length)"]) I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"]) I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"]) I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"]) LocalAttentionGQASW_to1_0[["intermediate.LocalAttentionGQASW_to1( ., ., ., ., ., ., .)"]] I_query -->|"FLOAT(batch, 8, seq_length, 32)"| LocalAttentionGQASW_to1_0 I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0 I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0 I_to -->|"BOOL(seq_length, total_length)"| LocalAttentionGQASW_to1_0 I_init1_s___RSh1 -->|"FLOAT(1)"| LocalAttentionGQASW_to1_0 I_init7_s5_1_1_2_1_1 -->|"INT64(5)"| LocalAttentionGQASW_to1_0 I_init7_s4_0_8__1_32 -->|"INT64(4)"| LocalAttentionGQASW_to1_0 O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"]) LocalAttentionGQASW_to1_0 --> O_output_0 class I_query,I_cat_1,I_cat,I_to,I_init7_s4_0_8__1_32 ioNode class I_init7_s5_1_1_2_1_1,I_init1_s___RSh1,O_output_0 ioNode class LocalAttentionGQASW_to1_0 opNode3D Pattern#
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_values_t(["values_t FLOAT(av, 8, cv, 64)"]) I_keys(["keys FLOAT(ak, ck, bk*dk)"]) I_scale_sqrt(["scale_sqrt FLOAT(1)"]) I_shape0(["shape0 INT64(4)"]) I_mask(["mask BOOL(am, bm, cm, dm)"]) I_query(["query FLOAT(aq, cq, bq*dq)"]) Constant_0[["Constant() -#gt; scale_sqrt"]] Constant_1[["Constant() -#gt; shape0"]] Mul_2[["Mul(., .)"]] Transpose_3[["Transpose(., perm=[0, 2, 1, 3])"]] Reshape_4[["Reshape(., .)"]] Mul_5[["Mul(., .)"]] Reshape_6[["Reshape(., .)"]] Transpose_7[["Transpose(., perm=[0, 2, 3, 1])"]] MatMul_8[["MatMul(., .)"]] Where_9[["Where(., [0.0], [-inf])"]] Add_10[["Add(., .)"]] s11[["Softmax(., axis=-1)"]] nan12[["IsNaN(.)"]] w13[["Where(., [0.0], .)"]] MatMul_14[["MatMul(., .)"]] I_query -->|"FLOAT(aq, cq, bq*dq)"| Mul_2 Constant_0 -->|"FLOAT(1)"| Mul_2 Reshape_4 -->|"FLOAT(aq, cq, 8, 64)"| Transpose_3 Mul_2 -->|"FLOAT(aq, cq, bq*dq)"| Reshape_4 Constant_1 -->|"INT64(4)"| Reshape_4 I_keys -->|"FLOAT(ak, ck, bk*dk)"| Mul_5 Constant_0 -->|"FLOAT(1)"| Mul_5 Mul_5 -->|"FLOAT(ak, ck, bk*dk)"| Reshape_6 Constant_1 -->|"INT64(4)"| Reshape_6 Reshape_6 -->|"FLOAT(ak, ck, 8, 64)"| Transpose_7 Transpose_3 --> MatMul_8 Transpose_7 -->|"FLOAT(ak, 8, 64, ck)"| MatMul_8 I_mask -->|"BOOL(am, bm, cm, dm)"| Where_9 MatMul_8 --> Add_10 Where_9 -->|"FLOAT(am, bm, cm, dm)"| Add_10 Add_10 --> s11 s11 --> nan12 nan12 --> w13 s11 --> w13 w13 --> MatMul_14 I_values_t -->|"FLOAT(av, 8, cv, 64)"| MatMul_14 O_Y(["Y FLOAT(ay, by, cy, dy)"]) MatMul_14 --> O_Y class I_values_t,I_keys,I_scale_sqrt,I_shape0,I_mask,I_query,O_Y ioNode class Constant_0,Constant_1 constNode class Mul_2,Transpose_3,Reshape_4,Mul_5,Reshape_6,Transpose_7 opNode class MatMul_8,Where_9,Add_10,s11,nan12,w13,MatMul_14 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_values_t(["values_t FLOAT(av, 8, cv, 64)"]) I_keys(["keys FLOAT(ak, ck, bk*dk)"]) I_scale_sqrt(["scale_sqrt FLOAT(1)"]) I_shape0(["shape0 INT64(4)"]) I_mask(["mask BOOL(am, bm, cm, dm)"]) I_query(["query FLOAT(aq, cq, bq*dq)"]) Reshape_0[["Reshape(., .)"]] Reshape_1[["Reshape(., .)"]] Transpose_2[["Transpose(., perm=[0, 2, 1, 3])"]] Transpose_3[["Transpose(., perm=[0, 2, 1, 3])"]] LocalAttention_to1_4[["intermediate.LocalAttention_to1(., ., ., ., .)"]] I_query -->|"FLOAT(aq, cq, bq*dq)"| Reshape_0 I_shape0 -->|"INT64(4)"| Reshape_0 I_keys -->|"FLOAT(ak, ck, bk*dk)"| Reshape_1 I_shape0 -->|"INT64(4)"| Reshape_1 Reshape_0 --> Transpose_2 Reshape_1 --> Transpose_3 Transpose_2 --> LocalAttention_to1_4 Transpose_3 --> LocalAttention_to1_4 I_values_t -->|"FLOAT(av, 8, cv, 64)"| LocalAttention_to1_4 I_mask -->|"BOOL(am, bm, cm, dm)"| LocalAttention_to1_4 I_scale_sqrt -->|"FLOAT(1)"| LocalAttention_to1_4 O_Y(["Y FLOAT(ay, by, cy, dy)"]) LocalAttention_to1_4 --> O_Y class I_values_t,I_keys,I_scale_sqrt,I_shape0,I_mask,I_query,O_Y ioNode class Reshape_0,Reshape_1,Transpose_2,Transpose_3,LocalAttention_to1_4 opNode- apply(g: GraphBuilder, mul1: NodeProto, transpose_mul1: NodeProto | None, reshape_mul1: NodeProto | None, gqa_unsqueeze: NodeProto | None, mul2: NodeProto, reshape_mul2: NodeProto | None, gqa_expand: NodeProto | None, gqa_reshape: NodeProto | None, transpose: NodeProto | None, mat_qk: NodeProto, where_node: NodeProto, add_node: NodeProto | None, softmax: NodeProto, isnan: NodeProto, where: NodeProto, gqa_unsqueeze_v: NodeProto | None, gqa_expand_v: NodeProto | None, gqa_reshape_v: NodeProto | None, mat_qkv: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.