yobx.xoptim.patterns.onnx_rotary#
- class yobx.xoptim.patterns.onnx_rotary.FunctionCausalMaskMulAddPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Fuses nodes matching CausalMask into a local function.
<<<
from yobx.xbuilder import GraphBuilder from yobx.xoptim import GraphBuilderPatternOptimization from yobx.xoptim.patterns import ( FunctionCausalMaskMulAddPattern, ) pat = FunctionCausalMaskMulAddPattern() g = GraphBuilderPatternOptimization(GraphBuilder(18)) print(pat._pattern_to_string(g))
>>>
Pattern cannot be constructed without the matched nodes.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_d1(["d1 INT64(1)"]) I_N(["N INT64(1)"]) I_d2(["d2 INT64(1)"]) Squeeze_0[["Squeeze(.)"]] Squeeze_1[["Squeeze(.)"]] Range_2[["Range(0, ., 1)"]] Range_3[["Range(0, ., 1)"]] Unsqueeze_4[["Unsqueeze(., [0, 1, 2])"]] Unsqueeze_5[["Unsqueeze(., [1, 2, 3])"]] Mul_6[["Mul(., .)"]] Add_7[["Add(., .)"]] I_d1 -->|"INT64(1)"| Squeeze_0 I_d2 -->|"INT64(1)"| Squeeze_1 Squeeze_0 -->|"INT64()"| Range_2 Squeeze_1 -->|"INT64()"| Range_3 Range_2 -->|"INT64(NEWDIM_range_0)"| Unsqueeze_4 Range_3 -->|"INT64(NEWDIM_range_1)"| Unsqueeze_5 Unsqueeze_5 -->|"INT64(NEWDIM_range_1, 1, 1, 1)"| Mul_6 I_N -->|"INT64(1)"| Mul_6 Unsqueeze_4 -->|"INT64(1, 1, 1, NEWDIM_range_0)"| Add_7 Mul_6 -->|"INT64(NEWDIM_range_1, 1, 1, 1)"| Add_7 O_yyc(["yyc INT64(c, 1, 1, b)"]) Add_7 --> O_yyc class I_d1,I_N,I_d2,O_yyc ioNode class Squeeze_0,Squeeze_1,Range_2,Range_3,Unsqueeze_4,Unsqueeze_5,Mul_6,Add_7 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_d1(["d1 INT64(1)"]) I_N(["N INT64(1)"]) I_d2(["d2 INT64(1)"]) CausalMaskMulAdd_0[["intermediate.CausalMaskMulAdd(., ., .)"]] I_d1 -->|"INT64(1)"| CausalMaskMulAdd_0 I_d2 -->|"INT64(1)"| CausalMaskMulAdd_0 I_N -->|"INT64(1)"| CausalMaskMulAdd_0 O_yyc(["yyc INT64(c, 1, 1, b)"]) CausalMaskMulAdd_0 --> O_yyc class I_d1,I_N,I_d2,O_yyc ioNode class CausalMaskMulAdd_0 opNode- apply(g: GraphBuilder, dim_squeeze1: NodeProto, dim_squeeze2: NodeProto, range1: NodeProto, range2: NodeProto, rg_unsqueeze1: NodeProto, rg_unsqueeze2: NodeProto, mul: NodeProto, add: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_rotary.FunctionCausalMaskPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Fuses nodes matching CausalMask into a local function.
<<<
from yobx.xbuilder import GraphBuilder from yobx.xoptim import GraphBuilderPatternOptimization from yobx.xoptim.patterns import ( FunctionCausalMaskPattern, ) pat = FunctionCausalMaskPattern() g = GraphBuilderPatternOptimization(GraphBuilder(18)) print(pat._pattern_to_string(g))
>>>
Pattern cannot be constructed without the matched nodes.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_d1(["d1 INT64(1)"]) I_initi(["initi INT64(1)"]) I_d2(["d2 INT64(1)"]) Constant_0[["Constant() -#gt; initi"]] Squeeze_1[["Squeeze(.)"]] Squeeze_2[["Squeeze(.)"]] Range_3[["Range(0, ., 1)"]] Range_4[["Range(., ., 1)"]] Unsqueeze_5[["Unsqueeze(., [0, 1, 2])"]] Unsqueeze_6[["Unsqueeze(., [0, 1, 3])"]] Sub_7[["Sub(., .)"]] Greater_8[["Greater(., .)"]] I_d1 -->|"INT64(1)"| Squeeze_1 I_d2 -->|"INT64(1)"| Squeeze_2 Squeeze_2 -->|"INT64()"| Range_3 Squeeze_1 -->|"INT64()"| Range_4 Squeeze_2 -->|"INT64()"| Range_4 Range_3 -->|"INT64(NEWDIM_range_0)"| Unsqueeze_5 Range_4 -->|"INT64(NEWDIM_range_1)"| Unsqueeze_6 Unsqueeze_6 -->|"INT64(1, 1, NEWDIM_range_1, 1)"| Sub_7 Constant_0 -->|"INT64(1)"| Sub_7 Unsqueeze_5 -->|"INT64(1, 1, 1, NEWDIM_range_0)"| Greater_8 Sub_7 -->|"INT64(1, 1, NEWDIM_range_1, 1)"| Greater_8 O_nd2(["nd2 INT64()"]) Squeeze_2 --> O_nd2 O_yc(["yc BOOL(1, 1, b-a, b)"]) Greater_8 --> O_yc class I_d1,I_initi,I_d2,O_nd2,O_yc ioNode class Constant_0 constNode class Squeeze_1,Squeeze_2,Range_3,Range_4,Unsqueeze_5,Unsqueeze_6,Sub_7 opNode class Greater_8 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_d1(["d1 INT64(1)"]) I_initi(["initi INT64(1)"]) I_d2(["d2 INT64(1)"]) Squeeze_0[["Squeeze(.)"]] ShiftedCausalMask_1[["intermediate.ShiftedCausalMask(., ., .)"]] I_d2 -->|"INT64(1)"| Squeeze_0 I_d1 -->|"INT64(1)"| ShiftedCausalMask_1 I_d2 -->|"INT64(1)"| ShiftedCausalMask_1 I_initi -->|"INT64(1)"| ShiftedCausalMask_1 O_nd2(["nd2 INT64()"]) Squeeze_0 --> O_nd2 O_yc(["yc BOOL(1, 1, b-a, b)"]) ShiftedCausalMask_1 --> O_yc class I_d1,I_initi,I_d2,O_nd2,O_yc ioNode class Squeeze_0,ShiftedCausalMask_1 opNode- apply(g: GraphBuilder, dim_squeeze1: NodeProto, dim_squeeze2: NodeProto, range1: NodeProto, range2: NodeProto, rg_unsqueeze1: NodeProto, rg_unsqueeze2: NodeProto, sub2: NodeProto, less_or_equal: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_rotary.FunctionCosSinCachePattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Fuses nodes to simplify the creation of cos/sin caches in LLM.
<<<
from yobx.xbuilder import GraphBuilder from yobx.xoptim import GraphBuilderPatternOptimization from yobx.xoptim.patterns import ( FunctionCosSinCachePattern, ) pat = FunctionCosSinCachePattern() g = GraphBuilderPatternOptimization(GraphBuilder(18)) print(pat._pattern_to_string(g))
>>>
Pattern cannot be constructed without the matched nodes.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_weights(["weights FLOAT(1, 1, a)"]) I_dim1(["dim1 INT64(1)"]) I_dim2(["dim2 INT64(1)"]) Squeeze_0[["Squeeze(.)"]] Squeeze_1[["Squeeze(.)"]] Range_2[["Range(., ., 1)"]] Unsqueeze_3[["Unsqueeze(., [0, 1])"]] Cast_4[["Cast(., to=FLOAT)"]] Reshape_5[["Reshape(., [0, -1, 1])"]] Mul_6[["Mul(., .)"]] Cos_7[["Cos(.)"]] Sin_8[["Sin(.)"]] I_dim1 -->|"INT64(1)"| Squeeze_0 I_dim2 -->|"INT64(1)"| Squeeze_1 Squeeze_0 -->|"INT64()"| Range_2 Squeeze_1 -->|"INT64()"| Range_2 Range_2 -->|"INT64(NEWDIM_range_0)"| Unsqueeze_3 Unsqueeze_3 -->|"INT64(1, 1, NEWDIM_range_0)"| Cast_4 Cast_4 -->|"FLOAT(1, 1, NEWDIM_range_0)"| Reshape_5 I_weights -->|"FLOAT(1, 1, a)"| Mul_6 Reshape_5 -->|"FLOAT(1, NEWDIM_range_0, 1)"| Mul_6 Mul_6 -->|"FLOAT(1, NEWDIM_range_0, a)"| Cos_7 Mul_6 -->|"FLOAT(1, NEWDIM_range_0, a)"| Sin_8 O__onx_sin_mul_weights(["_onx_sin_mul_weights FLOAT(1, dim2-dim1, a)"]) Sin_8 --> O__onx_sin_mul_weights O__onx_cos_mul_weights(["_onx_cos_mul_weights FLOAT(1, dim2-dim1, a)"]) Cos_7 --> O__onx_cos_mul_weights class I_weights,I_dim1,I_dim2,O__onx_sin_mul_weights,O__onx_cos_mul_weights ioNode class Squeeze_0,Squeeze_1,Range_2,Unsqueeze_3,Cast_4,Reshape_5 opNode class Mul_6,Cos_7,Sin_8 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_weights(["weights FLOAT(1, 1, a)"]) I_dim1(["dim1 INT64(1)"]) I_dim2(["dim2 INT64(1)"]) CosSinCacheWithRange_0[["intermediate.CosSinCacheWithRange(., ., .)"]] I_dim1 -->|"INT64(1)"| CosSinCacheWithRange_0 I_dim2 -->|"INT64(1)"| CosSinCacheWithRange_0 I_weights -->|"FLOAT(1, 1, a)"| CosSinCacheWithRange_0 O__onx_sin_mul_weights(["_onx_sin_mul_weights FLOAT(1, dim2-dim1, a)"]) CosSinCacheWithRange_0 --> O__onx_sin_mul_weights O__onx_cos_mul_weights(["_onx_cos_mul_weights FLOAT(1, dim2-dim1, a)"]) CosSinCacheWithRange_0 --> O__onx_cos_mul_weights class I_weights,I_dim1,I_dim2,O__onx_sin_mul_weights,O__onx_cos_mul_weights ioNode class CosSinCacheWithRange_0 opNode- apply(g: GraphBuilder, dim_squeeze1: NodeProto, dim_squeeze2: NodeProto, range_node: NodeProto, unsq_or_cast_node: NodeProto, cast_or_unsq_node: NodeProto, reshape_node: NodeProto, mul_node: NodeProto, cos: NodeProto, cos_cast: NodeProto | None, sin: NodeProto, sin_cast: NodeProto | None) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_rotary.FunctionHalfRotaryEmbeddingPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Fuses nodes matching half RotaryEmbedding(23) into a local function.
<<<
from yobx.xbuilder import GraphBuilder from yobx.xoptim import GraphBuilderPatternOptimization from yobx.xoptim.patterns import ( FunctionHalfRotaryEmbeddingPattern, ) pat = FunctionHalfRotaryEmbeddingPattern() g = GraphBuilderPatternOptimization(GraphBuilder(18)) print(pat._pattern_to_string(g))
>>>
Pattern cannot be constructed without the matched nodes.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_m2(["m2 FLOAT(c, d)"]) I_X(["X FLOAT(a, b, c, d)"]) I_m1(["m1 FLOAT(c, d)"]) Split_0[["Split(., axis=-1)"]] Neg_1[["Neg(.)"]] Concat_2[["Concat(., ., axis=-1)"]] Mul_3[["Mul(., .)"]] Mul_4[["Mul(., .)"]] Add_5[["Add(., .)"]] I_X -->|"FLOAT(a, b, c, d)"| Split_0 Split_0 -->|"FLOAT(a, b, c, d-CeilToInt(d,2))"| Neg_1 Neg_1 -->|"FLOAT(a, b, c, d-CeilToInt(d,2))"| Concat_2 Split_0 -->|"FLOAT(a, b, c, CeilToInt(d,2))"| Concat_2 Concat_2 -->|"FLOAT(a, b, c, d)"| Mul_3 I_m1 -->|"FLOAT(c, d)"| Mul_3 I_X -->|"FLOAT(a, b, c, d)"| Mul_4 I_m2 -->|"FLOAT(c, d)"| Mul_4 Mul_3 -->|"FLOAT(a, b, c, d)"| Add_5 Mul_4 -->|"FLOAT(a, b, c, d)"| Add_5 O_Y(["Y FLOAT(a, b, c, d)"]) Add_5 --> O_Y class I_m2,I_X,I_m1,O_Y ioNode class Split_0,Neg_1,Concat_2,Mul_3,Mul_4,Add_5 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_m2(["m2 FLOAT(c, d)"]) I_X(["X FLOAT(a, b, c, d)"]) I_m1(["m1 FLOAT(c, d)"]) HalfRotaryEmbedding_0[["intermediate.HalfRotaryEmbedding(., ., .)"]] I_X -->|"FLOAT(a, b, c, d)"| HalfRotaryEmbedding_0 I_m2 -->|"FLOAT(c, d)"| HalfRotaryEmbedding_0 I_m1 -->|"FLOAT(c, d)"| HalfRotaryEmbedding_0 O_Y(["Y FLOAT(a, b, c, d)"]) HalfRotaryEmbedding_0 --> O_Y class I_m2,I_X,I_m1,O_Y ioNode class HalfRotaryEmbedding_0 opNode- apply(g: GraphBuilder, split_node: NodeProto, neg_node: NodeProto, concat_node: NodeProto, mul1_node: NodeProto, mul2_node: NodeProto, add_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_rotary.RotaryConcatPartPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Optimizes the following pattern
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(a, b, c, d)"]) ConstantOfShape_0[["ConstantOfShape([2, 2, 1024, 256])"]] ConstantOfShape_1[["ConstantOfShape([2, 2, 1024, 256])"]] Slice_2[["Slice(., [256], [512], [3])"]] Concat_3[["Concat(., ., axis=3)"]] Slice_4[["Slice(., [0], [256], [3])"]] Neg_5[["Neg(.)"]] Concat_6[["Concat(., ., axis=3)"]] Add_7[["Add(., .)"]] I_X -->|"FLOAT(a, b, c, d)"| Slice_2 ConstantOfShape_0 -->|"FLOAT(2, 2, 1024, 256)"| Concat_3 Slice_2 -->|"FLOAT(a, b, c, 256)"| Concat_3 I_X -->|"FLOAT(a, b, c, d)"| Slice_4 Slice_4 -->|"FLOAT(a, b, c, 256)"| Neg_5 ConstantOfShape_0 -->|"FLOAT(2, 2, 1024, 256)"| Concat_6 Neg_5 -->|"FLOAT(a, b, c, 256)"| Concat_6 Concat_3 -->|"FLOAT(2, 2, 1024, 512)"| Add_7 Concat_6 -->|"FLOAT(2, 2, 1024, 512)"| Add_7 O_Y(["Y FLOAT(a, b, c, d)"]) Add_7 --> O_Y class I_X,O_Y ioNode class ConstantOfShape_0,ConstantOfShape_1,Slice_2,Concat_3,Slice_4,Neg_5 opNode class Concat_6,Add_7 opNodeModel with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(a, 16)"]) I_split(["split INT64(2)"]) I_x1(["x1 FLOAT(a, 8)"]) I_shape(["shape INT64(2)"]) I_nx1(["nx1 FLOAT(a, 8)"]) I_x2(["x2 FLOAT(a, 8)"]) Constant_0[["Constant() -#gt; shape"]] Constant_1[["Constant() -#gt; split"]] ConstantOfShape_2[["ConstantOfShape(.)"]] Split_3[["Split(., ., axis=1)"]] Neg_4[["Neg(.)"]] Concat_5[["Concat(., ., axis=1)"]] ConstantOfShape_6[["ConstantOfShape(.)"]] Concat_7[["Concat(., ., axis=1)"]] Add_8[["Add(., .)"]] Constant_0 -->|"INT64(2)"| ConstantOfShape_2 I_X -->|"FLOAT(a, 16)"| Split_3 Constant_1 -->|"INT64(2)"| Split_3 Split_3 -->|"FLOAT(a, 8)"| Neg_4 Neg_4 -->|"FLOAT(a, 8)"| Concat_5 ConstantOfShape_6 -->|"FLOAT(3, 8)"| Concat_5 Constant_0 -->|"INT64(2)"| ConstantOfShape_6 ConstantOfShape_6 -->|"FLOAT(3, 8)"| Concat_7 Split_3 -->|"FLOAT(a, 8)"| Concat_7 Concat_5 -->|"FLOAT(a, 16)"| Add_8 Concat_7 -->|"FLOAT(3, 16)"| Add_8 O_Y(["Y FLOAT(a, 16)"]) Add_8 --> O_Y O_zero(["zero FLOAT(3, 8)"]) ConstantOfShape_6 --> O_zero O_x1(["x1 FLOAT(a, 8)"]) Split_3 --> O_x1 O_nx1(["nx1 FLOAT(a, 8)"]) Neg_4 --> O_nx1 O_x2(["x2 FLOAT(a, 8)"]) Split_3 --> O_x2 class I_X,I_split,I_x1,I_shape,I_nx1,I_x2,O_Y,O_zero,O_x1,O_nx1,O_x2 ioNode class Constant_0,Constant_1 constNode class ConstantOfShape_2,Split_3,Neg_4,Concat_5,ConstantOfShape_6,Concat_7,Add_8 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(a, 16)"]) I_split(["split INT64(2)"]) I_x1(["x1 FLOAT(a, 8)"]) I_shape(["shape INT64(2)"]) I_nx1(["nx1 FLOAT(a, 8)"]) I_x2(["x2 FLOAT(a, 8)"]) ConstantOfShape_0[["ConstantOfShape(.)"]] Split_1[["Split(., ., axis=1)"]] Neg_2[["Neg(.)"]] Concat_3[["Concat(., ., axis=1)"]] I_shape -->|"INT64(2)"| ConstantOfShape_0 I_X -->|"FLOAT(a, 16)"| Split_1 I_split -->|"INT64(2)"| Split_1 Split_1 -->|"FLOAT(a, 8)"| Neg_2 Neg_2 -->|"FLOAT(a, 8)"| Concat_3 Split_1 -->|"FLOAT(a, 8)"| Concat_3 O_Y(["Y FLOAT(a, 16)"]) Concat_3 --> O_Y O_zero(["zero FLOAT(3, 8)"]) ConstantOfShape_0 --> O_zero O_x1(["x1 FLOAT(a, 8)"]) Split_1 --> O_x1 O_nx1(["nx1 FLOAT(a, 8)"]) Neg_2 --> O_nx1 O_x2(["x2 FLOAT(a, 8)"]) Split_1 --> O_x2 class I_X,I_split,I_x1,I_shape,I_nx1,I_x2,O_Y,O_zero,O_x1,O_nx1,O_x2 ioNode class ConstantOfShape_0,Split_1,Neg_2,Concat_3 opNode- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_rotary.RotaryEmbeddingPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Fuses nodes matching RotaryEmbedding(23).
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_m2(["m2 FLOAT(1, 1, c, e)"]) I_X(["X FLOAT(a, 2, c, d)"]) I_m1(["m1 FLOAT(1, 1, c, e)"]) Concat_0[["Concat(., ., axis=-1)"]] Concat_1[["Concat(., ., axis=-1)"]] Split_2[["Split(., [4, 6], axis=-1)"]] HalfRotaryEmbedding_3[["intermediate.HalfRotaryEmbedding(., ., .)"]] Concat_4[["Concat(., ., axis=-1)"]] I_m2 -->|"FLOAT(1, 1, c, e)"| Concat_0 I_m1 -->|"FLOAT(1, 1, c, e)"| Concat_1 I_X -->|"FLOAT(a, 2, c, d)"| Split_2 Split_2 --> HalfRotaryEmbedding_3 Concat_0 --> HalfRotaryEmbedding_3 Concat_1 --> HalfRotaryEmbedding_3 HalfRotaryEmbedding_3 --> Concat_4 Split_2 --> Concat_4 O_Y(["Y FLOAT(a, b, c, d)"]) Concat_4 --> O_Y class I_m2,I_X,I_m1,O_Y ioNode class Concat_0,Concat_1,Split_2,HalfRotaryEmbedding_3,Concat_4 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_m2(["m2 FLOAT(1, 1, c, e)"]) I_X(["X FLOAT(a, 2, c, d)"]) I_m1(["m1 FLOAT(1, 1, c, e)"]) Shape_0[["Shape(., end=1, start=0)"]] Concat_1[["Concat(., [1, 1], axis=0)"]] Squeeze_2[["Squeeze(., [1])"]] Squeeze_3[["Squeeze(., [1])"]] Expand_4[["Expand(., .)"]] Expand_5[["Expand(., .)"]] RotaryEmbedding_6[["RotaryEmbedding(., ., .)"]] I_X -->|"FLOAT(a, 2, c, d)"| Shape_0 Shape_0 -->|"INT64(1)"| Concat_1 I_m2 -->|"FLOAT(1, 1, c, e)"| Squeeze_2 I_m1 -->|"FLOAT(1, 1, c, e)"| Squeeze_3 Squeeze_2 -->|"FLOAT(1, c, e)"| Expand_4 Concat_1 -->|"INT64(3)"| Expand_4 Squeeze_3 -->|"FLOAT(1, c, e)"| Expand_5 Concat_1 -->|"INT64(3)"| Expand_5 I_X -->|"FLOAT(a, 2, c, d)"| RotaryEmbedding_6 Expand_4 -->|"FLOAT(a, c, e)"| RotaryEmbedding_6 Expand_5 -->|"FLOAT(a, c, e)"| RotaryEmbedding_6 O_Y(["Y FLOAT(a, b, c, d)"]) RotaryEmbedding_6 --> O_Y class I_m2,I_X,I_m1,O_Y ioNode class Shape_0,Concat_1,Squeeze_2,Squeeze_3,Expand_4,Expand_5,RotaryEmbedding_6 opNode- apply(g: GraphBuilder, concat_cos: NodeProto, concat_sin: NodeProto, split_node: NodeProto, half_node: NodeProto, concat_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.