yobx.xoptim.patterns_ort.decoder_attention#

class yobx.xoptim.patterns_ort.decoder_attention.DecoderAttentionPattern(verbose: int = 0, priority: int = 2)[source]#

Fuses a sequence-first decoder cross-attention or self-attention computation into com.microsoft.DecoderAttention.

The operator expects inputs in sequence-first format (S, B, H) (sequence length, batch size, hidden size) and separate weight matrices:

  • query(S, B, H)

  • key(T, B, H) (same as query for self-attention)

  • q_weight(H, H)

  • kv_weight(H, 2*H) (K and V weights concatenated)

  • bias(3*H,) (Q, K, V biases concatenated)

  • static_kv – bool scalar: True for cross-attention, False for self-attention

  • use_past – bool scalar: False (no KV-cache in this pattern)

  • has_layer_state – bool scalar: False

  • has_key_padding_mask – bool scalar: False

Cross-attention is detected when the source of the Q projection differs from the source of the K/V projections.

Model with nodes to be fused (seq-first cross-attention, no cache, no mask):

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_query(["query FLOAT(S, B, H)"])
    I_key(["key FLOAT(T, B, H)"])
    W_q(["q_weight FLOAT(H, H)"])
    W_k(["k_weight FLOAT(H, H)"])
    W_v(["v_weight FLOAT(H, H)"])
    B_q(["q_bias FLOAT(H)"])
    B_k(["k_bias FLOAT(H)"])
    B_v(["v_bias FLOAT(H)"])

    MM_q[["MatMul(query, q_weight)"]]
    Add_q[["Add(mm_q, q_bias)"]]
    Re_q[["Reshape(., [0,0,N,d])"]]
    Tr_q[["Transpose(., perm=[1,2,0,3])"]]

    MM_k[["MatMul(key, k_weight)"]]
    Add_k[["Add(mm_k, k_bias)"]]
    Re_k[["Reshape(., [0,0,N,d])"]]
    Tr_k[["Transpose(., perm=[1,2,0,3])"]]
    Tr_kt[["Transpose(., perm=[0,1,3,2])"]]

    MM_v[["MatMul(key, v_weight)"]]
    Add_v[["Add(mm_v, v_bias)"]]
    Re_v[["Reshape(., [0,0,N,d])"]]
    Tr_v[["Transpose(., perm=[1,2,0,3])"]]

    Mul_scale[["Mul(Q, scale)"]]
    MM_qk[["MatMul(Q_scaled, K_T)"]]
    Softmax[["Softmax(., axis=-1)"]]
    MM_qkv[["MatMul(attn_probs, V)"]]

    Tr_out[["Transpose(., perm=[2,0,1,3])"]]
    Re_out[["Reshape(., [0,0,-1])"]]

    I_query --> MM_q --> Add_q --> Re_q --> Tr_q
    W_q --> MM_q
    B_q --> Add_q

    I_key --> MM_k --> Add_k --> Re_k --> Tr_k --> Tr_kt
    W_k --> MM_k
    B_k --> Add_k

    I_key --> MM_v --> Add_v --> Re_v --> Tr_v
    W_v --> MM_v
    B_v --> Add_v

    Tr_q --> Mul_scale
    Mul_scale --> MM_qk
    Tr_kt --> MM_qk
    MM_qk --> Softmax --> MM_qkv
    Tr_v --> MM_qkv

    MM_qkv --> Tr_out --> Re_out

    O_output(["output FLOAT(S, B, H)"])
    Re_out --> O_output

    class I_query,I_key,O_output ioNode
    class W_q,W_k,W_v,B_q,B_k,B_v initNode
    class MM_q,Add_q,Re_q,Tr_q,MM_k,Add_k,Re_k,Tr_k,Tr_kt opNode
    class MM_v,Add_v,Re_v,Tr_v,Mul_scale,MM_qk,Softmax,MM_qkv opNode
    class Tr_out,Re_out opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_query(["query FLOAT(S, B, H)"])
    I_key(["key FLOAT(T, B, H)"])
    W_q(["q_weight FLOAT(H, H)"])
    W_kv(["kv_weight FLOAT(H, 2*H)"])
    B_qkv(["bias FLOAT(3*H)"])
    C_static_kv(["static_kv BOOL()"])
    C_use_past(["use_past BOOL()"])
    C_has_state(["has_layer_state BOOL()"])
    C_has_mask(["has_key_padding_mask BOOL()"])

    DecoderAttention_0[["com.microsoft.DecoderAttention(., ., ., ., ., , , , ., ., .,
    .)"]]

    I_query --> DecoderAttention_0
    I_key --> DecoderAttention_0
    W_q --> DecoderAttention_0
    W_kv --> DecoderAttention_0
    B_qkv --> DecoderAttention_0
    C_static_kv --> DecoderAttention_0
    C_use_past --> DecoderAttention_0
    C_has_state --> DecoderAttention_0
    C_has_mask --> DecoderAttention_0

    O_output(["output FLOAT(S, B, H)"])
    DecoderAttention_0 --> O_output

    class I_query,I_key,O_output ioNode
    class W_q,W_kv,B_qkv,C_static_kv,C_use_past,C_has_state,C_has_mask initNode
    class DecoderAttention_0 opNode
    
apply(g: GraphBuilder, *nodes: NodeProto) List[NodeProto][source]#

Not called directly; the match closure handles dispatch.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Attempts to match the pattern starting from the output Reshape node.