yobx.xoptim.patterns_ort.decoder_attention#
- class yobx.xoptim.patterns_ort.decoder_attention.DecoderAttentionPattern(verbose: int = 0, priority: int = 2)[source]#
Fuses a sequence-first decoder cross-attention or self-attention computation into
com.microsoft.DecoderAttention.The operator expects inputs in sequence-first format
(S, B, H)(sequence length, batch size, hidden size) and separate weight matrices:query–(S, B, H)key–(T, B, H)(same as query for self-attention)q_weight–(H, H)kv_weight–(H, 2*H)(K and V weights concatenated)bias–(3*H,)(Q, K, V biases concatenated)static_kv– bool scalar:Truefor cross-attention,Falsefor self-attentionuse_past– bool scalar:False(no KV-cache in this pattern)has_layer_state– bool scalar:Falsehas_key_padding_mask– bool scalar:False
Cross-attention is detected when the source of the Q projection differs from the source of the K/V projections.
Model with nodes to be fused (seq-first cross-attention, no cache, no mask):
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_query(["query FLOAT(S, B, H)"]) I_key(["key FLOAT(T, B, H)"]) W_q(["q_weight FLOAT(H, H)"]) W_k(["k_weight FLOAT(H, H)"]) W_v(["v_weight FLOAT(H, H)"]) B_q(["q_bias FLOAT(H)"]) B_k(["k_bias FLOAT(H)"]) B_v(["v_bias FLOAT(H)"]) MM_q[["MatMul(query, q_weight)"]] Add_q[["Add(mm_q, q_bias)"]] Re_q[["Reshape(., [0,0,N,d])"]] Tr_q[["Transpose(., perm=[1,2,0,3])"]] MM_k[["MatMul(key, k_weight)"]] Add_k[["Add(mm_k, k_bias)"]] Re_k[["Reshape(., [0,0,N,d])"]] Tr_k[["Transpose(., perm=[1,2,0,3])"]] Tr_kt[["Transpose(., perm=[0,1,3,2])"]] MM_v[["MatMul(key, v_weight)"]] Add_v[["Add(mm_v, v_bias)"]] Re_v[["Reshape(., [0,0,N,d])"]] Tr_v[["Transpose(., perm=[1,2,0,3])"]] Mul_scale[["Mul(Q, scale)"]] MM_qk[["MatMul(Q_scaled, K_T)"]] Softmax[["Softmax(., axis=-1)"]] MM_qkv[["MatMul(attn_probs, V)"]] Tr_out[["Transpose(., perm=[2,0,1,3])"]] Re_out[["Reshape(., [0,0,-1])"]] I_query --> MM_q --> Add_q --> Re_q --> Tr_q W_q --> MM_q B_q --> Add_q I_key --> MM_k --> Add_k --> Re_k --> Tr_k --> Tr_kt W_k --> MM_k B_k --> Add_k I_key --> MM_v --> Add_v --> Re_v --> Tr_v W_v --> MM_v B_v --> Add_v Tr_q --> Mul_scale Mul_scale --> MM_qk Tr_kt --> MM_qk MM_qk --> Softmax --> MM_qkv Tr_v --> MM_qkv MM_qkv --> Tr_out --> Re_out O_output(["output FLOAT(S, B, H)"]) Re_out --> O_output class I_query,I_key,O_output ioNode class W_q,W_k,W_v,B_q,B_k,B_v initNode class MM_q,Add_q,Re_q,Tr_q,MM_k,Add_k,Re_k,Tr_k,Tr_kt opNode class MM_v,Add_v,Re_v,Tr_v,Mul_scale,MM_qk,Softmax,MM_qkv opNode class Tr_out,Re_out opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_query(["query FLOAT(S, B, H)"]) I_key(["key FLOAT(T, B, H)"]) W_q(["q_weight FLOAT(H, H)"]) W_kv(["kv_weight FLOAT(H, 2*H)"]) B_qkv(["bias FLOAT(3*H)"]) C_static_kv(["static_kv BOOL()"]) C_use_past(["use_past BOOL()"]) C_has_state(["has_layer_state BOOL()"]) C_has_mask(["has_key_padding_mask BOOL()"]) DecoderAttention_0[["com.microsoft.DecoderAttention(., ., ., ., ., , , , ., ., ., .)"]] I_query --> DecoderAttention_0 I_key --> DecoderAttention_0 W_q --> DecoderAttention_0 W_kv --> DecoderAttention_0 B_qkv --> DecoderAttention_0 C_static_kv --> DecoderAttention_0 C_use_past --> DecoderAttention_0 C_has_state --> DecoderAttention_0 C_has_mask --> DecoderAttention_0 O_output(["output FLOAT(S, B, H)"]) DecoderAttention_0 --> O_output class I_query,I_key,O_output ioNode class W_q,W_kv,B_qkv,C_static_kv,C_use_past,C_has_state,C_has_mask initNode class DecoderAttention_0 opNode- apply(g: GraphBuilder, *nodes: NodeProto) List[NodeProto][source]#
Not called directly; the match closure handles dispatch.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Attempts to match the pattern starting from the output Reshape node.