yobx.xoptim.patterns_ort.linear_attention#

class yobx.xoptim.patterns_ort.linear_attention.LinearAttentionPattern(verbose: int = 0, priority: int = 2)[source]#

Fuses a linear-attention recurrent state update into com.microsoft.LinearAttention.

The pattern supports two update_rule variants ('linear' and 'gated'), which correspond exactly to the update_rule attribute of the ORT contrib op.

Inputs expected by the pattern (all 3-D packed, i.e. batch-first with heads folded into the last dimension):

  • queryFLOAT(B, T, H_q * d_k)

  • keyFLOAT(B, T, H_kv * d_k)

  • valueFLOAT(B, T, H_kv * d_v)

  • past_state (optional) – FLOAT(B, H_kv, d_k, d_v)

  • decay (optional, gated only) – FLOAT(B, T, H_kv * d_k) or FLOAT(B, T, H_kv)

Update rules (where ⊗ denotes outer product):

  • 'linear': S_t = S_{t-1} + k_t v_t

  • 'gated': S_t = exp(g_t) * S_{t-1} + k_t v_t

followed in all cases by: o_t = scale * q_t^T S_t

The pattern operates on the 4-D internal representation obtained after unpacking and transposing the 3-D packed inputs: [B, T, H * d]    Reshape Transpose    [B, H, T, d]

For the decoding case (T = 1) the sequence/time dimension is squeezed before the core computation and unsqueezed afterwards.

Model with nodes to be fused ('linear' rule, T = 1):

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_query_3d(["query FLOAT(B, 1, H_q*d_k)"])
    I_key_3d(["key FLOAT(B, 1, H_kv*d_k)"])
    I_value_3d(["value FLOAT(B, 1, H_kv*d_v)"])
    I_past_state(["past_state FLOAT(B, H_kv, d_k, d_v)"])

    Reshape_q[["Reshape(., [0, 1, H_q, d_k])"]]
    Reshape_k[["Reshape(., [0, 1, H_kv, d_k])"]]
    Reshape_v[["Reshape(., [0, 1, H_kv, d_v])"]]
    Transpose_q[["Transpose(., perm=[0,2,1,3])"]]
    Transpose_k[["Transpose(., perm=[0,2,1,3])"]]
    Transpose_v[["Transpose(., perm=[0,2,1,3])"]]
    Squeeze_q[["Squeeze(., [2])"]]
    Squeeze_k[["Squeeze(., [2])"]]
    Squeeze_v[["Squeeze(., [2])"]]
    Unsqueeze_k[["Unsqueeze(., [-1])"]]
    Unsqueeze_v[["Unsqueeze(., [-2])"]]
    Mul_kv[["Mul(., .)"]]
    Add_state[["Add(past_state, kv)"]]
    Unsqueeze_q[["Unsqueeze(., [-2])"]]
    MatMul_out[["MatMul(., .)"]]
    Squeeze_out[["Squeeze(., [-2])"]]
    Mul_scale[["Mul(., scale)"]]
    Unsqueeze_out[["Unsqueeze(., [2])"]]
    Transpose_out[["Transpose(., perm=[0,2,1,3])"]]
    Reshape_out[["Reshape(., [0, -1, H_q*d_v])"]]

    I_query_3d --> Reshape_q --> Transpose_q --> Squeeze_q
    I_key_3d --> Reshape_k --> Transpose_k --> Squeeze_k
    I_value_3d --> Reshape_v --> Transpose_v --> Squeeze_v
    Squeeze_k --> Unsqueeze_k
    Squeeze_v --> Unsqueeze_v
    Unsqueeze_k --> Mul_kv
    Unsqueeze_v --> Mul_kv
    I_past_state --> Add_state
    Mul_kv --> Add_state
    Squeeze_q --> Unsqueeze_q
    Unsqueeze_q --> MatMul_out
    Add_state --> MatMul_out
    MatMul_out --> Squeeze_out --> Mul_scale
    Mul_scale --> Unsqueeze_out --> Transpose_out --> Reshape_out

    O_output(["output FLOAT(B, 1, H_q*d_v)"])
    Reshape_out --> O_output
    O_state(["present_state FLOAT(B, H_kv, d_k, d_v)"])
    Add_state --> O_state

    class I_query_3d,I_key_3d,I_value_3d,I_past_state,O_output,O_state ioNode
    class Reshape_q,Reshape_k,Reshape_v,Transpose_q,Transpose_k,Transpose_v opNode
    class Squeeze_q,Squeeze_k,Squeeze_v,Unsqueeze_k,Unsqueeze_v opNode
    class Mul_kv,Add_state,Unsqueeze_q,MatMul_out,Squeeze_out,Mul_scale opNode
    class Unsqueeze_out,Transpose_out,Reshape_out opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_query_3d(["query FLOAT(B, 1, H_q*d_k)"])
    I_key_3d(["key FLOAT(B, 1, H_kv*d_k)"])
    I_value_3d(["value FLOAT(B, 1, H_kv*d_v)"])
    I_past_state(["past_state FLOAT(B, H_kv, d_k, d_v)"])

    LinearAttention_0[["com.microsoft.LinearAttention(., ., ., .)"]]

    I_query_3d --> LinearAttention_0
    I_key_3d --> LinearAttention_0
    I_value_3d --> LinearAttention_0
    I_past_state --> LinearAttention_0

    O_output(["output FLOAT(B, 1, H_q*d_v)"])
    LinearAttention_0 --> O_output
    O_state(["present_state FLOAT(B, H_kv, d_k, d_v)"])
    LinearAttention_0 --> O_state

    class I_query_3d,I_key_3d,I_value_3d,I_past_state,O_output,O_state ioNode
    class LinearAttention_0 opNode
    
apply(g: GraphBuilder, *nodes: NodeProto) List[NodeProto][source]#

Not called directly; the match closure handles dispatch.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Attempts to match starting from a scalar-scale Mul anchor node.