yobx.xoptim.patterns.onnx_expand#

class yobx.xoptim.patterns.onnx_expand.ExpandBroadcastPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Checks that a Expand is really needed before an element wise operator. The objective is to save one allocation and let the next operator do the expansion by broadcasting one input.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_mul_25(["mul_25 FLOAT(2, 1024, 1)"])
    I_input66(["input66 FLOAT(2, 1024, 1024)"])

    Expand_0[["Expand(., [2, 1024, 1024])"]]
    Mul_1[["Mul(., .)"]]

    I_mul_25 -->|"FLOAT(2, 1024, 1)"| Expand_0
    Expand_0 -->|"FLOAT(2, 1024, 1024)"| Mul_1
    I_input66 -->|"FLOAT(2, 1024, 1024)"| Mul_1

    O_MulMulMulPattern__mul_27(["MulMulMulPattern--mul_27 FLOAT(2, 1024, 1024)"])
    Mul_1 --> O_MulMulMulPattern__mul_27

    class I_mul_25,I_input66,O_MulMulMulPattern__mul_27 ioNode
    class Expand_0,Mul_1 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_mul_25(["mul_25 FLOAT(2, 1024, 1)"])
    I_input66(["input66 FLOAT(2, 1024, 1024)"])

    Mul_0[["Mul(., .)"]]

    I_mul_25 -->|"FLOAT(2, 1024, 1)"| Mul_0
    I_input66 -->|"FLOAT(2, 1024, 1024)"| Mul_0

    O_MulMulMulPattern__mul_27(["MulMulMulPattern--mul_27 FLOAT(2, 1024, 1024)"])
    Mul_0 --> O_MulMulMulPattern__mul_27

    class I_mul_25,I_input66,O_MulMulMulPattern__mul_27 ioNode
    class Mul_0 opNode
    
apply(g: GraphBuilder, node: NodeProto, next_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ExpandPattern(verbose: int = 0, priority: int = 0)[source]#

Checks that a Expand is really needed.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_init7_s4_32_2_10_8(["init7_s4_32_2_10_8 INT64(4)"])
    I_mul(["mul FLOAT(32, 2, 10, 8)"])

    Constant_0[["Constant() -#gt; init7_s4_32_2_10_8"]]
    Expand_1[["Expand(., .)"]]

    I_mul -->|"FLOAT(32, 2, 10, 8)"| Expand_1
    Constant_0 -->|"INT64(4)"| Expand_1

    O_expand(["expand FLOAT(32, 2, 10, 8)"])
    Expand_1 --> O_expand

    class I_init7_s4_32_2_10_8,I_mul,O_expand ioNode
    class Constant_0 constNode
    class Expand_1 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_init7_s4_32_2_10_8(["init7_s4_32_2_10_8 INT64(4)"])
    I_mul(["mul FLOAT(32, 2, 10, 8)"])

    Identity_0[["Identity(., .)"]]

    I_mul -->|"FLOAT(32, 2, 10, 8)"| Identity_0
    I_init7_s4_32_2_10_8 -->|"INT64(4)"| Identity_0

    O_expand(["expand FLOAT(32, 2, 10, 8)"])
    Identity_0 --> O_expand

    class I_init7_s4_32_2_10_8,I_mul,O_expand ioNode
    class Identity_0 opNode
    
apply(g: GraphBuilder, node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ExpandSwapPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Tries to move a node Expand forward in the graph. Expand + Exp can be changed into Exp + Expand. Then Exp applies on a tensor of a smaller or equal size.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_p(["p INT64(1)"])
    I_X(["X FLOAT(1, 5, 7)"])
    I_shape(["shape INT64(3)"])

    Constant_0[["Constant() -#gt; shape"]]
    Constant_1[["Constant() -#gt; p"]]
    Expand_2[["Expand(., .)"]]
    Pow_3[["Pow(., .)"]]

    I_X -->|"FLOAT(1, 5, 7)"| Expand_2
    Constant_0 -->|"INT64(3)"| Expand_2
    Expand_2 -->|"FLOAT(3, 5, 7)"| Pow_3
    Constant_1 -->|"INT64(1)"| Pow_3

    O_Z(["Z FLOAT(3, 5, 7)"])
    Pow_3 --> O_Z

    class I_p,I_X,I_shape,O_Z ioNode
    class Constant_0,Constant_1 constNode
    class Expand_2,Pow_3 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_p(["p INT64(1)"])
    I_X(["X FLOAT(1, 5, 7)"])
    I_shape(["shape INT64(3)"])

    Pow_0[["Pow(., .)"]]
    Expand_1[["Expand(., .)"]]

    I_X -->|"FLOAT(1, 5, 7)"| Pow_0
    I_p -->|"INT64(1)"| Pow_0
    Pow_0 -->|"FLOAT(1, 5, 7)"| Expand_1
    I_shape -->|"INT64(3)"| Expand_1

    O_Z(["Z FLOAT(3, 5, 7)"])
    Expand_1 --> O_Z

    class I_p,I_X,I_shape,O_Z ioNode
    class Pow_0,Expand_1 opNode
    
apply(g: GraphBuilder, node: NodeProto, next_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ExpandUnsqueezeExpandPattern(verbose: int = 0, priority: int = 0)[source]#

Fuses the sequence Expand + Unsqueeze + Expand into Unsqueeze + Expand. Since Expand does not change the rank of a tensor, the Unsqueeze axes are valid for the original tensor as well, and the final Expand can handle both the broadcasting of the first Expand and the new dimension added by Unsqueeze.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(1, a, b)"])
    I_shape1(["shape1 INT64(3)"])
    I_axes(["axes INT64(1)"])
    I_shape2(["shape2 INT64(4)"])

    Constant_0[["Constant() -#gt; shape1"]]
    Constant_1[["Constant() -#gt; axes"]]
    Constant_2[["Constant() -#gt; shape2"]]
    Expand_3[["Expand(., .)"]]
    Unsqueeze_4[["Unsqueeze(., .)"]]
    Expand_5[["Expand(., .)"]]

    I_X -->|"FLOAT(1, a, b)"| Expand_3
    Constant_0 -->|"INT64(3)"| Expand_3
    Expand_3 -->|"FLOAT(c, a, b)"| Unsqueeze_4
    Constant_1 -->|"INT64(1)"| Unsqueeze_4
    Unsqueeze_4 -->|"FLOAT(c, 1, a, b)"| Expand_5
    Constant_2 -->|"INT64(4)"| Expand_5

    O_Y(["Y FLOAT(c, d, a, b)"])
    Expand_5 --> O_Y

    class I_X,I_shape1,I_axes,I_shape2,O_Y ioNode
    class Constant_0,Constant_1,Constant_2 constNode
    class Expand_3,Unsqueeze_4,Expand_5 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(1, a, b)"])
    I_axes(["axes INT64(1)"])
    I_shape1_new(["shape1_with_new_1s INT64(4)"])
    I_shape2(["shape2 INT64(4)"])

    Constant_0[["Constant() -#gt; axes"]]
    Constant_1[["Constant() -#gt; shape1_with_new_1s"]]
    Max_2[["Max(., .)"]]
    Unsqueeze_3[["Unsqueeze(., .)"]]
    Expand_4[["Expand(., .)"]]

    I_X -->|"FLOAT(1, a, b)"| Unsqueeze_3
    Constant_0 -->|"INT64(1)"| Unsqueeze_3
    Constant_1 -->|"INT64(4)"| Max_2
    I_shape2 -->|"INT64(4)"| Max_2
    Unsqueeze_3 -->|"FLOAT(1, 1, a, b)"| Expand_4
    Max_2 -->|"INT64(4)"| Expand_4

    O_Y(["Y FLOAT(c, d, a, b)"])
    Expand_4 --> O_Y

    class I_X,I_axes,I_shape2,O_Y ioNode
    class Constant_0,Constant_1 constNode
    class Max_2,Unsqueeze_3,Expand_4 opNode
    
apply(g: GraphBuilder, expand_node: NodeProto, unsq_node: NodeProto, expand2_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ShapeBasedConcatExpandPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Rewrites Expand(X, concat(…)) if possible.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(a, 1)"])
    I_two(["two INT64(1)"])

    Constant_0[["Constant() -#gt; two"]]
    Shape_1[["Shape(., end=1, start=0)"]]
    Concat_2[["Concat(., ., axis=0)"]]
    Expand_3[["Expand(., .)"]]

    I_X -->|"FLOAT(a, 1)"| Shape_1
    Shape_1 -->|"INT64(1)"| Concat_2
    Constant_0 -->|"INT64(1)"| Concat_2
    I_X -->|"FLOAT(a, 1)"| Expand_3
    Concat_2 -->|"INT64(2)"| Expand_3

    O_Y(["Y FLOAT(a, 2)"])
    Expand_3 --> O_Y

    class I_X,I_two,O_Y ioNode
    class Constant_0 constNode
    class Shape_1,Concat_2,Expand_3 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(a, 1)"])
    I_two(["two INT64(1)"])

    Concat_0[["Concat([1], ., axis=0)"]]
    Expand_1[["Expand(., .)"]]

    I_two -->|"INT64(1)"| Concat_0
    I_X -->|"FLOAT(a, 1)"| Expand_1
    Concat_0 -->|"INT64(2)"| Expand_1

    O_Y(["Y FLOAT(a, 2)"])
    Expand_1 --> O_Y

    class I_X,I_two,O_Y ioNode
    class Concat_0,Expand_1 opNode
    
apply(g: GraphBuilder, concat_node: NodeProto, expand_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastMatMulPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Similar to yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern, but works only with MatMul.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_Y(["Y FLOAT(1, c, d)"])
    I_X(["X FLOAT(a, b, c)"])

    Shape_0[["Shape(., end=1, start=0)"]]
    Concat_1[["Concat(., [1, 1], axis=0)"]]
    Expand_2[["Expand(., .)"]]
    MatMul_3[["MatMul(., .)"]]

    I_Y -->|"FLOAT(1, c, d)"| Shape_0
    Shape_0 -->|"INT64(1)"| Concat_1
    I_Y -->|"FLOAT(1, c, d)"| Expand_2
    Concat_1 -->|"INT64(3)"| Expand_2
    I_X -->|"FLOAT(a, b, c)"| MatMul_3
    Expand_2 -->|"FLOAT(1, c, d)"| MatMul_3

    O_Z(["Z FLOAT(a, b, d)"])
    MatMul_3 --> O_Z

    class I_Y,I_X,O_Z ioNode
    class Shape_0,Concat_1,Expand_2,MatMul_3 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_Y(["Y FLOAT(1, c, d)"])
    I_X(["X FLOAT(a, b, c)"])

    MatMul_0[["MatMul(., .)"]]

    I_X -->|"FLOAT(a, b, c)"| MatMul_0
    I_Y -->|"FLOAT(1, c, d)"| MatMul_0

    O_Z(["Z FLOAT(a, b, d)"])
    MatMul_0 --> O_Z

    class I_Y,I_X,O_Z ioNode
    class MatMul_0 opNode
    
apply(g: GraphBuilder, expand_left: NodeProto, expand_right: NodeProto, binary_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Similar to yobx.xoptim.patterns.onnx_expand.ExpandBroadcastPattern, but it allows dynamic shapes as well. It does not look into the second argument of Expand, it just infers than an expand is not needed for a binary operator following just after.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(1, b, c)"])
    I_Y(["Y FLOAT(a, b, c)"])

    Expand_0[["Expand(., .)"]]
    Add_1[["Add(., .)"]]

    I_X -->|"FLOAT(1, b, c)"| Expand_0
    Expand_0 -->|"FLOAT(a, b, c)"| Add_1
    I_Y -->|"FLOAT(a, b, c)"| Add_1

    O_Z(["Z FLOAT(a, b, c)"])
    Add_1 --> O_Z

    class I_X,I_Y,O_Z ioNode
    class Expand_0,Add_1 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(1, b, c)"])
    I_Y(["Y FLOAT(a, b, c)"])

    Add_0[["Add(., .)"]]

    I_X -->|"FLOAT(1, b, c)"| Add_0
    I_Y -->|"FLOAT(a, b, c)"| Add_0

    O_Z(["Z FLOAT(a, b, c)"])
    Add_0 --> O_Z

    class I_X,I_Y,O_Z ioNode
    class Add_0 opNode
    
apply(g: GraphBuilder, expand_left: NodeProto, expand_right: NodeProto, binary_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandCastWhereSwapPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Rewrites Where(Cast(X), X, cond).

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(b, c)"])
    I_exp(["exp INT64(3)"])
    I_cst(["cst FLOAT(1)"])

    Constant_0[["Constant() -#gt; cst"]]
    Expand_1[["Expand(., .)"]]
    Cast_2[["Cast(., to=BOOL)"]]
    Where_3[["Where(., ., .)"]]

    I_X -->|"FLOAT(b, c)"| Expand_1
    I_exp -->|"INT64(3)"| Expand_1
    Expand_1 --> Cast_2
    Cast_2 --> Where_3
    Expand_1 --> Where_3
    Constant_0 -->|"FLOAT(1)"| Where_3

    O_Y(["Y FLOAT(b, b, c)"])
    Where_3 --> O_Y

    class I_X,I_exp,I_cst,O_Y ioNode
    class Constant_0 constNode
    class Expand_1,Cast_2,Where_3 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(b, c)"])
    I_exp(["exp INT64(3)"])
    I_cst(["cst FLOAT(1)"])

    Cast_0[["Cast(., to=BOOL)"]]
    Where_1[["Where(., ., .)"]]
    Expand_2[["Expand(., .)"]]

    I_X -->|"FLOAT(b, c)"| Cast_0
    Cast_0 -->|"BOOL(b, c)"| Where_1
    I_X -->|"FLOAT(b, c)"| Where_1
    I_cst -->|"FLOAT(1)"| Where_1
    Where_1 -->|"FLOAT(b, c)"| Expand_2
    I_exp -->|"INT64(3)"| Expand_2

    O_Y(["Y FLOAT(b, b, c)"])
    Expand_2 --> O_Y

    class I_X,I_exp,I_cst,O_Y ioNode
    class Cast_0,Where_1,Expand_2 opNode
    
apply(g: GraphBuilder, expand_node: NodeProto, cast_node: NodeProto, where_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandSwapPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#

Tries to move a node Expand forward in the graph for a binary operator. The code is similar to yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_full_shape(["full_shape INT64(2)"])
    I_Xc(["Xc FLOAT(d, 1)"])
    I_one(["one FLOAT(1)"])

    Constant_0[["Constant() -#gt; one"]]
    Expand_1[["Expand(., .)"]]
    Add_2[["Add(., .)"]]

    I_Xc -->|"FLOAT(d, 1)"| Expand_1
    I_full_shape -->|"INT64(2)"| Expand_1
    Expand_1 --> Add_2
    Constant_0 -->|"FLOAT(1)"| Add_2

    O_Y(["Y FLOAT(d, d)"])
    Add_2 --> O_Y

    class I_full_shape,I_Xc,I_one,O_Y ioNode
    class Constant_0 constNode
    class Expand_1,Add_2 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_full_shape(["full_shape INT64(2)"])
    I_Xc(["Xc FLOAT(d, 1)"])
    I_one(["one FLOAT(1)"])

    Add_0[["Add(., .)"]]
    Expand_1[["Expand(., .)"]]

    I_Xc -->|"FLOAT(d, 1)"| Add_0
    I_one -->|"FLOAT(1)"| Add_0
    Add_0 -->|"FLOAT(d, 1)"| Expand_1
    I_full_shape -->|"INT64(2)"| Expand_1

    O_Y(["Y FLOAT(d, d)"])
    Expand_1 --> O_Y

    class I_full_shape,I_Xc,I_one,O_Y ioNode
    class Add_0,Expand_1 opNode
    
apply(g: GraphBuilder, expand_left: NodeProto, expand_right: NodeProto, binary_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.ShapeBasedStaticExpandPattern(verbose: int = 0, priority: int = 0)[source]#

Compares input and output shapes to tell if the expand can uses a constant as a second input.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(2, 3, d, 1)"])

    Shape_0[["Shape(., end=-1, start=0)"]]
    Concat_1[["Concat(., [2], axis=0)"]]
    Expand_2[["Expand(., .)"]]

    I_X -->|"FLOAT(2, 3, d, 1)"| Shape_0
    Shape_0 -->|"INT64(3)"| Concat_1
    I_X -->|"FLOAT(2, 3, d, 1)"| Expand_2
    Concat_1 -->|"INT64(4)"| Expand_2

    O_Y(["Y FLOAT(2, 3, d, 2)"])
    Expand_2 --> O_Y

    class I_X,O_Y ioNode
    class Shape_0,Concat_1,Expand_2 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(2, 3, d, 1)"])

    Expand_0[["Expand(., [1, 1, 1, 2])"]]

    I_X -->|"FLOAT(2, 3, d, 1)"| Expand_0

    O_Y(["Y FLOAT(2, 3, d, 2)"])
    Expand_0 --> O_Y

    class I_X,O_Y ioNode
    class Expand_0 opNode
    
apply(g: GraphBuilder, reshape: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.SwapExpandReshapePattern(verbose: int = 0, priority: int = 0)[source]#

Checks if Expand + Reshape can be swapped.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_weight(["weight FLOAT(1, 4, 1)"])
    I_stat(["stat INT64(3)"])
    I_shape(["shape INT64(3)"])

    Constant_0[["Constant() -#gt; weight"]]
    Constant_1[["Constant() -#gt; stat"]]
    Expand_2[["Expand(., .)"]]
    Reshape_3[["Reshape(., .)"]]

    Constant_0 -->|"FLOAT(1, 4, 1)"| Expand_2
    I_shape -->|"INT64(3)"| Expand_2
    Expand_2 --> Reshape_3
    Constant_1 -->|"INT64(3)"| Reshape_3

    O_Y(["Y FLOAT(a, 1, 4)"])
    Reshape_3 --> O_Y

    class I_weight,I_stat,I_shape,O_Y ioNode
    class Constant_0,Constant_1 constNode
    class Expand_2,Reshape_3 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_weight(["weight FLOAT(1, 4, 1)"])
    I_stat(["stat INT64(3)"])
    I_shape(["shape INT64(3)"])

    Reshape_0[["Reshape(., .)"]]
    Expand_1[["Expand(., .)"]]

    I_weight -->|"FLOAT(1, 4, 1)"| Reshape_0
    I_stat -->|"INT64(3)"| Reshape_0
    Reshape_0 --> Expand_1
    I_shape -->|"INT64(3)"| Expand_1

    O_Y(["Y FLOAT(a, 1, 4)"])
    Expand_1 --> O_Y

    class I_weight,I_stat,I_shape,O_Y ioNode
    class Reshape_0,Expand_1 opNode
    
apply(g: GraphBuilder, expand_node: NodeProto, reshape_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

class yobx.xoptim.patterns.onnx_expand.SwapExpandUnsqueezePattern(verbose: int = 0, priority: int = 0)[source]#

Swaps Expand and Unsqueeze when Unsqueeze directly follows Expand. Expand(X, shape) Unsqueeze(expanded, axes) is rewritten as Unsqueeze(X, axes) Expand(unsqueezed, new_shape) where new_shape is obtained by inserting 1 at every position listed in axes into the original expand shape. Performing the Unsqueeze before the Expand means the Unsqueeze operates on the smaller (pre-expanded) tensor, which is more efficient.

Model with nodes to be fused:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(1, 5, 7)"])
    I_shape(["shape INT64(3)"])
    I_axes(["axes INT64(1)"])

    Constant_0[["Constant() -#gt; shape"]]
    Constant_1[["Constant() -#gt; axes"]]
    Expand_2[["Expand(., .)"]]
    Unsqueeze_3[["Unsqueeze(., .)"]]

    I_X -->|"FLOAT(1, 5, 7)"| Expand_2
    Constant_0 -->|"INT64(3)"| Expand_2
    Expand_2 -->|"FLOAT(3, 5, 7)"| Unsqueeze_3
    Constant_1 -->|"INT64(1)"| Unsqueeze_3

    O_Y(["Y FLOAT(3, 1, 5, 7)"])
    Unsqueeze_3 --> O_Y

    class I_X,I_shape,I_axes,O_Y ioNode
    class Constant_0,Constant_1 constNode
    class Expand_2,Unsqueeze_3 opNode
    

Outcome of the fusion:

        graph TD

    classDef ioNode fill:#dfd,stroke:#333,color:#333
    classDef initNode fill:#cccc00,stroke:#333,color:#333
    classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333
    classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333

    I_X(["X FLOAT(1, 5, 7)"])
    I_axes(["axes INT64(1)"])
    I_new_shape(["new_shape INT64(4)"])

    Constant_0[["Constant() -#gt; axes"]]
    Constant_1[["Constant() -#gt; new_shape"]]
    Unsqueeze_2[["Unsqueeze(., .)"]]
    Expand_3[["Expand(., .)"]]

    I_X -->|"FLOAT(1, 5, 7)"| Unsqueeze_2
    Constant_0 -->|"INT64(1)"| Unsqueeze_2
    Unsqueeze_2 -->|"FLOAT(1, 1, 5, 7)"| Expand_3
    Constant_1 -->|"INT64(4)"| Expand_3

    O_Y(["Y FLOAT(3, 1, 5, 7)"])
    Expand_3 --> O_Y

    class I_X,I_axes,I_new_shape,O_Y ioNode
    class Constant_0,Constant_1 constNode
    class Unsqueeze_2,Expand_3 opNode
    
apply(g: GraphBuilder, expand_node: NodeProto, unsq_node: NodeProto) List[NodeProto][source]#

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.