yobx.xoptim.patterns_ort.moe#

class yobx.xoptim.patterns_ort.moe.MoEPattern(verbose: int = 0, priority: int = 2)[source]#

Fuses the Mixture-of-Experts (MoE) computation pattern into a single com.microsoft.MoE node.

The pattern matches a standard top-k expert dispatch with two FC layers and an element-wise activation between them. The routing probabilities must already be computed (e.g. via Softmax) before the pattern.

Model with nodes to be fused (k=1, relu, both biases present):

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_input(["input FLOAT(T, H)"])
    I_rp(["router_logits FLOAT(T, E)"])
    I_fc1w(["fc1_w FLOAT(E, I, H)"])
    I_fc1b(["fc1_b FLOAT(E, I)"])
    I_fc2w(["fc2_w FLOAT(E, H, I)"])
    I_fc2b(["fc2_b FLOAT(E, H)"])

    Softmax_0[["Softmax(., axis=-1) [optional]"]]
    TopK_0[["TopK(., k)"]]
    Reshape_ids[["Reshape(top_indices, (T,)) or Reshape(top_indices, (-1,))"]]
    Reshape_w[["Reshape(top_weights, (T, 1))"]]
    Gather_fc1w[["Gather(fc1_w, flat_ids, 0)"]]
    Gather_fc1b[["Gather(fc1_b, flat_ids, 0)"]]
    Transpose_fc1w[["Transpose(., [0,2,1])"]]
    Unsqueeze_in[["Unsqueeze(input, [1])"]]
    MatMul_fc1[["MatMul(., .)"]]
    Squeeze_fc1[["Squeeze(., [1])"]]
    Add_fc1[["Add(., .)"]]
    Activation_0[["Relu/Gelu/Silu(.)"]]
    Gather_fc2w[["Gather(fc2_w, flat_ids, 0)"]]
    Gather_fc2b[["Gather(fc2_b, flat_ids, 0)"]]
    Transpose_fc2w[["Transpose(., [0,2,1])"]]
    Unsqueeze_h1[["Unsqueeze(., [1])"]]
    MatMul_fc2[["MatMul(., .)"]]
    Squeeze_fc2[["Squeeze(., [1])"]]
    Add_fc2[["Add(., .)"]]
    Mul_out[["Mul(., .)"]]

    I_rp -->|"FLOAT(T, E)"| Softmax_0
    Softmax_0 -->|"FLOAT(T, E)"| TopK_0
    TopK_0 -->|"weights FLOAT(T, 1)"| Reshape_w
    TopK_0 -->|"indices INT64(T, 1)"| Reshape_ids
    Reshape_ids -->|"INT64(T,)"| Gather_fc1w
    Reshape_ids -->|"INT64(T,)"| Gather_fc1b
    Reshape_ids -->|"INT64(T,)"| Gather_fc2w
    Reshape_ids -->|"INT64(T,)"| Gather_fc2b
    I_fc1w --> Gather_fc1w
    I_fc1b --> Gather_fc1b
    I_fc2w --> Gather_fc2w
    I_fc2b --> Gather_fc2b
    Gather_fc1w --> Transpose_fc1w
    I_input --> Unsqueeze_in
    Unsqueeze_in --> MatMul_fc1
    Transpose_fc1w --> MatMul_fc1
    MatMul_fc1 --> Squeeze_fc1
    Squeeze_fc1 --> Add_fc1
    Gather_fc1b --> Add_fc1
    Add_fc1 --> Activation_0
    Gather_fc2w --> Transpose_fc2w
    Activation_0 --> Unsqueeze_h1
    Unsqueeze_h1 --> MatMul_fc2
    Transpose_fc2w --> MatMul_fc2
    MatMul_fc2 --> Squeeze_fc2
    Squeeze_fc2 --> Add_fc2
    Gather_fc2b --> Add_fc2
    Add_fc2 --> Mul_out
    Reshape_w --> Mul_out

    O_out(["output FLOAT(T, H)"])
    Mul_out --> O_out

    class I_input,I_rp,I_fc1w,I_fc1b,I_fc2w,I_fc2b,O_out ioNode
    class TopK_0,Reshape_ids,Reshape_w,Gather_fc1w,Gather_fc1b opNode
    class Transpose_fc1w,Unsqueeze_in,MatMul_fc1,Squeeze_fc1,Add_fc1 opNode
    class Activation_0,Gather_fc2w,Gather_fc2b,Transpose_fc2w opNode
    class Unsqueeze_h1,MatMul_fc2,Squeeze_fc2,Add_fc2,Mul_out opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_input(["input FLOAT(T, H)"])
    I_rp(["router_logits FLOAT(T, E)"])
    I_fc1w(["fc1_w FLOAT(E, I, H)"])
    I_fc1b(["fc1_b FLOAT(E, I)"])
    I_fc2w(["fc2_w FLOAT(E, H, I)"])
    I_fc2b(["fc2_b FLOAT(E, H)"])

    MoE_0[["com.microsoft.MoE(., ., ., ., ., .)"]]

    I_input --> MoE_0
    I_rp --> MoE_0
    I_fc1w --> MoE_0
    I_fc1b --> MoE_0
    I_fc2w --> MoE_0
    I_fc2b --> MoE_0

    O_out(["output FLOAT(T, H)"])
    MoE_0 --> O_out

    class I_input,I_rp,I_fc1w,I_fc1b,I_fc2w,I_fc2b,O_out ioNode
    class MoE_0 opNode
    
apply(g: GraphBuilder, topk_node: NodeProto, ids_reshape: NodeProto, routing_reshape: NodeProto, input_unsqueeze: NodeProto, fc1_w_gather: NodeProto, fc1_w_transpose: NodeProto, fc1_matmul: NodeProto, fc1_squeeze: NodeProto, fc1_bias_gather: NodeProto | None, fc1_add: NodeProto | None, act_node: NodeProto, fc1_act_unsqueeze: NodeProto, fc2_w_gather: NodeProto, fc2_w_transpose: NodeProto, fc2_matmul: NodeProto, fc2_squeeze: NodeProto, fc2_bias_gather: NodeProto | None, fc2_add: NodeProto | None, mul_node: NodeProto) List[NodeProto][source]#

Replaces the matched expert-computation sub-graph with one MoE node.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.