yobx.xoptim.patterns.onnx_expand#
- class yobx.xoptim.patterns.onnx_expand.ExpandBroadcastPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Checks that a Expand is really needed before an element wise operator. The objective is to save one allocation and let the next operator do the expansion by broadcasting one input.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_mul_25(["mul_25 FLOAT(2, 1024, 1)"]) I_input66(["input66 FLOAT(2, 1024, 1024)"]) Expand_0[["Expand(., [2, 1024, 1024])"]] Mul_1[["Mul(., .)"]] I_mul_25 -->|"FLOAT(2, 1024, 1)"| Expand_0 Expand_0 -->|"FLOAT(2, 1024, 1024)"| Mul_1 I_input66 -->|"FLOAT(2, 1024, 1024)"| Mul_1 O_MulMulMulPattern__mul_27(["MulMulMulPattern--mul_27 FLOAT(2, 1024, 1024)"]) Mul_1 --> O_MulMulMulPattern__mul_27 class I_mul_25,I_input66,O_MulMulMulPattern__mul_27 ioNode class Expand_0,Mul_1 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_mul_25(["mul_25 FLOAT(2, 1024, 1)"]) I_input66(["input66 FLOAT(2, 1024, 1024)"]) Mul_0[["Mul(., .)"]] I_mul_25 -->|"FLOAT(2, 1024, 1)"| Mul_0 I_input66 -->|"FLOAT(2, 1024, 1024)"| Mul_0 O_MulMulMulPattern__mul_27(["MulMulMulPattern--mul_27 FLOAT(2, 1024, 1024)"]) Mul_0 --> O_MulMulMulPattern__mul_27 class I_mul_25,I_input66,O_MulMulMulPattern__mul_27 ioNode class Mul_0 opNode- apply(g: GraphBuilder, node: NodeProto, next_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ExpandPattern(verbose: int = 0, priority: int = 0)[source]#
Checks that a Expand is really needed.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_init7_s4_32_2_10_8(["init7_s4_32_2_10_8 INT64(4)"]) I_mul(["mul FLOAT(32, 2, 10, 8)"]) Constant_0[["Constant() -#gt; init7_s4_32_2_10_8"]] Expand_1[["Expand(., .)"]] I_mul -->|"FLOAT(32, 2, 10, 8)"| Expand_1 Constant_0 -->|"INT64(4)"| Expand_1 O_expand(["expand FLOAT(32, 2, 10, 8)"]) Expand_1 --> O_expand class I_init7_s4_32_2_10_8,I_mul,O_expand ioNode class Constant_0 constNode class Expand_1 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_init7_s4_32_2_10_8(["init7_s4_32_2_10_8 INT64(4)"]) I_mul(["mul FLOAT(32, 2, 10, 8)"]) Identity_0[["Identity(., .)"]] I_mul -->|"FLOAT(32, 2, 10, 8)"| Identity_0 I_init7_s4_32_2_10_8 -->|"INT64(4)"| Identity_0 O_expand(["expand FLOAT(32, 2, 10, 8)"]) Identity_0 --> O_expand class I_init7_s4_32_2_10_8,I_mul,O_expand ioNode class Identity_0 opNode- apply(g: GraphBuilder, node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ExpandSwapPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Tries to move a node Expand forward in the graph. Expand + Exp can be changed into Exp + Expand. Then Exp applies on a tensor of a smaller or equal size.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_p(["p INT64(1)"]) I_X(["X FLOAT(1, 5, 7)"]) I_shape(["shape INT64(3)"]) Constant_0[["Constant() -#gt; shape"]] Constant_1[["Constant() -#gt; p"]] Expand_2[["Expand(., .)"]] Pow_3[["Pow(., .)"]] I_X -->|"FLOAT(1, 5, 7)"| Expand_2 Constant_0 -->|"INT64(3)"| Expand_2 Expand_2 -->|"FLOAT(3, 5, 7)"| Pow_3 Constant_1 -->|"INT64(1)"| Pow_3 O_Z(["Z FLOAT(3, 5, 7)"]) Pow_3 --> O_Z class I_p,I_X,I_shape,O_Z ioNode class Constant_0,Constant_1 constNode class Expand_2,Pow_3 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_p(["p INT64(1)"]) I_X(["X FLOAT(1, 5, 7)"]) I_shape(["shape INT64(3)"]) Pow_0[["Pow(., .)"]] Expand_1[["Expand(., .)"]] I_X -->|"FLOAT(1, 5, 7)"| Pow_0 I_p -->|"INT64(1)"| Pow_0 Pow_0 -->|"FLOAT(1, 5, 7)"| Expand_1 I_shape -->|"INT64(3)"| Expand_1 O_Z(["Z FLOAT(3, 5, 7)"]) Expand_1 --> O_Z class I_p,I_X,I_shape,O_Z ioNode class Pow_0,Expand_1 opNode- apply(g: GraphBuilder, node: NodeProto, next_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ExpandUnsqueezeExpandPattern(verbose: int = 0, priority: int = 0)[source]#
Fuses the sequence Expand + Unsqueeze + Expand into Unsqueeze + Expand. Since Expand does not change the rank of a tensor, the Unsqueeze axes are valid for the original tensor as well, and the final Expand can handle both the broadcasting of the first Expand and the new dimension added by Unsqueeze.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(1, a, b)"]) I_shape1(["shape1 INT64(3)"]) I_axes(["axes INT64(1)"]) I_shape2(["shape2 INT64(4)"]) Constant_0[["Constant() -#gt; shape1"]] Constant_1[["Constant() -#gt; axes"]] Constant_2[["Constant() -#gt; shape2"]] Expand_3[["Expand(., .)"]] Unsqueeze_4[["Unsqueeze(., .)"]] Expand_5[["Expand(., .)"]] I_X -->|"FLOAT(1, a, b)"| Expand_3 Constant_0 -->|"INT64(3)"| Expand_3 Expand_3 -->|"FLOAT(c, a, b)"| Unsqueeze_4 Constant_1 -->|"INT64(1)"| Unsqueeze_4 Unsqueeze_4 -->|"FLOAT(c, 1, a, b)"| Expand_5 Constant_2 -->|"INT64(4)"| Expand_5 O_Y(["Y FLOAT(c, d, a, b)"]) Expand_5 --> O_Y class I_X,I_shape1,I_axes,I_shape2,O_Y ioNode class Constant_0,Constant_1,Constant_2 constNode class Expand_3,Unsqueeze_4,Expand_5 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(1, a, b)"]) I_axes(["axes INT64(1)"]) I_shape1_new(["shape1_with_new_1s INT64(4)"]) I_shape2(["shape2 INT64(4)"]) Constant_0[["Constant() -#gt; axes"]] Constant_1[["Constant() -#gt; shape1_with_new_1s"]] Max_2[["Max(., .)"]] Unsqueeze_3[["Unsqueeze(., .)"]] Expand_4[["Expand(., .)"]] I_X -->|"FLOAT(1, a, b)"| Unsqueeze_3 Constant_0 -->|"INT64(1)"| Unsqueeze_3 Constant_1 -->|"INT64(4)"| Max_2 I_shape2 -->|"INT64(4)"| Max_2 Unsqueeze_3 -->|"FLOAT(1, 1, a, b)"| Expand_4 Max_2 -->|"INT64(4)"| Expand_4 O_Y(["Y FLOAT(c, d, a, b)"]) Expand_4 --> O_Y class I_X,I_axes,I_shape2,O_Y ioNode class Constant_0,Constant_1 constNode class Max_2,Unsqueeze_3,Expand_4 opNode- apply(g: GraphBuilder, expand_node: NodeProto, unsq_node: NodeProto, expand2_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ShapeBasedConcatExpandPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Rewrites Expand(X, concat(…)) if possible.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(a, 1)"]) I_two(["two INT64(1)"]) Constant_0[["Constant() -#gt; two"]] Shape_1[["Shape(., end=1, start=0)"]] Concat_2[["Concat(., ., axis=0)"]] Expand_3[["Expand(., .)"]] I_X -->|"FLOAT(a, 1)"| Shape_1 Shape_1 -->|"INT64(1)"| Concat_2 Constant_0 -->|"INT64(1)"| Concat_2 I_X -->|"FLOAT(a, 1)"| Expand_3 Concat_2 -->|"INT64(2)"| Expand_3 O_Y(["Y FLOAT(a, 2)"]) Expand_3 --> O_Y class I_X,I_two,O_Y ioNode class Constant_0 constNode class Shape_1,Concat_2,Expand_3 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(a, 1)"]) I_two(["two INT64(1)"]) Concat_0[["Concat([1], ., axis=0)"]] Expand_1[["Expand(., .)"]] I_two -->|"INT64(1)"| Concat_0 I_X -->|"FLOAT(a, 1)"| Expand_1 Concat_0 -->|"INT64(2)"| Expand_1 O_Y(["Y FLOAT(a, 2)"]) Expand_1 --> O_Y class I_X,I_two,O_Y ioNode class Concat_0,Expand_1 opNode- apply(g: GraphBuilder, concat_node: NodeProto, expand_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastMatMulPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Similar to
yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern, but works only with MatMul.Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_Y(["Y FLOAT(1, c, d)"]) I_X(["X FLOAT(a, b, c)"]) Shape_0[["Shape(., end=1, start=0)"]] Concat_1[["Concat(., [1, 1], axis=0)"]] Expand_2[["Expand(., .)"]] MatMul_3[["MatMul(., .)"]] I_Y -->|"FLOAT(1, c, d)"| Shape_0 Shape_0 -->|"INT64(1)"| Concat_1 I_Y -->|"FLOAT(1, c, d)"| Expand_2 Concat_1 -->|"INT64(3)"| Expand_2 I_X -->|"FLOAT(a, b, c)"| MatMul_3 Expand_2 -->|"FLOAT(1, c, d)"| MatMul_3 O_Z(["Z FLOAT(a, b, d)"]) MatMul_3 --> O_Z class I_Y,I_X,O_Z ioNode class Shape_0,Concat_1,Expand_2,MatMul_3 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_Y(["Y FLOAT(1, c, d)"]) I_X(["X FLOAT(a, b, c)"]) MatMul_0[["MatMul(., .)"]] I_X -->|"FLOAT(a, b, c)"| MatMul_0 I_Y -->|"FLOAT(1, c, d)"| MatMul_0 O_Z(["Z FLOAT(a, b, d)"]) MatMul_0 --> O_Z class I_Y,I_X,O_Z ioNode class MatMul_0 opNode- apply(g: GraphBuilder, expand_left: NodeProto, expand_right: NodeProto, binary_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Similar to
yobx.xoptim.patterns.onnx_expand.ExpandBroadcastPattern, but it allows dynamic shapes as well. It does not look into the second argument of Expand, it just infers than an expand is not needed for a binary operator following just after.Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(1, b, c)"]) I_Y(["Y FLOAT(a, b, c)"]) Expand_0[["Expand(., .)"]] Add_1[["Add(., .)"]] I_X -->|"FLOAT(1, b, c)"| Expand_0 Expand_0 -->|"FLOAT(a, b, c)"| Add_1 I_Y -->|"FLOAT(a, b, c)"| Add_1 O_Z(["Z FLOAT(a, b, c)"]) Add_1 --> O_Z class I_X,I_Y,O_Z ioNode class Expand_0,Add_1 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(1, b, c)"]) I_Y(["Y FLOAT(a, b, c)"]) Add_0[["Add(., .)"]] I_X -->|"FLOAT(1, b, c)"| Add_0 I_Y -->|"FLOAT(a, b, c)"| Add_0 O_Z(["Z FLOAT(a, b, c)"]) Add_0 --> O_Z class I_X,I_Y,O_Z ioNode class Add_0 opNode- apply(g: GraphBuilder, expand_left: NodeProto, expand_right: NodeProto, binary_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandCastWhereSwapPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Rewrites Where(Cast(X), X, cond).
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(b, c)"]) I_exp(["exp INT64(3)"]) I_cst(["cst FLOAT(1)"]) Constant_0[["Constant() -#gt; cst"]] Expand_1[["Expand(., .)"]] Cast_2[["Cast(., to=BOOL)"]] Where_3[["Where(., ., .)"]] I_X -->|"FLOAT(b, c)"| Expand_1 I_exp -->|"INT64(3)"| Expand_1 Expand_1 --> Cast_2 Cast_2 --> Where_3 Expand_1 --> Where_3 Constant_0 -->|"FLOAT(1)"| Where_3 O_Y(["Y FLOAT(b, b, c)"]) Where_3 --> O_Y class I_X,I_exp,I_cst,O_Y ioNode class Constant_0 constNode class Expand_1,Cast_2,Where_3 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(b, c)"]) I_exp(["exp INT64(3)"]) I_cst(["cst FLOAT(1)"]) Cast_0[["Cast(., to=BOOL)"]] Where_1[["Where(., ., .)"]] Expand_2[["Expand(., .)"]] I_X -->|"FLOAT(b, c)"| Cast_0 Cast_0 -->|"BOOL(b, c)"| Where_1 I_X -->|"FLOAT(b, c)"| Where_1 I_cst -->|"FLOAT(1)"| Where_1 Where_1 -->|"FLOAT(b, c)"| Expand_2 I_exp -->|"INT64(3)"| Expand_2 O_Y(["Y FLOAT(b, b, c)"]) Expand_2 --> O_Y class I_X,I_exp,I_cst,O_Y ioNode class Cast_0,Where_1,Expand_2 opNode- apply(g: GraphBuilder, expand_node: NodeProto, cast_node: NodeProto, where_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandSwapPattern(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]#
Tries to move a node Expand forward in the graph for a binary operator. The code is similar to
yobx.xoptim.patterns.onnx_expand.ShapeBasedExpandBroadcastPatternModel with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_full_shape(["full_shape INT64(2)"]) I_Xc(["Xc FLOAT(d, 1)"]) I_one(["one FLOAT(1)"]) Constant_0[["Constant() -#gt; one"]] Expand_1[["Expand(., .)"]] Add_2[["Add(., .)"]] I_Xc -->|"FLOAT(d, 1)"| Expand_1 I_full_shape -->|"INT64(2)"| Expand_1 Expand_1 --> Add_2 Constant_0 -->|"FLOAT(1)"| Add_2 O_Y(["Y FLOAT(d, d)"]) Add_2 --> O_Y class I_full_shape,I_Xc,I_one,O_Y ioNode class Constant_0 constNode class Expand_1,Add_2 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_full_shape(["full_shape INT64(2)"]) I_Xc(["Xc FLOAT(d, 1)"]) I_one(["one FLOAT(1)"]) Add_0[["Add(., .)"]] Expand_1[["Expand(., .)"]] I_Xc -->|"FLOAT(d, 1)"| Add_0 I_one -->|"FLOAT(1)"| Add_0 Add_0 -->|"FLOAT(d, 1)"| Expand_1 I_full_shape -->|"INT64(2)"| Expand_1 O_Y(["Y FLOAT(d, d)"]) Expand_1 --> O_Y class I_full_shape,I_Xc,I_one,O_Y ioNode class Add_0,Expand_1 opNode- apply(g: GraphBuilder, expand_left: NodeProto, expand_right: NodeProto, binary_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.ShapeBasedStaticExpandPattern(verbose: int = 0, priority: int = 0)[source]#
Compares input and output shapes to tell if the expand can uses a constant as a second input.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(2, 3, d, 1)"]) Shape_0[["Shape(., end=-1, start=0)"]] Concat_1[["Concat(., [2], axis=0)"]] Expand_2[["Expand(., .)"]] I_X -->|"FLOAT(2, 3, d, 1)"| Shape_0 Shape_0 -->|"INT64(3)"| Concat_1 I_X -->|"FLOAT(2, 3, d, 1)"| Expand_2 Concat_1 -->|"INT64(4)"| Expand_2 O_Y(["Y FLOAT(2, 3, d, 2)"]) Expand_2 --> O_Y class I_X,O_Y ioNode class Shape_0,Concat_1,Expand_2 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(2, 3, d, 1)"]) Expand_0[["Expand(., [1, 1, 1, 2])"]] I_X -->|"FLOAT(2, 3, d, 1)"| Expand_0 O_Y(["Y FLOAT(2, 3, d, 2)"]) Expand_0 --> O_Y class I_X,O_Y ioNode class Expand_0 opNode- apply(g: GraphBuilder, reshape: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.SwapExpandReshapePattern(verbose: int = 0, priority: int = 0)[source]#
Checks if Expand + Reshape can be swapped.
Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_weight(["weight FLOAT(1, 4, 1)"]) I_stat(["stat INT64(3)"]) I_shape(["shape INT64(3)"]) Constant_0[["Constant() -#gt; weight"]] Constant_1[["Constant() -#gt; stat"]] Expand_2[["Expand(., .)"]] Reshape_3[["Reshape(., .)"]] Constant_0 -->|"FLOAT(1, 4, 1)"| Expand_2 I_shape -->|"INT64(3)"| Expand_2 Expand_2 --> Reshape_3 Constant_1 -->|"INT64(3)"| Reshape_3 O_Y(["Y FLOAT(a, 1, 4)"]) Reshape_3 --> O_Y class I_weight,I_stat,I_shape,O_Y ioNode class Constant_0,Constant_1 constNode class Expand_2,Reshape_3 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_weight(["weight FLOAT(1, 4, 1)"]) I_stat(["stat INT64(3)"]) I_shape(["shape INT64(3)"]) Reshape_0[["Reshape(., .)"]] Expand_1[["Expand(., .)"]] I_weight -->|"FLOAT(1, 4, 1)"| Reshape_0 I_stat -->|"INT64(3)"| Reshape_0 Reshape_0 --> Expand_1 I_shape -->|"INT64(3)"| Expand_1 O_Y(["Y FLOAT(a, 1, 4)"]) Expand_1 --> O_Y class I_weight,I_stat,I_shape,O_Y ioNode class Reshape_0,Expand_1 opNode- apply(g: GraphBuilder, expand_node: NodeProto, reshape_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- class yobx.xoptim.patterns.onnx_expand.SwapExpandUnsqueezePattern(verbose: int = 0, priority: int = 0)[source]#
Swaps Expand and Unsqueeze when Unsqueeze directly follows Expand.
Expand(X, shape) → Unsqueeze(expanded, axes)is rewritten asUnsqueeze(X, axes) → Expand(unsqueezed, new_shape)wherenew_shapeis obtained by inserting1at every position listed inaxesinto the original expand shape. Performing the Unsqueeze before the Expand means the Unsqueeze operates on the smaller (pre-expanded) tensor, which is more efficient.Model with nodes to be fused:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(1, 5, 7)"]) I_shape(["shape INT64(3)"]) I_axes(["axes INT64(1)"]) Constant_0[["Constant() -#gt; shape"]] Constant_1[["Constant() -#gt; axes"]] Expand_2[["Expand(., .)"]] Unsqueeze_3[["Unsqueeze(., .)"]] I_X -->|"FLOAT(1, 5, 7)"| Expand_2 Constant_0 -->|"INT64(3)"| Expand_2 Expand_2 -->|"FLOAT(3, 5, 7)"| Unsqueeze_3 Constant_1 -->|"INT64(1)"| Unsqueeze_3 O_Y(["Y FLOAT(3, 1, 5, 7)"]) Unsqueeze_3 --> O_Y class I_X,I_shape,I_axes,O_Y ioNode class Constant_0,Constant_1 constNode class Expand_2,Unsqueeze_3 opNodeOutcome of the fusion:
graph TD classDef ioNode fill:#dfd,stroke:#333,color:#333 classDef initNode fill:#cccc00,stroke:#333,color:#333 classDef constNode fill:#f9f,stroke:#333,stroke-width:2px,color:#333 classDef opNode fill:#bbf,stroke:#333,stroke-width:2px,color:#333 I_X(["X FLOAT(1, 5, 7)"]) I_axes(["axes INT64(1)"]) I_new_shape(["new_shape INT64(4)"]) Constant_0[["Constant() -#gt; axes"]] Constant_1[["Constant() -#gt; new_shape"]] Unsqueeze_2[["Unsqueeze(., .)"]] Expand_3[["Expand(., .)"]] I_X -->|"FLOAT(1, 5, 7)"| Unsqueeze_2 Constant_0 -->|"INT64(1)"| Unsqueeze_2 Unsqueeze_2 -->|"FLOAT(1, 1, 5, 7)"| Expand_3 Constant_1 -->|"INT64(4)"| Expand_3 O_Y(["Y FLOAT(3, 1, 5, 7)"]) Expand_3 --> O_Y class I_X,I_axes,I_new_shape,O_Y ioNode class Constant_0,Constant_1 constNode class Unsqueeze_2,Expand_3 opNode- apply(g: GraphBuilder, expand_node: NodeProto, unsq_node: NodeProto) List[NodeProto][source]#
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]#
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.