yobx.xoptim.patterns_ort.linear_attention#
- class yobx.xoptim.patterns_ort.linear_attention.LinearAttentionPattern(verbose: int = 0, priority: int = 2)[source]#
Fuses a linear-attention recurrent state update into
com.microsoft.LinearAttention.The pattern supports two update_rule variants (
'linear'and'gated'), which correspond exactly to theupdate_ruleattribute of the ORT contrib op.Inputs expected by the pattern (all 3-D packed, i.e. batch-first with heads folded into the last dimension):
query–FLOAT(B, T, H_q * d_k)key–FLOAT(B, T, H_kv * d_k)value–FLOAT(B, T, H_kv * d_v)past_state(optional) –FLOAT(B, H_kv, d_k, d_v)decay(optional, gated only) –FLOAT(B, T, H_kv * d_k)orFLOAT(B, T, H_kv)
Update rules (where ⊗ denotes outer product):
'linear':S_t = S_{t-1} + k_t ⊗ v_t'gated':S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t
followed in all cases by:
o_t = scale * q_t^T S_tThe pattern operates on the 4-D internal representation obtained after unpacking and transposing the 3-D packed inputs:
[B, T, H * d] → Reshape → Transpose → [B, H, T, d]For the decoding case (
T = 1) the sequence/time dimension is squeezed before the core computation and unsqueezed afterwards.Model with nodes to be fused (
'linear'rule,T = 1):graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_query_3d(["query FLOAT(B, 1, H_q*d_k)"]) I_key_3d(["key FLOAT(B, 1, H_kv*d_k)"]) I_value_3d(["value FLOAT(B, 1, H_kv*d_v)"]) I_past_state(["past_state FLOAT(B, H_kv, d_k, d_v)"]) Reshape_q[["Reshape(., [0, 1, H_q, d_k])"]] Reshape_k[["Reshape(., [0, 1, H_kv, d_k])"]] Reshape_v[["Reshape(., [0, 1, H_kv, d_v])"]] Transpose_q[["Transpose(., perm=[0,2,1,3])"]] Transpose_k[["Transpose(., perm=[0,2,1,3])"]] Transpose_v[["Transpose(., perm=[0,2,1,3])"]] Squeeze_q[["Squeeze(., [2])"]] Squeeze_k[["Squeeze(., [2])"]] Squeeze_v[["Squeeze(., [2])"]] Unsqueeze_k[["Unsqueeze(., [-1])"]] Unsqueeze_v[["Unsqueeze(., [-2])"]] Mul_kv[["Mul(., .)"]] Add_state[["Add(past_state, kv)"]] Unsqueeze_q[["Unsqueeze(., [-2])"]] MatMul_out[["MatMul(., .)"]] Squeeze_out[["Squeeze(., [-2])"]] Mul_scale[["Mul(., scale)"]] Unsqueeze_out[["Unsqueeze(., [2])"]] Transpose_out[["Transpose(., perm=[0,2,1,3])"]] Reshape_out[["Reshape(., [0, -1, H_q*d_v])"]] I_query_3d --> Reshape_q --> Transpose_q --> Squeeze_q I_key_3d --> Reshape_k --> Transpose_k --> Squeeze_k I_value_3d --> Reshape_v --> Transpose_v --> Squeeze_v Squeeze_k --> Unsqueeze_k Squeeze_v --> Unsqueeze_v Unsqueeze_k --> Mul_kv Unsqueeze_v --> Mul_kv I_past_state --> Add_state Mul_kv --> Add_state Squeeze_q --> Unsqueeze_q Unsqueeze_q --> MatMul_out Add_state --> MatMul_out MatMul_out --> Squeeze_out --> Mul_scale Mul_scale --> Unsqueeze_out --> Transpose_out --> Reshape_out O_output(["output FLOAT(B, 1, H_q*d_v)"]) Reshape_out --> O_output O_state(["present_state FLOAT(B, H_kv, d_k, d_v)"]) Add_state --> O_state class I_query_3d,I_key_3d,I_value_3d,I_past_state,O_output,O_state ioNode class Reshape_q,Reshape_k,Reshape_v,Transpose_q,Transpose_k,Transpose_v opNode class Squeeze_q,Squeeze_k,Squeeze_v,Unsqueeze_k,Unsqueeze_v opNode class Mul_kv,Add_state,Unsqueeze_q,MatMul_out,Squeeze_out,Mul_scale opNode class Unsqueeze_out,Transpose_out,Reshape_out opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_query_3d(["query FLOAT(B, 1, H_q*d_k)"]) I_key_3d(["key FLOAT(B, 1, H_kv*d_k)"]) I_value_3d(["value FLOAT(B, 1, H_kv*d_v)"]) I_past_state(["past_state FLOAT(B, H_kv, d_k, d_v)"]) LinearAttention_0[["com.microsoft.LinearAttention(., ., ., .)"]] I_query_3d --> LinearAttention_0 I_key_3d --> LinearAttention_0 I_value_3d --> LinearAttention_0 I_past_state --> LinearAttention_0 O_output(["output FLOAT(B, 1, H_q*d_v)"]) LinearAttention_0 --> O_output O_state(["present_state FLOAT(B, H_kv, d_k, d_v)"]) LinearAttention_0 --> O_state class I_query_3d,I_key_3d,I_value_3d,I_past_state,O_output,O_state ioNode class LinearAttention_0 opNode- apply(g: GraphBuilder, *nodes: NodeProto) List[NodeProto][source]#
Not called directly; the match closure handles dispatch.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Attempts to match starting from a scalar-scale Mul anchor node.