yobx.xoptim.patterns_ort.moe#
- class yobx.xoptim.patterns_ort.moe.MoEPattern(verbose: int = 0, priority: int = 2)[source]#
Fuses the Mixture-of-Experts (MoE) computation pattern into a single
com.microsoft.MoEnode.The pattern matches a standard top-k expert dispatch with two FC layers and an element-wise activation between them. The routing probabilities must already be computed (e.g. via
Softmax) before the pattern.Model with nodes to be fused (k=1, relu, both biases present):
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_input(["input FLOAT(T, H)"]) I_rp(["router_logits FLOAT(T, E)"]) I_fc1w(["fc1_w FLOAT(E, I, H)"]) I_fc1b(["fc1_b FLOAT(E, I)"]) I_fc2w(["fc2_w FLOAT(E, H, I)"]) I_fc2b(["fc2_b FLOAT(E, H)"]) Softmax_0[["Softmax(., axis=-1) [optional]"]] TopK_0[["TopK(., k)"]] Reshape_ids[["Reshape(top_indices, (T,)) or Reshape(top_indices, (-1,))"]] Reshape_w[["Reshape(top_weights, (T, 1))"]] Gather_fc1w[["Gather(fc1_w, flat_ids, 0)"]] Gather_fc1b[["Gather(fc1_b, flat_ids, 0)"]] Transpose_fc1w[["Transpose(., [0,2,1])"]] Unsqueeze_in[["Unsqueeze(input, [1])"]] MatMul_fc1[["MatMul(., .)"]] Squeeze_fc1[["Squeeze(., [1])"]] Add_fc1[["Add(., .)"]] Activation_0[["Relu/Gelu/Silu(.)"]] Gather_fc2w[["Gather(fc2_w, flat_ids, 0)"]] Gather_fc2b[["Gather(fc2_b, flat_ids, 0)"]] Transpose_fc2w[["Transpose(., [0,2,1])"]] Unsqueeze_h1[["Unsqueeze(., [1])"]] MatMul_fc2[["MatMul(., .)"]] Squeeze_fc2[["Squeeze(., [1])"]] Add_fc2[["Add(., .)"]] Mul_out[["Mul(., .)"]] I_rp -->|"FLOAT(T, E)"| Softmax_0 Softmax_0 -->|"FLOAT(T, E)"| TopK_0 TopK_0 -->|"weights FLOAT(T, 1)"| Reshape_w TopK_0 -->|"indices INT64(T, 1)"| Reshape_ids Reshape_ids -->|"INT64(T,)"| Gather_fc1w Reshape_ids -->|"INT64(T,)"| Gather_fc1b Reshape_ids -->|"INT64(T,)"| Gather_fc2w Reshape_ids -->|"INT64(T,)"| Gather_fc2b I_fc1w --> Gather_fc1w I_fc1b --> Gather_fc1b I_fc2w --> Gather_fc2w I_fc2b --> Gather_fc2b Gather_fc1w --> Transpose_fc1w I_input --> Unsqueeze_in Unsqueeze_in --> MatMul_fc1 Transpose_fc1w --> MatMul_fc1 MatMul_fc1 --> Squeeze_fc1 Squeeze_fc1 --> Add_fc1 Gather_fc1b --> Add_fc1 Add_fc1 --> Activation_0 Gather_fc2w --> Transpose_fc2w Activation_0 --> Unsqueeze_h1 Unsqueeze_h1 --> MatMul_fc2 Transpose_fc2w --> MatMul_fc2 MatMul_fc2 --> Squeeze_fc2 Squeeze_fc2 --> Add_fc2 Gather_fc2b --> Add_fc2 Add_fc2 --> Mul_out Reshape_w --> Mul_out O_out(["output FLOAT(T, H)"]) Mul_out --> O_out class I_input,I_rp,I_fc1w,I_fc1b,I_fc2w,I_fc2b,O_out ioNode class TopK_0,Reshape_ids,Reshape_w,Gather_fc1w,Gather_fc1b opNode class Transpose_fc1w,Unsqueeze_in,MatMul_fc1,Squeeze_fc1,Add_fc1 opNode class Activation_0,Gather_fc2w,Gather_fc2b,Transpose_fc2w opNode class Unsqueeze_h1,MatMul_fc2,Squeeze_fc2,Add_fc2,Mul_out opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_input(["input FLOAT(T, H)"]) I_rp(["router_logits FLOAT(T, E)"]) I_fc1w(["fc1_w FLOAT(E, I, H)"]) I_fc1b(["fc1_b FLOAT(E, I)"]) I_fc2w(["fc2_w FLOAT(E, H, I)"]) I_fc2b(["fc2_b FLOAT(E, H)"]) MoE_0[["com.microsoft.MoE(., ., ., ., ., .)"]] I_input --> MoE_0 I_rp --> MoE_0 I_fc1w --> MoE_0 I_fc1b --> MoE_0 I_fc2w --> MoE_0 I_fc2b --> MoE_0 O_out(["output FLOAT(T, H)"]) MoE_0 --> O_out class I_input,I_rp,I_fc1w,I_fc1b,I_fc2w,I_fc2b,O_out ioNode class MoE_0 opNode- apply(g: GraphBuilder, topk_node: NodeProto, ids_reshape: NodeProto, routing_reshape: NodeProto, input_unsqueeze: NodeProto, fc1_w_gather: NodeProto, fc1_w_transpose: NodeProto, fc1_matmul: NodeProto, fc1_squeeze: NodeProto, fc1_bias_gather: NodeProto | None, fc1_add: NodeProto | None, act_node: NodeProto, fc1_act_unsqueeze: NodeProto, fc2_w_gather: NodeProto, fc2_w_transpose: NodeProto, fc2_matmul: NodeProto, fc2_squeeze: NodeProto, fc2_bias_gather: NodeProto | None, fc2_add: NodeProto | None, mul_node: NodeProto) List[NodeProto][source]#
Replaces the matched expert-computation sub-graph with one MoE node.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.