yobx.xoptim.patterns.onnx_rotary#

class yobx.xoptim.patterns.onnx_rotary.FunctionCausalMaskMulAddPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Fuses nodes matching CausalMask into a local function.

<<<

from yobx.xbuilder import GraphBuilder
from yobx.xoptim import GraphBuilderPatternOptimization
from yobx.xoptim.patterns import (
    FunctionCausalMaskMulAddPattern,
)

pat = FunctionCausalMaskMulAddPattern()
g = GraphBuilderPatternOptimization(GraphBuilder(18))
print(pat._pattern_to_string(g))

>>>

    Pattern cannot be constructed without the matched nodes.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_d1(["d1 INT64(1)"])
    I_N(["N INT64(1)"])
    I_d2(["d2 INT64(1)"])

    Squeeze_0[["Squeeze(.)"]]
    Squeeze_1[["Squeeze(.)"]]
    Range_2[["Range(0, ., 1)"]]
    Range_3[["Range(0, ., 1)"]]
    Unsqueeze_4[["Unsqueeze(., [0, 1, 2])"]]
    Unsqueeze_5[["Unsqueeze(., [1, 2, 3])"]]
    Mul_6[["Mul(., .)"]]
    Add_7[["Add(., .)"]]

    I_d1 -->|"INT64(1)"| Squeeze_0
    I_d2 -->|"INT64(1)"| Squeeze_1
    Squeeze_0 -->|"INT64()"| Range_2
    Squeeze_1 -->|"INT64()"| Range_3
    Range_2 -->|"INT64(NEWDIM_range_0)"| Unsqueeze_4
    Range_3 -->|"INT64(NEWDIM_range_1)"| Unsqueeze_5
    Unsqueeze_5 -->|"INT64(NEWDIM_range_1, 1, 1, 1)"| Mul_6
    I_N -->|"INT64(1)"| Mul_6
    Unsqueeze_4 -->|"INT64(1, 1, 1, NEWDIM_range_0)"| Add_7
    Mul_6 -->|"INT64(NEWDIM_range_1, 1, 1, 1)"| Add_7

    O_yyc(["yyc INT64(c, 1, 1, b)"])
    Add_7 --> O_yyc

    class I_d1,I_N,I_d2,O_yyc ioNode
    class Squeeze_0,Squeeze_1,Range_2,Range_3,Unsqueeze_4,Unsqueeze_5,Mul_6,Add_7 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_d1(["d1 INT64(1)"])
    I_N(["N INT64(1)"])
    I_d2(["d2 INT64(1)"])

    CausalMaskMulAdd_0[["intermediate.CausalMaskMulAdd(., ., .)"]]

    I_d1 -->|"INT64(1)"| CausalMaskMulAdd_0
    I_d2 -->|"INT64(1)"| CausalMaskMulAdd_0
    I_N -->|"INT64(1)"| CausalMaskMulAdd_0

    O_yyc(["yyc INT64(c, 1, 1, b)"])
    CausalMaskMulAdd_0 --> O_yyc

    class I_d1,I_N,I_d2,O_yyc ioNode
    class CausalMaskMulAdd_0 opNode
    
apply(g: GraphBuilder, dim_squeeze1: NodeProto, dim_squeeze2: NodeProto, range1: NodeProto, range2: NodeProto, rg_unsqueeze1: NodeProto, rg_unsqueeze2: NodeProto, mul: NodeProto, add: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_rotary.FunctionCausalMaskPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Fuses nodes matching CausalMask into a local function.

<<<

from yobx.xbuilder import GraphBuilder
from yobx.xoptim import GraphBuilderPatternOptimization
from yobx.xoptim.patterns import (
    FunctionCausalMaskPattern,
)

pat = FunctionCausalMaskPattern()
g = GraphBuilderPatternOptimization(GraphBuilder(18))
print(pat._pattern_to_string(g))

>>>

    Pattern cannot be constructed without the matched nodes.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_d1(["d1 INT64(1)"])
    I_initi(["initi INT64(1)"])
    I_d2(["d2 INT64(1)"])

    Constant_0[["Constant() -#gt; initi"]]
    Squeeze_1[["Squeeze(.)"]]
    Squeeze_2[["Squeeze(.)"]]
    Range_3[["Range(0, ., 1)"]]
    Range_4[["Range(., ., 1)"]]
    Unsqueeze_5[["Unsqueeze(., [0, 1, 2])"]]
    Unsqueeze_6[["Unsqueeze(., [0, 1, 3])"]]
    Sub_7[["Sub(., .)"]]
    Greater_8[["Greater(., .)"]]

    I_d1 -->|"INT64(1)"| Squeeze_1
    I_d2 -->|"INT64(1)"| Squeeze_2
    Squeeze_2 -->|"INT64()"| Range_3
    Squeeze_1 -->|"INT64()"| Range_4
    Squeeze_2 -->|"INT64()"| Range_4
    Range_3 -->|"INT64(NEWDIM_range_0)"| Unsqueeze_5
    Range_4 -->|"INT64(NEWDIM_range_1)"| Unsqueeze_6
    Unsqueeze_6 -->|"INT64(1, 1, NEWDIM_range_1, 1)"| Sub_7
    Constant_0 -->|"INT64(1)"| Sub_7
    Unsqueeze_5 -->|"INT64(1, 1, 1, NEWDIM_range_0)"| Greater_8
    Sub_7 -->|"INT64(1, 1, NEWDIM_range_1, 1)"| Greater_8

    O_nd2(["nd2 INT64()"])
    Squeeze_2 --> O_nd2
    O_yc(["yc BOOL(1, 1, b-a, b)"])
    Greater_8 --> O_yc

    class I_d1,I_initi,I_d2,O_nd2,O_yc ioNode
    class Constant_0 constNode
    class Squeeze_1,Squeeze_2,Range_3,Range_4,Unsqueeze_5,Unsqueeze_6,Sub_7 opNode
    class Greater_8 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_d1(["d1 INT64(1)"])
    I_initi(["initi INT64(1)"])
    I_d2(["d2 INT64(1)"])

    Squeeze_0[["Squeeze(.)"]]
    ShiftedCausalMask_1[["intermediate.ShiftedCausalMask(., ., .)"]]

    I_d2 -->|"INT64(1)"| Squeeze_0
    I_d1 -->|"INT64(1)"| ShiftedCausalMask_1
    I_d2 -->|"INT64(1)"| ShiftedCausalMask_1
    I_initi -->|"INT64(1)"| ShiftedCausalMask_1

    O_nd2(["nd2 INT64()"])
    Squeeze_0 --> O_nd2
    O_yc(["yc BOOL(1, 1, b-a, b)"])
    ShiftedCausalMask_1 --> O_yc

    class I_d1,I_initi,I_d2,O_nd2,O_yc ioNode
    class Squeeze_0,ShiftedCausalMask_1 opNode
    
apply(g: GraphBuilder, dim_squeeze1: NodeProto, dim_squeeze2: NodeProto, range1: NodeProto, range2: NodeProto, rg_unsqueeze1: NodeProto, rg_unsqueeze2: NodeProto, sub2: NodeProto, less_or_equal: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_rotary.FunctionCosSinCachePattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Fuses nodes to simplify the creation of cos/sin caches in LLM.

<<<

from yobx.xbuilder import GraphBuilder
from yobx.xoptim import GraphBuilderPatternOptimization
from yobx.xoptim.patterns import (
    FunctionCosSinCachePattern,
)

pat = FunctionCosSinCachePattern()
g = GraphBuilderPatternOptimization(GraphBuilder(18))
print(pat._pattern_to_string(g))

>>>

    Pattern cannot be constructed without the matched nodes.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_weights(["weights FLOAT(1, 1, a)"])
    I_dim1(["dim1 INT64(1)"])
    I_dim2(["dim2 INT64(1)"])

    Squeeze_0[["Squeeze(.)"]]
    Squeeze_1[["Squeeze(.)"]]
    Range_2[["Range(., ., 1)"]]
    Unsqueeze_3[["Unsqueeze(., [0, 1])"]]
    Cast_4[["Cast(., to=FLOAT)"]]
    Reshape_5[["Reshape(., [0, -1, 1])"]]
    Mul_6[["Mul(., .)"]]
    Cos_7[["Cos(.)"]]
    Sin_8[["Sin(.)"]]

    I_dim1 -->|"INT64(1)"| Squeeze_0
    I_dim2 -->|"INT64(1)"| Squeeze_1
    Squeeze_0 -->|"INT64()"| Range_2
    Squeeze_1 -->|"INT64()"| Range_2
    Range_2 -->|"INT64(NEWDIM_range_0)"| Unsqueeze_3
    Unsqueeze_3 -->|"INT64(1, 1, NEWDIM_range_0)"| Cast_4
    Cast_4 -->|"FLOAT(1, 1, NEWDIM_range_0)"| Reshape_5
    I_weights -->|"FLOAT(1, 1, a)"| Mul_6
    Reshape_5 -->|"FLOAT(1, NEWDIM_range_0, 1)"| Mul_6
    Mul_6 -->|"FLOAT(1, NEWDIM_range_0, a)"| Cos_7
    Mul_6 -->|"FLOAT(1, NEWDIM_range_0, a)"| Sin_8

    O__onx_sin_mul_weights(["_onx_sin_mul_weights FLOAT(1, dim2-dim1, a)"])
    Sin_8 --> O__onx_sin_mul_weights
    O__onx_cos_mul_weights(["_onx_cos_mul_weights FLOAT(1, dim2-dim1, a)"])
    Cos_7 --> O__onx_cos_mul_weights

    class I_weights,I_dim1,I_dim2,O__onx_sin_mul_weights,O__onx_cos_mul_weights ioNode
    class Squeeze_0,Squeeze_1,Range_2,Unsqueeze_3,Cast_4,Reshape_5 opNode
    class Mul_6,Cos_7,Sin_8 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_weights(["weights FLOAT(1, 1, a)"])
    I_dim1(["dim1 INT64(1)"])
    I_dim2(["dim2 INT64(1)"])

    CosSinCacheWithRange_0[["intermediate.CosSinCacheWithRange(., ., .)"]]

    I_dim1 -->|"INT64(1)"| CosSinCacheWithRange_0
    I_dim2 -->|"INT64(1)"| CosSinCacheWithRange_0
    I_weights -->|"FLOAT(1, 1, a)"| CosSinCacheWithRange_0

    O__onx_sin_mul_weights(["_onx_sin_mul_weights FLOAT(1, dim2-dim1, a)"])
    CosSinCacheWithRange_0 --> O__onx_sin_mul_weights
    O__onx_cos_mul_weights(["_onx_cos_mul_weights FLOAT(1, dim2-dim1, a)"])
    CosSinCacheWithRange_0 --> O__onx_cos_mul_weights

    class I_weights,I_dim1,I_dim2,O__onx_sin_mul_weights,O__onx_cos_mul_weights ioNode
    class CosSinCacheWithRange_0 opNode
    
apply(g: GraphBuilder, dim_squeeze1: NodeProto, dim_squeeze2: NodeProto, range_node: NodeProto, unsq_or_cast_node: NodeProto, cast_or_unsq_node: NodeProto, reshape_node: NodeProto, mul_node: NodeProto, cos: NodeProto, cos_cast: NodeProto | None, sin: NodeProto, sin_cast: NodeProto | None) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_rotary.FunctionHalfRotaryEmbeddingPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Fuses nodes matching half RotaryEmbedding(23) into a local function.

<<<

from yobx.xbuilder import GraphBuilder
from yobx.xoptim import GraphBuilderPatternOptimization
from yobx.xoptim.patterns import (
    FunctionHalfRotaryEmbeddingPattern,
)

pat = FunctionHalfRotaryEmbeddingPattern()
g = GraphBuilderPatternOptimization(GraphBuilder(18))
print(pat._pattern_to_string(g))

>>>

    Pattern cannot be constructed without the matched nodes.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_m2(["m2 FLOAT(c, d)"])
    I_X(["X FLOAT(a, b, c, d)"])
    I_m1(["m1 FLOAT(c, d)"])

    Split_0[["Split(., axis=-1)"]]
    Neg_1[["Neg(.)"]]
    Concat_2[["Concat(., ., axis=-1)"]]
    Mul_3[["Mul(., .)"]]
    Mul_4[["Mul(., .)"]]
    Add_5[["Add(., .)"]]

    I_X -->|"FLOAT(a, b, c, d)"| Split_0
    Split_0 -->|"FLOAT(a, b, c, d-CeilToInt(d,2))"| Neg_1
    Neg_1 -->|"FLOAT(a, b, c, d-CeilToInt(d,2))"| Concat_2
    Split_0 -->|"FLOAT(a, b, c, CeilToInt(d,2))"| Concat_2
    Concat_2 -->|"FLOAT(a, b, c, d)"| Mul_3
    I_m1 -->|"FLOAT(c, d)"| Mul_3
    I_X -->|"FLOAT(a, b, c, d)"| Mul_4
    I_m2 -->|"FLOAT(c, d)"| Mul_4
    Mul_3 -->|"FLOAT(a, b, c, d)"| Add_5
    Mul_4 -->|"FLOAT(a, b, c, d)"| Add_5

    O_Y(["Y FLOAT(a, b, c, d)"])
    Add_5 --> O_Y

    class I_m2,I_X,I_m1,O_Y ioNode
    class Split_0,Neg_1,Concat_2,Mul_3,Mul_4,Add_5 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_m2(["m2 FLOAT(c, d)"])
    I_X(["X FLOAT(a, b, c, d)"])
    I_m1(["m1 FLOAT(c, d)"])

    HalfRotaryEmbedding_0[["intermediate.HalfRotaryEmbedding(., ., .)"]]

    I_X -->|"FLOAT(a, b, c, d)"| HalfRotaryEmbedding_0
    I_m2 -->|"FLOAT(c, d)"| HalfRotaryEmbedding_0
    I_m1 -->|"FLOAT(c, d)"| HalfRotaryEmbedding_0

    O_Y(["Y FLOAT(a, b, c, d)"])
    HalfRotaryEmbedding_0 --> O_Y

    class I_m2,I_X,I_m1,O_Y ioNode
    class HalfRotaryEmbedding_0 opNode
    
apply(g: GraphBuilder, split_node: NodeProto, neg_node: NodeProto, concat_node: NodeProto, mul1_node: NodeProto, mul2_node: NodeProto, add_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_rotary.RotaryConcatPartPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Optimizes the following pattern

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(a, b, c, d)"])

    ConstantOfShape_0[["ConstantOfShape([2, 2, 1024, 256])"]]
    ConstantOfShape_1[["ConstantOfShape([2, 2, 1024, 256])"]]
    Slice_2[["Slice(., [256], [512], [3])"]]
    Concat_3[["Concat(., ., axis=3)"]]
    Slice_4[["Slice(., [0], [256], [3])"]]
    Neg_5[["Neg(.)"]]
    Concat_6[["Concat(., ., axis=3)"]]
    Add_7[["Add(., .)"]]

    I_X -->|"FLOAT(a, b, c, d)"| Slice_2
    ConstantOfShape_0 -->|"FLOAT(2, 2, 1024, 256)"| Concat_3
    Slice_2 -->|"FLOAT(a, b, c, 256)"| Concat_3
    I_X -->|"FLOAT(a, b, c, d)"| Slice_4
    Slice_4 -->|"FLOAT(a, b, c, 256)"| Neg_5
    ConstantOfShape_0 -->|"FLOAT(2, 2, 1024, 256)"| Concat_6
    Neg_5 -->|"FLOAT(a, b, c, 256)"| Concat_6
    Concat_3 -->|"FLOAT(2, 2, 1024, 512)"| Add_7
    Concat_6 -->|"FLOAT(2, 2, 1024, 512)"| Add_7

    O_Y(["Y FLOAT(a, b, c, d)"])
    Add_7 --> O_Y

    class I_X,O_Y ioNode
    class ConstantOfShape_0,ConstantOfShape_1,Slice_2,Concat_3,Slice_4,Neg_5 opNode
    class Concat_6,Add_7 opNode
    

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(a, 16)"])
    I_split(["split INT64(2)"])
    I_x1(["x1 FLOAT(a, 8)"])
    I_shape(["shape INT64(2)"])
    I_nx1(["nx1 FLOAT(a, 8)"])
    I_x2(["x2 FLOAT(a, 8)"])

    Constant_0[["Constant() -#gt; shape"]]
    Constant_1[["Constant() -#gt; split"]]
    ConstantOfShape_2[["ConstantOfShape(.)"]]
    Split_3[["Split(., ., axis=1)"]]
    Neg_4[["Neg(.)"]]
    Concat_5[["Concat(., ., axis=1)"]]
    ConstantOfShape_6[["ConstantOfShape(.)"]]
    Concat_7[["Concat(., ., axis=1)"]]
    Add_8[["Add(., .)"]]

    Constant_0 -->|"INT64(2)"| ConstantOfShape_2
    I_X -->|"FLOAT(a, 16)"| Split_3
    Constant_1 -->|"INT64(2)"| Split_3
    Split_3 -->|"FLOAT(a, 8)"| Neg_4
    Neg_4 -->|"FLOAT(a, 8)"| Concat_5
    ConstantOfShape_6 -->|"FLOAT(3, 8)"| Concat_5
    Constant_0 -->|"INT64(2)"| ConstantOfShape_6
    ConstantOfShape_6 -->|"FLOAT(3, 8)"| Concat_7
    Split_3 -->|"FLOAT(a, 8)"| Concat_7
    Concat_5 -->|"FLOAT(a, 16)"| Add_8
    Concat_7 -->|"FLOAT(3, 16)"| Add_8

    O_Y(["Y FLOAT(a, 16)"])
    Add_8 --> O_Y
    O_zero(["zero FLOAT(3, 8)"])
    ConstantOfShape_6 --> O_zero
    O_x1(["x1 FLOAT(a, 8)"])
    Split_3 --> O_x1
    O_nx1(["nx1 FLOAT(a, 8)"])
    Neg_4 --> O_nx1
    O_x2(["x2 FLOAT(a, 8)"])
    Split_3 --> O_x2

    class I_X,I_split,I_x1,I_shape,I_nx1,I_x2,O_Y,O_zero,O_x1,O_nx1,O_x2 ioNode
    class Constant_0,Constant_1 constNode
    class ConstantOfShape_2,Split_3,Neg_4,Concat_5,ConstantOfShape_6,Concat_7,Add_8 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(a, 16)"])
    I_split(["split INT64(2)"])
    I_x1(["x1 FLOAT(a, 8)"])
    I_shape(["shape INT64(2)"])
    I_nx1(["nx1 FLOAT(a, 8)"])
    I_x2(["x2 FLOAT(a, 8)"])

    ConstantOfShape_0[["ConstantOfShape(.)"]]
    Split_1[["Split(., ., axis=1)"]]
    Neg_2[["Neg(.)"]]
    Concat_3[["Concat(., ., axis=1)"]]

    I_shape -->|"INT64(2)"| ConstantOfShape_0
    I_X -->|"FLOAT(a, 16)"| Split_1
    I_split -->|"INT64(2)"| Split_1
    Split_1 -->|"FLOAT(a, 8)"| Neg_2
    Neg_2 -->|"FLOAT(a, 8)"| Concat_3
    Split_1 -->|"FLOAT(a, 8)"| Concat_3

    O_Y(["Y FLOAT(a, 16)"])
    Concat_3 --> O_Y
    O_zero(["zero FLOAT(3, 8)"])
    ConstantOfShape_0 --> O_zero
    O_x1(["x1 FLOAT(a, 8)"])
    Split_1 --> O_x1
    O_nx1(["nx1 FLOAT(a, 8)"])
    Neg_2 --> O_nx1
    O_x2(["x2 FLOAT(a, 8)"])
    Split_1 --> O_x2

    class I_X,I_split,I_x1,I_shape,I_nx1,I_x2,O_Y,O_zero,O_x1,O_nx1,O_x2 ioNode
    class ConstantOfShape_0,Split_1,Neg_2,Concat_3 opNode
    
match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_rotary.RotaryEmbeddingPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Fuses nodes matching RotaryEmbedding(23).

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_m2(["m2 FLOAT(1, 1, c, e)"])
    I_X(["X FLOAT(a, 2, c, d)"])
    I_m1(["m1 FLOAT(1, 1, c, e)"])

    Concat_0[["Concat(., ., axis=-1)"]]
    Concat_1[["Concat(., ., axis=-1)"]]
    Split_2[["Split(., [4, 6], axis=-1)"]]
    HalfRotaryEmbedding_3[["intermediate.HalfRotaryEmbedding(., ., .)"]]
    Concat_4[["Concat(., ., axis=-1)"]]

    I_m2 -->|"FLOAT(1, 1, c, e)"| Concat_0
    I_m1 -->|"FLOAT(1, 1, c, e)"| Concat_1
    I_X -->|"FLOAT(a, 2, c, d)"| Split_2
    Split_2 --> HalfRotaryEmbedding_3
    Concat_0 --> HalfRotaryEmbedding_3
    Concat_1 --> HalfRotaryEmbedding_3
    HalfRotaryEmbedding_3 --> Concat_4
    Split_2 --> Concat_4

    O_Y(["Y FLOAT(a, b, c, d)"])
    Concat_4 --> O_Y

    class I_m2,I_X,I_m1,O_Y ioNode
    class Concat_0,Concat_1,Split_2,HalfRotaryEmbedding_3,Concat_4 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_m2(["m2 FLOAT(1, 1, c, e)"])
    I_X(["X FLOAT(a, 2, c, d)"])
    I_m1(["m1 FLOAT(1, 1, c, e)"])

    Shape_0[["Shape(., end=1, start=0)"]]
    Concat_1[["Concat(., [1, 1], axis=0)"]]
    Squeeze_2[["Squeeze(., [1])"]]
    Squeeze_3[["Squeeze(., [1])"]]
    Expand_4[["Expand(., .)"]]
    Expand_5[["Expand(., .)"]]
    RotaryEmbedding_6[["RotaryEmbedding(., ., .)"]]

    I_X -->|"FLOAT(a, 2, c, d)"| Shape_0
    Shape_0 -->|"INT64(1)"| Concat_1
    I_m2 -->|"FLOAT(1, 1, c, e)"| Squeeze_2
    I_m1 -->|"FLOAT(1, 1, c, e)"| Squeeze_3
    Squeeze_2 -->|"FLOAT(1, c, e)"| Expand_4
    Concat_1 -->|"INT64(3)"| Expand_4
    Squeeze_3 -->|"FLOAT(1, c, e)"| Expand_5
    Concat_1 -->|"INT64(3)"| Expand_5
    I_X -->|"FLOAT(a, 2, c, d)"| RotaryEmbedding_6
    Expand_4 -->|"FLOAT(a, c, e)"| RotaryEmbedding_6
    Expand_5 -->|"FLOAT(a, c, e)"| RotaryEmbedding_6

    O_Y(["Y FLOAT(a, b, c, d)"])
    RotaryEmbedding_6 --> O_Y

    class I_m2,I_X,I_m1,O_Y ioNode
    class Shape_0,Concat_1,Squeeze_2,Squeeze_3,Expand_4,Expand_5,RotaryEmbedding_6 opNode
    
apply(g: GraphBuilder, concat_cos: NodeProto, concat_sin: NodeProto, split_node: NodeProto, half_node: NodeProto, concat_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.