yobx.xoptim.patterns.onnx_attention#

class yobx.xoptim.patterns.onnx_attention.AttentionGQAPattern(verbose: int = 0, priority: int = 2)[source]#

Fuses LocalAttention into Attention. Opset must be >= 23 to do so.

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_key(["key FLOAT(a, 2, c, 8)"])
    I_mask(["mask BOOL(a, 1, c, c+h)"])
    I_value(["value FLOAT(a, 2, c, 8)"])
    I_past_key(["past_key FLOAT(a, 2, h, 8)"])
    I_query(["query FLOAT(a, 4, c, 8)"])
    I_past_value(["past_value FLOAT(a, 2, h, 8)"])

    Concat_0[["Concat(., ., axis=2)"]]
    Concat_1[["Concat(., ., axis=2)"]]
    Unsqueeze_2[["Unsqueeze(., [2])"]]
    Expand_3[["Expand(., [1, 1, 2, 1, 1])"]]
    Reshape_4[["Reshape(., [0, 4, -1, 8])"]]
    Unsqueeze_5[["Unsqueeze(., [2])"]]
    Expand_6[["Expand(., [1, 1, 2, 1, 1])"]]
    Reshape_7[["Reshape(., [0, 4, -1, 8])"]]
    Attention_8[["Attention(., ., ., .)"]]

    I_past_key -->|"FLOAT(a, 2, h, 8)"| Concat_0
    I_key -->|"FLOAT(a, 2, c, 8)"| Concat_0
    I_past_value -->|"FLOAT(a, 2, h, 8)"| Concat_1
    I_value -->|"FLOAT(a, 2, c, 8)"| Concat_1
    Concat_0 -->|"FLOAT(a, 2, c+h, 8)"| Unsqueeze_2
    Unsqueeze_2 -->|"FLOAT(a, 2, 1, c+h, 8)"| Expand_3
    Expand_3 -->|"FLOAT(a, 2, 2, c+h, 8)"| Reshape_4
    Concat_1 -->|"FLOAT(a, 2, c+h, 8)"| Unsqueeze_5
    Unsqueeze_5 -->|"FLOAT(a, 2, 1, c+h, 8)"| Expand_6
    Expand_6 -->|"FLOAT(a, 2, 2, c+h, 8)"| Reshape_7
    I_query -->|"FLOAT(a, 4, c, 8)"| Attention_8
    Reshape_4 -->|"FLOAT(a, 4, c+h, 8)"| Attention_8
    Reshape_7 -->|"FLOAT(a, 4, c+h, 8)"| Attention_8
    I_mask -->|"BOOL(a, 1, c, c+h)"| Attention_8

    O_present_value(["present_value FLOAT(a, 2, c+h, 8)"])
    Concat_1 --> O_present_value
    O_present_key(["present_key FLOAT(a, 2, c+h, 8)"])
    Concat_0 --> O_present_key
    O_Y(["Y FLOAT(a, 4, c_, 8)"])
    Attention_8 --> O_Y

    class I_key,I_mask,I_value,I_past_key,I_query,I_past_value ioNode
    class O_present_value,O_present_key,O_Y ioNode
    class Concat_0,Concat_1,Unsqueeze_2,Expand_3,Reshape_4,Unsqueeze_5 opNode
    class Expand_6,Reshape_7,Attention_8 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_key(["key FLOAT(a, 2, c, 8)"])
    I_mask(["mask BOOL(a, 1, c, c+h)"])
    I_value(["value FLOAT(a, 2, c, 8)"])
    I_past_key(["past_key FLOAT(a, 2, h, 8)"])
    I_query(["query FLOAT(a, 4, c, 8)"])
    I_past_value(["past_value FLOAT(a, 2, h, 8)"])

    Attention_0[["Attention(., ., ., ., ., .)"]]

    I_query -->|"FLOAT(a, 4, c, 8)"| Attention_0
    I_key -->|"FLOAT(a, 2, c, 8)"| Attention_0
    I_value -->|"FLOAT(a, 2, c, 8)"| Attention_0
    I_mask -->|"BOOL(a, 1, c, c+h)"| Attention_0
    I_past_key -->|"FLOAT(a, 2, h, 8)"| Attention_0
    I_past_value -->|"FLOAT(a, 2, h, 8)"| Attention_0

    O_present_value(["present_value FLOAT(a, 2, c+h, 8)"])
    Attention_0 --> O_present_value
    O_present_key(["present_key FLOAT(a, 2, c+h, 8)"])
    Attention_0 --> O_present_key
    O_Y(["Y FLOAT(a, 4, c_, 8)"])
    Attention_0 --> O_Y

    class I_key,I_mask,I_value,I_past_key,I_query,I_past_value opNode
    class O_present_value,O_present_key,O_Y ioNode
    class Attention_0 opNode
    
apply(g: GraphBuilder, keys_concat_node: NodeProto, values_concat_node: NodeProto, gqa_unsqueeze: NodeProto | None, gqa_expand: NodeProto | None, gqa_reshape: NodeProto | None, gqa_unsqueeze_v: NodeProto | None, gqa_expand_v: NodeProto | None, gqa_reshape_v: NodeProto | None, local_attention_gqa: NodeProto | None) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_attention.FunctionAttentionGQAPattern(verbose: int = 0, priority: int = 0)[source]#

Merges onnx nodes equivalent to repeat interleave followed by function LocalAttention into LocalAttentionGQA (GQA for GroupQueryAttention).

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"])
    I_to(["to BOOL(seq_length, total_length)"])
    I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"])
    I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"])
    I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_query(["query FLOAT(batch, 8, seq_length, 32)"])

    Constant_0[["Constant() -#gt; init7_s5_1_1_2_1_1"]]
    Constant_1[["Constant() -#gt; init7_s4_0_8_-1_32"]]
    Constant_2[["Constant() -#gt; init1_s_::RSh1"]]
    Unsqueeze_3[["Unsqueeze(., [2])"]]
    Expand_4[["Expand(., .)"]]
    Reshape_5[["Reshape(., .)"]]
    Unsqueeze_6[["Unsqueeze(., [2])"]]
    Expand_7[["Expand(., .)"]]
    Reshape_8[["Reshape(., .)"]]
    LocalAttentionSW_to1_9[["intermediate.LocalAttentionSW_to1(., ., ., ., .)"]]

    I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_3
    Unsqueeze_3 --> Expand_4
    Constant_0 -->|"INT64(5)"| Expand_4
    Expand_4 --> Reshape_5
    Constant_1 -->|"INT64(4)"| Reshape_5
    I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_6
    Unsqueeze_6 --> Expand_7
    Constant_0 -->|"INT64(5)"| Expand_7
    Expand_7 --> Reshape_8
    Constant_1 -->|"INT64(4)"| Reshape_8
    I_query -->|"FLOAT(batch, 8, seq_length, 32)"| LocalAttentionSW_to1_9
    Reshape_5 --> LocalAttentionSW_to1_9
    Reshape_8 --> LocalAttentionSW_to1_9
    I_to -->|"BOOL(seq_length, total_length)"| LocalAttentionSW_to1_9
    Constant_2 -->|"FLOAT(1)"| LocalAttentionSW_to1_9

    O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"])
    LocalAttentionSW_to1_9 --> O_output_0

    class I_cat,I_init1_s___RSh1,I_to,I_init7_s4_0_8__1_32 ioNode
    class I_init7_s5_1_1_2_1_1,I_cat_1,I_query,O_output_0 ioNode
    class Constant_0,Constant_1,Constant_2 constNode
    class Unsqueeze_3,Expand_4,Reshape_5,Unsqueeze_6,Expand_7 opNode
    class Reshape_8,LocalAttentionSW_to1_9 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"])
    I_to(["to BOOL(seq_length, total_length)"])
    I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"])
    I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"])
    I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_query(["query FLOAT(batch, 8, seq_length, 32)"])

    LocalAttentionGQASW_to1_0[["intermediate.LocalAttentionGQASW_to1(
    ., ., ., ., ., ., .)"]]

    I_query -->|"FLOAT(batch, 8, seq_length, 32)"| LocalAttentionGQASW_to1_0
    I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0
    I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0
    I_to -->|"BOOL(seq_length, total_length)"| LocalAttentionGQASW_to1_0
    I_init1_s___RSh1 -->|"FLOAT(1)"| LocalAttentionGQASW_to1_0
    I_init7_s5_1_1_2_1_1 -->|"INT64(5)"| LocalAttentionGQASW_to1_0
    I_init7_s4_0_8__1_32 -->|"INT64(4)"| LocalAttentionGQASW_to1_0

    O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"])
    LocalAttentionGQASW_to1_0 --> O_output_0

    class I_cat,I_init1_s___RSh1,I_to,I_init7_s4_0_8__1_32 ioNode
    class I_init7_s5_1_1_2_1_1,I_cat_1,I_query,O_output_0 ioNode
    class LocalAttentionGQASW_to1_0 opNode
    
apply(g: GraphBuilder, gqa_unsqueeze: NodeProto, gqa_expand: NodeProto, gqa_reshape: NodeProto, gqa_unsqueeze_v: NodeProto, gqa_expand_v: NodeProto, gqa_reshape_v: NodeProto, attn: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_attention.FunctionAttentionPattern(verbose: int = 0, priority: int = 0)[source]#

Merges Attention nodes into a local function. That includes a version for GroupQueryAttention (see second pattern).

Main Pattern#

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_values(["values FLOAT(av, bv, cv, dv)"])
    I_keys(["keys FLOAT(ak, bk, ck, dk)"])
    I_scale_sqrt(["scale_sqrt FLOAT(1)"])
    I_mask(["mask BOOL(am, bm, cm, dm)"])
    I_query(["query FLOAT(aq, bq, cq, dq)"])

    Constant_0[["Constant() -#gt; scale_sqrt"]]
    Mul_1[["Mul(., .)"]]
    Mul_2[["Mul(., .)"]]
    Transpose_3[["Transpose(., perm=[0, 1, 3, 2])"]]
    MatMul_4[["MatMul(., .)"]]
    Where_5[["Where(., [0.0], [-inf])"]]
    Add_6[["Add(., .)"]]
    Softmax_7[["Softmax(., axis=-1)"]]
    IsNaN_8[["IsNaN(.)"]]
    Where_9[["Where(., [0.0], .)"]]
    MatMul_10[["MatMul(., .)"]]

    I_query -->|"FLOAT(aq, bq, cq, dq)"| Mul_1
    Constant_0 -->|"FLOAT(1)"| Mul_1
    I_keys -->|"FLOAT(ak, bk, ck, dk)"| Mul_2
    Constant_0 -->|"FLOAT(1)"| Mul_2
    Mul_2 -->|"FLOAT(ak, bk, ck, dk)"| Transpose_3
    Mul_1 -->|"FLOAT(aq, bq, cq, dq)"| MatMul_4
    Transpose_3 -->|"FLOAT(ak, bk, dk, ck)"| MatMul_4
    I_mask -->|"BOOL(am, bm, cm, dm)"| Where_5
    MatMul_4 -->|"FLOAT(aq^ak, bq^bk, cq, ck)"| Add_6
    Where_5 -->|"FLOAT(am, bm, cm, dm)"| Add_6
    Add_6 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| Softmax_7
    Softmax_7 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| IsNaN_8
    IsNaN_8 -->|"BOOL(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| Where_9
    Softmax_7 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| Where_9
    Where_9 -->|"FLOAT(aq^ak^am, bq^bk^bm, cq^cm, ck^dm)"| MatMul_10
    I_values -->|"FLOAT(av, bv, cv, dv)"| MatMul_10

    O_Y(["Y FLOAT(ay, by, cy, dy)"])
    MatMul_10 --> O_Y

    class I_values,I_keys,I_scale_sqrt,I_mask,I_query,O_Y ioNode
    class Constant_0 constNode
    class Mul_1,Mul_2,Transpose_3,MatMul_4,Where_5,Add_6,Softmax_7 opNode
    class IsNaN_8,Where_9,MatMul_10 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_values(["values FLOAT(av, bv, cv, dv)"])
    I_keys(["keys FLOAT(ak, bk, ck, dk)"])
    I_scale_sqrt(["scale_sqrt FLOAT(1)"])
    I_mask(["mask BOOL(am, bm, cm, dm)"])
    I_query(["query FLOAT(aq, bq, cq, dq)"])

    LocalAttention_to1_0[["intermediate.LocalAttention_to1(., ., ., ., .)"]]

    I_query -->|"FLOAT(aq, bq, cq, dq)"| LocalAttention_to1_0
    I_keys -->|"FLOAT(ak, bk, ck, dk)"| LocalAttention_to1_0
    I_values -->|"FLOAT(av, bv, cv, dv)"| LocalAttention_to1_0
    I_mask -->|"BOOL(am, bm, cm, dm)"| LocalAttention_to1_0
    I_scale_sqrt -->|"FLOAT(1)"| LocalAttention_to1_0

    O_Y(["Y FLOAT(ay, by, cy, dy)"])
    LocalAttention_to1_0 --> O_Y

    class I_values,I_keys,I_scale_sqrt,I_mask,I_query,O_Y ioNode
    class LocalAttention_to1_0 opNode
    

GroupQueryAttention (GQA)#

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"])
    I_query(["query FLOAT(batch, 8, seq_length, 32)"])
    I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_to(["to BOOL(seq_length, total_length)"])
    I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"])
    I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"])

    Constant_0[["Constant() -#gt; init1_s_::RSh1"]]
    Constant_1[["Constant() -#gt; init7_s5_1_1_2_1_1"]]
    Constant_2[["Constant() -#gt; init7_s4_0_8_-1_32"]]
    Mul_3[["Mul(., .)"]]
    Unsqueeze_4[["Unsqueeze(., [2])"]]
    Mul_5[["Mul(., [0.4204482])"]]
    Expand_6[["Expand(., .)"]]
    Reshape_7[["Reshape(., .)"]]
    Transpose_8[["Transpose(., perm=[0, 1, 3, 2])"]]
    MatMul_9[["MatMul(., .)"]]
    w10[["Where(., [-inf], .)"]]
    s11[["Softmax(., axis=-1)"]]
    nan12[["IsNaN(.)"]]
    w13[["Where(., 0.0, .)"]]
    Unsqueeze_14[["Unsqueeze(., [2])"]]
    Expand_15[["Expand(., .)"]]
    Reshape_16[["Reshape(., .)"]]
    mm17[["MatMul(., .)"]]

    I_query -->|"FLOAT(batch, 8, seq_length, 32)"| Mul_3
    Constant_0 -->|"FLOAT(1)"| Mul_3
    I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_4
    Unsqueeze_4 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Mul_5
    Expand_6 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Expand_6
    Constant_1 -->|"INT64(5)"| Expand_6
    Expand_6 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Reshape_7
    Constant_2 -->|"INT64(4)"| Reshape_7
    Reshape_7 -->|"FLOAT(batch, 8, 128*(past_length+seq_length)//256, 32)"| Transpose_8
    Mul_3 -->|"FLOAT(batch, 8, seq_length, 32)"| MatMul_9
    Transpose_8 -->|"FLOAT(batch, 8, 32, 128*(past_length+seq_length)//256)"| MatMul_9
    I_to -->|"BOOL(seq_length, total_length)"| w10
    MatMul_9 -->|"FLOAT(batch, 8, seq_length, 128*(past_length+seq_length)//256)"| w10
    w10 -->|"FLOAT(batch, 8, seq_length,
    total_length^128*(past_length+seq_length)//256)"| s11
    s11 -->|"FLOAT(batch, 8, seq_length,
    total_length^128*(past_length+seq_length)//256)"| nan12
    nan12 -->|"BOOL(batch, 8, seq_length,
    total_length^128*(past_length+seq_length)//256)"| w13
    s11 -->|"FLOAT(batch, 8, seq_length,
    total_length^128*(past_length+seq_length)//256)"| w13
    I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| Unsqueeze_14
    Unsqueeze_14 -->|"FLOAT(batch, 4, 1, past_length+seq_length, 32)"| Expand_15
    Constant_1 -->|"INT64(5)"| Expand_15
    Expand_15 -->|"FLOAT(batch, 4, 2, past_length+seq_length, 32)"| Reshape_16
    Constant_2 -->|"INT64(4)"| Reshape_16
    w13 -->|"FLOAT(batch, 8, seq_length,
    total_length^128*(past_length+seq_length)//256)"| mm17
    Reshape_16 -->|"FLOAT(batch, 8, past_length+seq_length, 32)"| mm17

    O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"])
    mm17 --> O_output_0

    class I_init1_s___RSh1,I_query,I_cat_1,I_cat,I_to ioNode
    class I_init7_s4_0_8__1_32,I_init7_s5_1_1_2_1_1,O_output_0 ioNode
    class Constant_0,Constant_1,Constant_2 constNode
    class Mul_3,Unsqueeze_4,Mul_5,Expand_6,Reshape_7,Transpose_8 opNode
    class MatMul_9,w10,s11,nan12,w13,Unsqueeze_14,Expand_15 opNode
    class Reshape_16,mm17 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_query(["query FLOAT(batch, 8, seq_length, 32)"])
    I_cat_1(["cat_1 FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_cat(["cat FLOAT(batch, 4, past_length+seq_length, 32)"])
    I_to(["to BOOL(seq_length, total_length)"])
    I_init7_s4_0_8__1_32(["init7_s4_0_8_-1_32 INT64(4)"])
    I_init7_s5_1_1_2_1_1(["init7_s5_1_1_2_1_1 INT64(5)"])
    I_init1_s___RSh1(["init1_s_::RSh1 FLOAT(1)"])

    LocalAttentionGQASW_to1_0[["intermediate.LocalAttentionGQASW_to1(
    ., ., ., ., ., ., .)"]]

    I_query -->|"FLOAT(batch, 8, seq_length, 32)"| LocalAttentionGQASW_to1_0
    I_cat -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0
    I_cat_1 -->|"FLOAT(batch, 4, past_length+seq_length, 32)"| LocalAttentionGQASW_to1_0
    I_to -->|"BOOL(seq_length, total_length)"| LocalAttentionGQASW_to1_0
    I_init1_s___RSh1 -->|"FLOAT(1)"| LocalAttentionGQASW_to1_0
    I_init7_s5_1_1_2_1_1 -->|"INT64(5)"| LocalAttentionGQASW_to1_0
    I_init7_s4_0_8__1_32 -->|"INT64(4)"| LocalAttentionGQASW_to1_0

    O_output_0(["output_0 FLOAT(batch, 8, seq_length, 32)"])
    LocalAttentionGQASW_to1_0 --> O_output_0

    class I_query,I_cat_1,I_cat,I_to,I_init7_s4_0_8__1_32 ioNode
    class I_init7_s5_1_1_2_1_1,I_init1_s___RSh1,O_output_0 ioNode
    class LocalAttentionGQASW_to1_0 opNode
    

3D Pattern#

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_values_t(["values_t FLOAT(av, 8, cv, 64)"])
    I_keys(["keys FLOAT(ak, ck, bk*dk)"])
    I_scale_sqrt(["scale_sqrt FLOAT(1)"])
    I_shape0(["shape0 INT64(4)"])
    I_mask(["mask BOOL(am, bm, cm, dm)"])
    I_query(["query FLOAT(aq, cq, bq*dq)"])

    Constant_0[["Constant() -#gt; scale_sqrt"]]
    Constant_1[["Constant() -#gt; shape0"]]
    Mul_2[["Mul(., .)"]]
    Transpose_3[["Transpose(., perm=[0, 2, 1, 3])"]]
    Reshape_4[["Reshape(., .)"]]
    Mul_5[["Mul(., .)"]]
    Reshape_6[["Reshape(., .)"]]
    Transpose_7[["Transpose(., perm=[0, 2, 3, 1])"]]
    MatMul_8[["MatMul(., .)"]]
    Where_9[["Where(., [0.0], [-inf])"]]
    Add_10[["Add(., .)"]]
    s11[["Softmax(., axis=-1)"]]
    nan12[["IsNaN(.)"]]
    w13[["Where(., [0.0], .)"]]
    MatMul_14[["MatMul(., .)"]]

    I_query -->|"FLOAT(aq, cq, bq*dq)"| Mul_2
    Constant_0 -->|"FLOAT(1)"| Mul_2
    Reshape_4 -->|"FLOAT(aq, cq, 8, 64)"| Transpose_3
    Mul_2 -->|"FLOAT(aq, cq, bq*dq)"| Reshape_4
    Constant_1 -->|"INT64(4)"| Reshape_4
    I_keys -->|"FLOAT(ak, ck, bk*dk)"| Mul_5
    Constant_0 -->|"FLOAT(1)"| Mul_5
    Mul_5 -->|"FLOAT(ak, ck, bk*dk)"| Reshape_6
    Constant_1 -->|"INT64(4)"| Reshape_6
    Reshape_6 -->|"FLOAT(ak, ck, 8, 64)"| Transpose_7
    Transpose_3 --> MatMul_8
    Transpose_7 -->|"FLOAT(ak, 8, 64, ck)"| MatMul_8
    I_mask -->|"BOOL(am, bm, cm, dm)"| Where_9
    MatMul_8 --> Add_10
    Where_9 -->|"FLOAT(am, bm, cm, dm)"| Add_10
    Add_10 --> s11
    s11 --> nan12
    nan12 --> w13
    s11 --> w13
    w13 --> MatMul_14
    I_values_t -->|"FLOAT(av, 8, cv, 64)"| MatMul_14

    O_Y(["Y FLOAT(ay, by, cy, dy)"])
    MatMul_14 --> O_Y

    class I_values_t,I_keys,I_scale_sqrt,I_shape0,I_mask,I_query,O_Y ioNode
    class Constant_0,Constant_1 constNode
    class Mul_2,Transpose_3,Reshape_4,Mul_5,Reshape_6,Transpose_7 opNode
    class MatMul_8,Where_9,Add_10,s11,nan12,w13,MatMul_14 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_values_t(["values_t FLOAT(av, 8, cv, 64)"])
    I_keys(["keys FLOAT(ak, ck, bk*dk)"])
    I_scale_sqrt(["scale_sqrt FLOAT(1)"])
    I_shape0(["shape0 INT64(4)"])
    I_mask(["mask BOOL(am, bm, cm, dm)"])
    I_query(["query FLOAT(aq, cq, bq*dq)"])

    Reshape_0[["Reshape(., .)"]]
    Reshape_1[["Reshape(., .)"]]
    Transpose_2[["Transpose(., perm=[0, 2, 1, 3])"]]
    Transpose_3[["Transpose(., perm=[0, 2, 1, 3])"]]
    LocalAttention_to1_4[["intermediate.LocalAttention_to1(., ., ., ., .)"]]

    I_query -->|"FLOAT(aq, cq, bq*dq)"| Reshape_0
    I_shape0 -->|"INT64(4)"| Reshape_0
    I_keys -->|"FLOAT(ak, ck, bk*dk)"| Reshape_1
    I_shape0 -->|"INT64(4)"| Reshape_1
    Reshape_0 --> Transpose_2
    Reshape_1 --> Transpose_3
    Transpose_2 --> LocalAttention_to1_4
    Transpose_3 --> LocalAttention_to1_4
    I_values_t -->|"FLOAT(av, 8, cv, 64)"| LocalAttention_to1_4
    I_mask -->|"BOOL(am, bm, cm, dm)"| LocalAttention_to1_4
    I_scale_sqrt -->|"FLOAT(1)"| LocalAttention_to1_4

    O_Y(["Y FLOAT(ay, by, cy, dy)"])
    LocalAttention_to1_4 --> O_Y

    class I_values_t,I_keys,I_scale_sqrt,I_shape0,I_mask,I_query,O_Y ioNode
    class Reshape_0,Reshape_1,Transpose_2,Transpose_3,LocalAttention_to1_4 opNode
    
apply(g: GraphBuilder, mul1: NodeProto, transpose_mul1: NodeProto | None, reshape_mul1: NodeProto | None, gqa_unsqueeze: NodeProto | None, mul2: NodeProto, reshape_mul2: NodeProto | None, gqa_expand: NodeProto | None, gqa_reshape: NodeProto | None, transpose: NodeProto | None, mat_qk: NodeProto, where_node: NodeProto, add_node: NodeProto | None, softmax: NodeProto, isnan: NodeProto, where: NodeProto, gqa_unsqueeze_v: NodeProto | None, gqa_expand_v: NodeProto | None, gqa_reshape_v: NodeProto | None, mat_qkv: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.