Pattern Optimizer

The pattern optimizer is implemented by class GraphBuilderPatternOptimization. It searches for a specific sequence of nodes in the graph and replaces it by another one without changing the inputs or the long_outputs of the graph. The goal of the optimizer is to make the whole computation graph more efficient. The goal of this implementation is to make this optimization as fast as possible. Assuming the nodes in an onnx graph are ordered in a way every input of a node was created by previous nodes, the optimizer must not require any global reordering. The cost should be in O(N P I) in the worst case where N is the number of nodes, P is the number of patterns, I is the number of iterations.

It is difficult to foresee what a pattern needs in order to rewrite a part of the graph. This API tries to give as much freedom as it can without leaving too much to do to the developper which tries to add a new pattern.

Patterns

Patterns must inherit from PatternOptimization. This class defines two methods.

PatternOptimization.match

def match(
    self,
    g: "GraphBuilderPatternOptimization",
    node: NodeProto,
    matched: List[MatchResult],
) -> Optional[MatchResult]:
  • g is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node: the matching must determine if some nodes around this one are part of set of nodes this pattern optimizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched: usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

Debugging: method none

def none(
    self,
    node: Optional[NodeProto] = None,
    lineno: Optional[int] = None,
    msg: str = "",
):

It may be useful which reason made a pattern matching fail. Instead of returning None, method match can return the following expression:

return self.none(node, inspect.currentframe().f_lineno)

By setting the verbosity (see next Section), the user may then know which lines in the code returned None and which condition failed.

PatternOptimization.apply

@classmethod
def apply(
    cls, g: "GraphBuilder", *nodes: Sequence[NodeProto]
) -> List[NodeProto]:

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting. It assumes no other pattern optimizer modified them or will modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Optimization Algorithm

It is implemented in method optimize

def optimize(
    self, max_iter=-1, remove_identity: bool = True
) -> List[Dict[str, Any]]:

The algorithm runs multiple iteration until the graph is not evolving or max_iter is reached. By default, it is equal to the number of nodes. An iteration is:

matches = []

builds all successors and predecessors

# Step 1: match

for all patterns P:

    for all nodes n:

        r = p.match(n)
        if r:
            if no node already scheduled to be rewritten by another match:
                matches.append(r)

# Step 2: apply

for all matches r:
    apply the match r

# Step 3: clean

remove unused nodes
remove identity nodes

This algorithm may apply more than one rewriting at each iteration but it guarantees the local structure when applying the rewriting was not altered by another one.

Adding a pattern

See #80 about the addition of a new pattern.

Example

Simple API

We consider the following simple model:

<<<

import torch
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.torch_interpreter import to_onnx


class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(10, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
        )

    def forward(self, x):
        return self.layers(x)


x = torch.rand(3, 10)
onx = to_onnx(
    MLP(), (x,), input_names=["x"], options=OptimizationOptions(patterns=None)
)
with open("temp_doc_mlp.onnx", "wb") as f:
    f.write(onx.SerializeToString())
print(onnx_simple_text_plot(onx))

>>>

    opset: domain='' version=18
    input: name='x' type=dtype('float32') shape=[3, 10]
    init: name='p_layers_0_weight' type=dtype('float32') shape=(32, 10)
    init: name='p_layers_0_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_2_weight' type=dtype('float32') shape=(1, 32)
    init: name='p_layers_2_bias' type=dtype('float32') shape=(1,) -- array([0.061], dtype=float32)
    Transpose(p_layers_0_weight, perm=[1,0]) -> t
      Gemm(x, t, p_layers_0_bias, alpha=1.00, beta=1.00) -> addmm
        Relu(addmm) -> relu
    Transpose(p_layers_2_weight, perm=[1,0]) -> t_1
      Gemm(relu, t_1, p_layers_2_bias, alpha=1.00, beta=1.00) -> output_0
    output: name='output_0' type=dtype('float32') shape=[3, 1]

Which we can renders as follows:

digraph{
  ranksep=0.25;
  orientation=portrait;
  nodesep=0.05;
  size=7;

  x [shape=box color=red label="x\nTensorProto.FLOAT\nshape=[3, 10]" fontsize=10];

  output_0 [shape=box color=green label="output_0\nTensorProto.FLOAT\nshape=[3, 1]" fontsize=10];

  p_layers_0_weight [shape=box label="p_layers_0_weight\nfloat32((32, 10))\n[[-1.942e-01 -2.460e-01  1.436e-01 -1.956e-01  1.6..." fontsize=10];
  p_layers_0_bias [shape=box label="p_layers_0_bias\nfloat32((32,))\n[-0.139  0.235  0.093  0.263  0.213 -0.314 -0.238 ..." fontsize=10];
  p_layers_2_weight [shape=box label="p_layers_2_weight\nfloat32((1, 32))\n[[-0.009 -0.042  0.046  0.104  0.116  0.054 -0.027..." fontsize=10];
  p_layers_2_bias [shape=box label="p_layers_2_bias\nfloat32((1,))\n[0.061]" fontsize=10];

  t [shape=box label="t" fontsize=10];
  t [shape=box style="filled,rounded" color=orange label="Transpose\nperm=[1, 0]" fontsize=10];
  p_layers_0_weight -> t;
  t -> t;

  addmm [shape=box label="addmm" fontsize=10];
  addmm [shape=box style="filled,rounded" color=orange label="Gemm\nalpha=1.0\nbeta=1.0" fontsize=10];
  x -> addmm;
  t -> addmm;
  p_layers_0_bias -> addmm;
  addmm -> addmm;

  relu [shape=box label="relu" fontsize=10];
  Opset [shape=box style="filled,rounded" color=orange label="Relu" fontsize=10];
  addmm -> Opset;
  Opset -> relu;

  t_1 [shape=box label="t_1" fontsize=10];
  t2 [shape=box style="filled,rounded" color=orange label="Transpose\nperm=[1, 0]" fontsize=10];
  p_layers_2_weight -> t2;
  t2 -> t_1;

  addmm2 [shape=box style="filled,rounded" color=orange label="Gemm\nalpha=1.0\nbeta=1.0" fontsize=10];
  relu -> addmm2;
  t_1 -> addmm2;
  p_layers_2_bias -> addmm2;
  addmm2 -> output_0;
}

We then apply the optimizations by writing the following code:

<<<

import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder

onx = onnx.load("temp_doc_mlp.onnx")

# The model is placed in a GraphBuilder.
# It creates dictionnaires to store shapes, ranks, types
# to make it easier to the optimizers to find the information
# they need. It still uses NodeProto to store nodes
gr = GraphBuilder(onx, infer_shapes=True)

# Let's optimize.
opt_onx = gr.to_onnx(optimize=True)
with open("temp_doc_mlp_opt.onnx", "wb") as f:
    f.write(opt_onx.SerializeToString())
print(onnx_simple_text_plot(opt_onx))

>>>

    opset: domain='' version=18
    input: name='x' type=dtype('float32') shape=[3, 10]
    init: name='p_layers_0_weight' type=dtype('float32') shape=(32, 10)
    init: name='p_layers_0_bias' type=dtype('float32') shape=(32,)
    init: name='p_layers_2_weight' type=dtype('float32') shape=(1, 32)
    init: name='p_layers_2_bias' type=dtype('float32') shape=(1,) -- array([0.061], dtype=float32)
    Gemm(x, p_layers_0_weight, p_layers_0_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
      Relu(addmm) -> relu
        Gemm(relu, p_layers_2_weight, p_layers_2_bias, transA=0, transB=1, alpha=1.00, beta=1.00) -> output_0
    output: name='output_0' type=dtype('float32') shape=[3, 1]

Which renders as follows:

digraph{
  ranksep=0.25;
  orientation=portrait;
  nodesep=0.05;
  size=7;

  x [shape=box color=red label="x\nTensorProto.FLOAT\nshape=[3, 10]" fontsize=10];

  output_0 [shape=box color=green label="output_0\nTensorProto.FLOAT\nshape=[3, 1]" fontsize=10];

  p_layers_0_weight [shape=box label="p_layers_0_weight\nfloat32((32, 10))\n[[-1.942e-01 -2.460e-01  1.436e-01 -1.956e-01  1.6..." fontsize=10];
  p_layers_0_bias [shape=box label="p_layers_0_bias\nfloat32((32,))\n[-0.139  0.235  0.093  0.263  0.213 -0.314 -0.238 ..." fontsize=10];
  p_layers_2_weight [shape=box label="p_layers_2_weight\nfloat32((1, 32))\n[[-0.009 -0.042  0.046  0.104  0.116  0.054 -0.027..." fontsize=10];
  p_layers_2_bias [shape=box label="p_layers_2_bias\nfloat32((1,))\n[0.061]" fontsize=10];

  addmm [shape=box label="addmm" fontsize=10];
  TransposeMatMulPattern__addmm [shape=box style="filled,rounded" color=orange label="Gemm\nalpha=1.0\nbeta=1.0\ntransA=0\ntransB=1" fontsize=10];
  x -> TransposeMatMulPattern__addmm;
  p_layers_0_weight -> TransposeMatMulPattern__addmm;
  p_layers_0_bias -> TransposeMatMulPattern__addmm;
  TransposeMatMulPattern__addmm -> addmm;

  relu [shape=box label="relu" fontsize=10];
  Opset [shape=box style="filled,rounded" color=orange label="Relu" fontsize=10];
  addmm -> Opset;
  Opset -> relu;

  TransposeMatMulPattern__addmm2 [shape=box style="filled,rounded" color=orange label="Gemm\nalpha=1.0\nbeta=1.0\ntransA=0\ntransB=1" fontsize=10];
  relu -> TransposeMatMulPattern__addmm2;
  p_layers_2_weight -> TransposeMatMulPattern__addmm2;
  p_layers_2_bias -> TransposeMatMulPattern__addmm2;
  TransposeMatMulPattern__addmm2 -> output_0;
}

Verbosity

<<<

import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder

onx = onnx.load("temp_doc_mlp.onnx")

gr = GraphBuilder(onx, infer_shapes=True, verbose=1)
opt_onx = gr.to_onnx(optimize=True)

>>>

    [GraphBuilderPatternOptimization.optimize] start with 5 nodes and 27 patterns, priorities=[0, 1]
    [GraphBuilderPatternOptimization.optimize] use pattern   1/27 - P0 - CastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   2/27 - P0 - ExpandPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   3/27 - P0 - IdentityPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   4/27 - P0 - ReshapeReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   5/27 - P0 - SameChildrenPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   6/27 - P0 - TransposeTransposePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   7/27 - P0 - UnsqueezeUnsqueezePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   8/27 - P1 - CastCastBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   9/27 - P1 - CastOpCastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  10/27 - P1 - ComputationCastOpCastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  11/27 - P1 - DivByMulScalarPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  12/27 - P1 - ExpandBroadcastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  13/27 - P1 - ExpandSwapPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  14/27 - P1 - MatMulReshape2Of3Pattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  15/27 - P1 - MulMulMulScalarPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  16/27 - P1 - ReduceReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  17/27 - P1 - ReduceSumNormalizePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  18/27 - P1 - Reshape2Of3Pattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  19/27 - P1 - ReshapeMatMulReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  20/27 - P1 - ReshapeReshapeBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  21/27 - P1 - RotaryConcatPartPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  22/27 - P1 - SlicesSplitPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  23/27 - P1 - Sub1MulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  24/27 - P1 - SwitchOrderBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  25/27 - P1 - TransposeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  26/27 - P1 - TransposeReshapeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  27/27 - P1 - UnsqueezeEqualPattern()
    [GraphBuilderPatternOptimization.optimize] iteration 0: 5 nodes, priority=0
    [GraphBuilderPatternOptimization.optimize] increase priority to 1
    [GraphBuilderPatternOptimization.optimize] iteration 1: 5 nodes, priority=1
    [GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*TransposeMatMulPattern - time=0.000 | max_time=TransposeMatMulPattern:0.000
    [GraphBuilderPatternOptimization.optimize] iteration 2: 3 nodes, priority=1
    [GraphBuilderPatternOptimization.optimize] done after 3 iterations with 3 nodes in 0.001
    [GraphBuilder] done with 3 nodes in 0.001
        STAT apply_TransposeMatMulPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00020869999571004882
        STAT build_for_pattern +0 -0 #it=3 maxmatch=0 i=0 - time=7.160000677686185e-05
        STAT check_A +0 -0 #it=0 maxmatch=0 i=0 - time=1.479999627918005e-05
        STAT check_B +0 -0 #it=0 maxmatch=0 i=0 - time=9.79999458650127e-06
        STAT check_C +0 -0 #it=0 maxmatch=0 i=0 - time=9.700001101009548e-06
        STAT check_F +0 -0 #it=0 maxmatch=0 i=0 - time=8.000002708286047e-06
        STAT check_G +0 -0 #it=0 maxmatch=0 i=0 - time=6.999995093792677e-06
        STAT check_pattern_A +0 -0 #it=1 maxmatch=0 i=0 - time=2.1399995603132993e-05
        STAT check_pattern_B +0 -0 #it=3 maxmatch=0 i=0 - time=3.139999171253294e-05
        STAT match_CastCastBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.60000033956021e-06
        STAT match_CastOpCastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=1.2200005585327744e-05
        STAT match_CastPattern +0 -0 #it=3 maxmatch=0 i=0 - time=2.179999864893034e-05
        STAT match_ComputationCastOpCastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.500006515532732e-06
        STAT match_DivByMulScalarPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.000002708286047e-06
        STAT match_ExpandBroadcastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.09999619377777e-06
        STAT match_ExpandPattern +0 -0 #it=3 maxmatch=0 i=0 - time=1.3400007446762174e-05
        STAT match_ExpandSwapPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.499998901039362e-06
        STAT match_IdentityPattern +0 -0 #it=3 maxmatch=0 i=0 - time=3.9200000173877925e-05
        STAT match_MatMulReshape2Of3Pattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.000003046821803e-06
        STAT match_MulMulMulScalarPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.299997716676444e-06
        STAT match_ReduceReshapePattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.699993825051934e-06
        STAT match_ReduceSumNormalizePattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.39999847812578e-06
        STAT match_Reshape2Of3Pattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.2999980552122e-06
        STAT match_ReshapeMatMulReshapePattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.199996955227107e-06
        STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.100003469735384e-06
        STAT match_ReshapeReshapePattern +0 -0 #it=3 maxmatch=0 i=0 - time=1.2600001355167478e-05
        STAT match_RotaryConcatPartPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.59999272506684e-06
        STAT match_SameChildrenPattern +0 -0 #it=3 maxmatch=0 i=0 - time=2.469999890308827e-05
        STAT match_SlicesSplitPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.999995770864189e-06
        STAT match_Sub1MulPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.39999847812578e-06
        STAT match_SwitchOrderBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.499999239575118e-06
        STAT match_TransposeMatMulPattern +0 -0 #it=2 maxmatch=2 i=2 - time=0.00015490000077988952
        STAT match_TransposeReshapeMatMulPattern +0 -0 #it=2 maxmatch=2 i=0 - time=1.0399999155197293e-05
        STAT match_TransposeTransposePattern +0 -0 #it=3 maxmatch=2 i=0 - time=2.3700013116467744e-05
        STAT match_UnsqueezeEqualPattern +0 -0 #it=2 maxmatch=2 i=0 - time=8.900002285372466e-06
        STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=3 maxmatch=2 i=0 - time=1.390001125400886e-05
        STAT pattern_optimization +0 -2 #it=0 maxmatch=0 i=0 - time=0.001188700000056997
        STAT remove_identity_nodes +0 -0 #it=3 maxmatch=0 i=0 - time=0.00011399999493733048
        STAT remove_unused +0 -0 #it=0 maxmatch=0 i=0 - time=4.7400004405062646e-05
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--
         INPUT:   1 x 1t
        OUTPUT:   1 x 1t
          INIT:   4 x 1t
          NODE:   2 x Gemm
          NODE:   1 x Relu
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--DETAILED--
         INPUT:   1 x 1t[3x10]
        OUTPUT:   1 x 1t[3x1]
          INIT:   1 x 1t[1]
          INIT:   1 x 1t[1x32]
          INIT:   1 x 1t[32]
          INIT:   1 x 1t[32x10]
          NODE:   1 x Gemm -SIG- 1t[3x10], 1t[32x10], 1t[32]
          NODE:   1 x Gemm -SIG- 1t[3x32], 1t[1x32], 1t[1]
          NODE:   1 x Relu -SIG- 1t[3x32]
    [GraphBuilder-WDC.to_onnx] make_model
    [GraphBuilder-WDC._build_initializers] start with 4 initializers, large_model=False, external_threshold=1024
    [GraphBuilder-WDC._build_initializers] switch low/high order
    [GraphBuilder-WDC._build_initializers] done in 1.400003384333104e-06s with 4 initializers, 0 large initializers

With more verbosity:

<<<

import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder

onx = onnx.load("temp_doc_mlp.onnx")

gr = GraphBuilder(onx, infer_shapes=True, verbose=11)
opt_onx = gr.to_onnx(optimize=True)

>>>

    [GraphBuilder._update_structures_with_proto] starts with 5 nodes
    [GraphBuilder-KHK.set_shape] p_layers_0_weight:(32, 10)
    [GraphBuilder-KHK.set_rank] p_layers_0_weight:2
    [GraphBuilder-KHK.set_type] p_layers_0_weight:1
    [GraphBuilder-KHK.set_shape] p_layers_0_bias:(32,)
    [GraphBuilder-KHK.set_rank] p_layers_0_bias:1
    [GraphBuilder-KHK.set_type] p_layers_0_bias:1
    [GraphBuilder-KHK.set_shape] p_layers_2_weight:(1, 32)
    [GraphBuilder-KHK.set_rank] p_layers_2_weight:2
    [GraphBuilder-KHK.set_type] p_layers_2_weight:1
    [GraphBuilder-KHK.set_shape] p_layers_2_bias:(1,)
    [GraphBuilder-KHK.set_rank] p_layers_2_bias:1
    [GraphBuilder-KHK.set_type] p_layers_2_bias:1
    [GraphBuilder-KHK.set_type] x:1
    [GraphBuilder-KHK.set_shape] x:(3, 10)
    [GraphBuilder-KHK.set_rank] x:2
    [GraphBuilder-KHK.set_type] output_0:1
    [GraphBuilder-KHK.set_shape] output_0:(3, 1)
    [GraphBuilder-KHK.set_rank] output_0:2
    [GraphBuilder-KHK.set_type] t:1
    [GraphBuilder-KHK.set_shape] t:(10, 32)
    [GraphBuilder-KHK.set_rank] t:2
    [GraphBuilder-KHK.set_type] addmm:1
    [GraphBuilder-KHK.set_shape] addmm:(3, 32)
    [GraphBuilder-KHK.set_rank] addmm:2
    [GraphBuilder-KHK.set_type] relu:1
    [GraphBuilder-KHK.set_shape] relu:(3, 32)
    [GraphBuilder-KHK.set_rank] relu:2
    [GraphBuilder-KHK.set_type] t_1:1
    [GraphBuilder-KHK.set_shape] t_1:(32, 1)
    [GraphBuilder-KHK.set_rank] t_1:2
    [GraphBuilder._update_structures_with_proto] ends with 5 nodes in 0.00024529999791411683
    [GraphBuilder.constant_folding] starts with 4 constants and 5 nodes.
    [GraphBuilder.constant_folding] ends with 4 constants and 5 nodes in 1.1099997209385037e-05 seconds
    [GraphBuilder._update_shape_types_with_proto] starts with 5 nodes and 5 shapes.
    [GraphBuilder._update_shape_types_with_proto] infer shapes
    [GraphBuilder._update_shape_types_with_proto] infer shapes done 0.00010060000204248354 seconds
    [GraphBuilder._update_shape_types_with_proto] _clean_shapes after 0.00012170000263722613 seconds
    [GraphBuilder._update_shape_types_with_proto] walk through 5 shapes.
    [GraphBuilder-KHK.set_type] t_1:1
    [GraphBuilder-KHK.set_type] addmm_1:1
    [GraphBuilder-KHK.set_shape] addmm_1:(3, 1)
    [GraphBuilder-KHK.set_rank] addmm_1:2
    [GraphBuilder-KHK.set_type] addmm:1
    [GraphBuilder-KHK.set_type] relu:1
    [GraphBuilder-KHK.set_type] t:1
    [GraphBuilder._update_shape_types_with_proto] ends in 6.579999899258837e-05 seconds.
    [GraphBuilder.remove_identity_nodes] starts with 5
    [GraphBuilder.remove_identity_nodes] found 0 replacements
    [GraphBuilder.remove_identity_nodes] kept 5 nodes
    [GraphBuilder.remove_identity_nodes] ends with 5 nodes in 1.679999695625156e-05 seconds
    [GraphBuilderPatternOptimization.optimize] start with 5 nodes and 27 patterns, priorities=[0, 1]
    [GraphBuilderPatternOptimization.optimize] use pattern   1/27 - P0 - CastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   2/27 - P0 - ExpandPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   3/27 - P0 - IdentityPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   4/27 - P0 - ReshapeReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   5/27 - P0 - SameChildrenPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   6/27 - P0 - TransposeTransposePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   7/27 - P0 - UnsqueezeUnsqueezePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   8/27 - P1 - CastCastBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   9/27 - P1 - CastOpCastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  10/27 - P1 - ComputationCastOpCastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  11/27 - P1 - DivByMulScalarPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  12/27 - P1 - ExpandBroadcastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  13/27 - P1 - ExpandSwapPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  14/27 - P1 - MatMulReshape2Of3Pattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  15/27 - P1 - MulMulMulScalarPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  16/27 - P1 - ReduceReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  17/27 - P1 - ReduceSumNormalizePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  18/27 - P1 - Reshape2Of3Pattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  19/27 - P1 - ReshapeMatMulReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  20/27 - P1 - ReshapeReshapeBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  21/27 - P1 - RotaryConcatPartPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  22/27 - P1 - SlicesSplitPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  23/27 - P1 - Sub1MulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  24/27 - P1 - SwitchOrderBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  25/27 - P1 - TransposeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  26/27 - P1 - TransposeReshapeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  27/27 - P1 - UnsqueezeEqualPattern()
    [GraphBuilderPatternOptimization.optimize] iteration 0: 5 nodes, priority=0
    [IdentityPattern.match] NONE - line: 157:experimental_experiment.xoptim.patterns.onnx_any, op_type=Transpose, name=t
    [IdentityPattern.match] NONE - line: 157:experimental_experiment.xoptim.patterns.onnx_any, op_type=Transpose, name=t2
    [TransposeTransposePattern.match] NONE - line: 49:experimental_experiment.xoptim.patterns.onnx_transpose, op_type=Transpose, name=t
    [TransposeTransposePattern.match] NONE - line: 49:experimental_experiment.xoptim.patterns.onnx_transpose, op_type=Transpose, name=t2
    [GraphBuilderPatternOptimization.optimize] done all: -0 +0 nodes
    [GraphBuilder.remove_identity_nodes] starts with 5
    [GraphBuilder.remove_identity_nodes] found 0 replacements
    [GraphBuilder.remove_identity_nodes] kept 5 nodes
    [GraphBuilder.remove_identity_nodes] ends with 5 nodes in 1.7099999240599573e-05 seconds
    [GraphBuilderPatternOptimization.optimize] increase priority to 1
    [GraphBuilderPatternOptimization.optimize] iteration 1: 5 nodes, priority=1
    [IdentityPattern.match] NONE - line: 157:experimental_experiment.xoptim.patterns.onnx_any, op_type=Transpose, name=t
    [IdentityPattern.match] NONE - line: 157:experimental_experiment.xoptim.patterns.onnx_any, op_type=Transpose, name=t2
    [GraphBuilderPatternOptimization.optimize] match=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm']
    [GraphBuilderPatternOptimization.optimize] match=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm']
    [TransposeTransposePattern.match] NONE - line: 49:experimental_experiment.xoptim.patterns.onnx_transpose, op_type=Transpose, name=t
    [TransposeTransposePattern.match] NONE - line: 49:experimental_experiment.xoptim.patterns.onnx_transpose, op_type=Transpose, name=t2
    [GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*TransposeMatMulPattern - time=0.000 | max_time=TransposeMatMulPattern:0.000
    [GraphBuilderPatternOptimization.optimize] apply MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm'], inputs: {'p_layers_0_weight', 'p_layers_0_bias', 'x', 't'}, outputs: {'t', 'addmm'}
    [GraphBuilderPatternOptimization.apply_match] MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm']
      - Transpose: ['p_layers_0_weight'] -> ['t']
      - Gemm: ['x', 't', 'p_layers_0_bias'] -> ['addmm']
      + Gemm: ['x', 'p_layers_0_weight', 'p_layers_0_bias'] -> ['addmm']
    [GraphBuilder-KHK.set_type] addmm:1
    [GraphBuilderPatternOptimization.apply_match] MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm'] applied.
    [GraphBuilderPatternOptimization.optimize] - add ['Gemm']
    [GraphBuilderPatternOptimization.optimize] done MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm']: -2 +1 nodes
    [GraphBuilderPatternOptimization.optimize] removed outputs {'t'}
    [GraphBuilderPatternOptimization.optimize] apply MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm'], inputs: {'t_1', 'relu', 'p_layers_2_weight', 'p_layers_2_bias'}, outputs: {'t_1', 'output_0'}
    [GraphBuilderPatternOptimization.apply_match] MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm']
      - Transpose: ['p_layers_2_weight'] -> ['t_1']
      - Gemm: ['relu', 't_1', 'p_layers_2_bias'] -> ['output_0']
      + Gemm: ['relu', 'p_layers_2_weight', 'p_layers_2_bias'] -> ['output_0']
    [GraphBuilder-KHK.set_type] output_0:1
    [GraphBuilderPatternOptimization.apply_match] MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm'] applied.
    [GraphBuilderPatternOptimization.optimize] - add ['Gemm']
    [GraphBuilderPatternOptimization.optimize] done MatchResult: TransposeMatMulPattern replaces ['Transpose', 'Gemm']: -2 +1 nodes
    [GraphBuilderPatternOptimization.optimize] removed outputs {'t_1'}
    [GraphBuilderPatternOptimization.optimize] done all: -4 +2 nodes
    [GraphBuilder.remove_identity_nodes] starts with 3
    [GraphBuilder.remove_identity_nodes] found 0 replacements
    [GraphBuilder.remove_identity_nodes] kept 3 nodes
    [GraphBuilder.remove_identity_nodes] ends with 3 nodes in 1.1900003300979733e-05 seconds
    [GraphBuilderPatternOptimization.optimize] iteration 2: 3 nodes, priority=1
    [TransposeMatMulPattern.match] NONE - line: 490:experimental_experiment.xoptim.patterns.onnx_matmul, op_type=Gemm, name=TransposeMatMulPattern--addmm
    [TransposeMatMulPattern.match] NONE - line: 490:experimental_experiment.xoptim.patterns.onnx_matmul, op_type=Gemm, name=TransposeMatMulPattern--addmm2
    [GraphBuilderPatternOptimization.optimize] done all: -0 +0 nodes
    [GraphBuilder.remove_identity_nodes] starts with 3
    [GraphBuilder.remove_identity_nodes] found 0 replacements
    [GraphBuilder.remove_identity_nodes] kept 3 nodes
    [GraphBuilder.remove_identity_nodes] ends with 3 nodes in 1.410000550094992e-05 seconds
    [GraphBuilderPatternOptimization.optimize] done after 3 iterations with 3 nodes in 0.001
        STAT apply_TransposeMatMulPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0002466000005370006
        STAT build_for_pattern +0 -0 #it=3 maxmatch=0 i=0 - time=9.630000567995012e-05
        STAT check_pattern_A +0 -0 #it=1 maxmatch=0 i=0 - time=2.1899999410379678e-05
        STAT check_pattern_B +0 -0 #it=3 maxmatch=0 i=0 - time=3.0800001695752144e-05
        STAT match_CastCastBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.899995009414852e-06
        STAT match_CastOpCastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=1.1399992217775434e-05
        STAT match_CastPattern +0 -0 #it=3 maxmatch=0 i=0 - time=2.040000254055485e-05
        STAT match_ComputationCastOpCastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=2.1299994841683656e-05
        STAT match_DivByMulScalarPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.099988917820156e-06
        STAT match_ExpandBroadcastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.499998901039362e-06
        STAT match_ExpandPattern +0 -0 #it=3 maxmatch=0 i=0 - time=1.2699994840659201e-05
        STAT match_ExpandSwapPattern +0 -0 #it=2 maxmatch=0 i=0 - time=6.900001608300954e-06
        STAT match_IdentityPattern +0 -0 #it=3 maxmatch=0 i=0 - time=4.940000508213416e-05
        STAT match_MatMulReshape2Of3Pattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.700000762473792e-06
        STAT match_MulMulMulScalarPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.90000194683671e-06
        STAT match_ReduceReshapePattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.90000956133008e-06
        STAT match_ReduceSumNormalizePattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.299990779254586e-06
        STAT match_Reshape2Of3Pattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.700008038431406e-06
        STAT match_ReshapeMatMulReshapePattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.299997378140688e-06
        STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.700000423938036e-06
        STAT match_ReshapeReshapePattern +0 -0 #it=3 maxmatch=0 i=0 - time=1.2100004823878407e-05
        STAT match_RotaryConcatPartPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.000002708286047e-06
        STAT match_SameChildrenPattern +0 -0 #it=3 maxmatch=0 i=0 - time=2.4500004656147212e-05
        STAT match_SlicesSplitPattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.199997293762863e-06
        STAT match_Sub1MulPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.09999619377777e-06
        STAT match_SwitchOrderBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.700000423938036e-06
        STAT match_TransposeMatMulPattern +0 -0 #it=2 maxmatch=2 i=2 - time=0.00016340000001946464
        STAT match_TransposeReshapeMatMulPattern +0 -0 #it=2 maxmatch=2 i=0 - time=1.069999416358769e-05
        STAT match_TransposeTransposePattern +0 -0 #it=3 maxmatch=2 i=0 - time=3.309999738121405e-05
        STAT match_UnsqueezeEqualPattern +0 -0 #it=2 maxmatch=2 i=0 - time=7.699993147980422e-06
        STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=3 maxmatch=2 i=0 - time=1.1900003300979733e-05
        STAT remove_identity_nodes +0 -0 #it=3 maxmatch=0 i=0 - time=7.139999797800556e-05
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--
         INPUT:   1 x 1t
        OUTPUT:   1 x 1t
          INIT:   4 x 1t
          NODE:   2 x Gemm
          NODE:   1 x Relu
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--DETAILED--
         INPUT:   1 x 1t[3x10]
        OUTPUT:   1 x 1t[3x1]
          INIT:   1 x 1t[1]
          INIT:   1 x 1t[1x32]
          INIT:   1 x 1t[32]
          INIT:   1 x 1t[32x10]
          NODE:   1 x Gemm -SIG- 1t[3x10], 1t[32x10], 1t[32]
          NODE:   1 x Gemm -SIG- 1t[3x32], 1t[1x32], 1t[1]
          NODE:   1 x Relu -SIG- 1t[3x32]
    [GraphBuilder] done with 3 nodes in 0.002
        STAT apply_TransposeMatMulPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.0002466000005370006
        STAT build_for_pattern +0 -0 #it=3 maxmatch=0 i=0 - time=9.630000567995012e-05
        STAT check_A +0 -0 #it=0 maxmatch=0 i=0 - time=1.410000550094992e-05
        STAT check_B +0 -0 #it=0 maxmatch=0 i=0 - time=1.0700001439545304e-05
        STAT check_C +0 -0 #it=0 maxmatch=0 i=0 - time=9.60000033956021e-06
        STAT check_F +0 -0 #it=0 maxmatch=0 i=0 - time=9.79999458650127e-06
        STAT check_G +0 -0 #it=0 maxmatch=0 i=0 - time=7.499998901039362e-06
        STAT check_pattern_A +0 -0 #it=1 maxmatch=0 i=0 - time=2.1899999410379678e-05
        STAT check_pattern_B +0 -0 #it=3 maxmatch=0 i=0 - time=3.0800001695752144e-05
        STAT match_CastCastBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.899995009414852e-06
        STAT match_CastOpCastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=1.1399992217775434e-05
        STAT match_CastPattern +0 -0 #it=3 maxmatch=0 i=0 - time=2.040000254055485e-05
        STAT match_ComputationCastOpCastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=2.1299994841683656e-05
        STAT match_DivByMulScalarPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.099988917820156e-06
        STAT match_ExpandBroadcastPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.499998901039362e-06
        STAT match_ExpandPattern +0 -0 #it=3 maxmatch=0 i=0 - time=1.2699994840659201e-05
        STAT match_ExpandSwapPattern +0 -0 #it=2 maxmatch=0 i=0 - time=6.900001608300954e-06
        STAT match_IdentityPattern +0 -0 #it=3 maxmatch=0 i=0 - time=4.940000508213416e-05
        STAT match_MatMulReshape2Of3Pattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.700000762473792e-06
        STAT match_MulMulMulScalarPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.90000194683671e-06
        STAT match_ReduceReshapePattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.90000956133008e-06
        STAT match_ReduceSumNormalizePattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.299990779254586e-06
        STAT match_Reshape2Of3Pattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.700008038431406e-06
        STAT match_ReshapeMatMulReshapePattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.299997378140688e-06
        STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.700000423938036e-06
        STAT match_ReshapeReshapePattern +0 -0 #it=3 maxmatch=0 i=0 - time=1.2100004823878407e-05
        STAT match_RotaryConcatPartPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.000002708286047e-06
        STAT match_SameChildrenPattern +0 -0 #it=3 maxmatch=0 i=0 - time=2.4500004656147212e-05
        STAT match_SlicesSplitPattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.199997293762863e-06
        STAT match_Sub1MulPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.09999619377777e-06
        STAT match_SwitchOrderBinaryPattern +0 -0 #it=2 maxmatch=0 i=0 - time=7.700000423938036e-06
        STAT match_TransposeMatMulPattern +0 -0 #it=2 maxmatch=2 i=2 - time=0.00016340000001946464
        STAT match_TransposeReshapeMatMulPattern +0 -0 #it=2 maxmatch=2 i=0 - time=1.069999416358769e-05
        STAT match_TransposeTransposePattern +0 -0 #it=3 maxmatch=2 i=0 - time=3.309999738121405e-05
        STAT match_UnsqueezeEqualPattern +0 -0 #it=2 maxmatch=2 i=0 - time=7.699993147980422e-06
        STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=3 maxmatch=2 i=0 - time=1.1900003300979733e-05
        STAT pattern_optimization +0 -2 #it=0 maxmatch=0 i=0 - time=0.001451299998734612
        STAT remove_identity_nodes +0 -0 #it=3 maxmatch=0 i=0 - time=9.7999996796716e-05
        STAT remove_unused +0 -0 #it=0 maxmatch=0 i=0 - time=4.6299988753162324e-05
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--
         INPUT:   1 x 1t
        OUTPUT:   1 x 1t
          INIT:   4 x 1t
          NODE:   2 x Gemm
          NODE:   1 x Relu
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--DETAILED--
         INPUT:   1 x 1t[3x10]
        OUTPUT:   1 x 1t[3x1]
          INIT:   1 x 1t[1]
          INIT:   1 x 1t[1x32]
          INIT:   1 x 1t[32]
          INIT:   1 x 1t[32x10]
          NODE:   1 x Gemm -SIG- 1t[3x10], 1t[32x10], 1t[32]
          NODE:   1 x Gemm -SIG- 1t[3x32], 1t[1x32], 1t[1]
          NODE:   1 x Relu -SIG- 1t[3x32]
    [GraphBuilder-KHK.to_onnx] make_model
    [GraphBuilder-KHK._build_initializers] start with 4 initializers, large_model=False, external_threshold=1024
    [GraphBuilder-KHK._build_initializers] switch low/high order
    [GraphBuilder-KHK._build_initializers] TensorProto-p_layers_0_weight:1[(32, 10)]
    [GraphBuilder-KHK._build_initializers] TensorProto-p_layers_0_bias:1[(32,)]
    [GraphBuilder-KHK._build_initializers] TensorProto-p_layers_2_weight:1[(1, 32)]
    [GraphBuilder-KHK._build_initializers] TensorProto-p_layers_2_bias:1[(1,)]
    [GraphBuilder-KHK._build_initializers] done in 1.2999953469261527e-06s with 4 initializers, 0 large initializers

Select the pattern to use

Class OptimizationOptions is used to enable or disable patterns.

<<<

import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions

onx = onnx.load("temp_doc_mlp.onnx")

gr = GraphBuilder(
    onx,
    infer_shapes=True,
    optimization_options=OptimizationOptions(
        patterns="TransposeTranspose,TransposeMatMul", verbose=1
    ),
)
opt_onx = gr.to_onnx(optimize=True)

>>>

    [GraphBuilderPatternOptimization.optimize] start with 5 nodes and 2 patterns, priorities=[0, 1]
    [GraphBuilderPatternOptimization.optimize] use pattern   1/2 - P0 - TransposeTransposePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   2/2 - P1 - TransposeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] iteration 0: 5 nodes, priority=0
    [GraphBuilderPatternOptimization.optimize] increase priority to 1
    [GraphBuilderPatternOptimization.optimize] iteration 1: 5 nodes, priority=1
    [GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*TransposeMatMulPattern - time=0.000 | max_time=TransposeMatMulPattern:0.000
    [GraphBuilderPatternOptimization.optimize] iteration 2: 3 nodes, priority=1
    [GraphBuilderPatternOptimization.optimize] done after 3 iterations with 3 nodes in 0.001
    [GraphBuilder] done with 3 nodes in 0.001
        STAT apply_TransposeMatMulPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00025289999757660553
        STAT build_for_pattern +0 -0 #it=3 maxmatch=0 i=0 - time=8.569999772589654e-05
        STAT check_A +0 -0 #it=0 maxmatch=0 i=0 - time=2.4300003133248538e-05
        STAT check_B +0 -0 #it=0 maxmatch=0 i=0 - time=2.559999848017469e-05
        STAT check_C +0 -0 #it=0 maxmatch=0 i=0 - time=1.5099998563528061e-05
        STAT check_F +0 -0 #it=0 maxmatch=0 i=0 - time=8.800001523923129e-06
        STAT check_G +0 -0 #it=0 maxmatch=0 i=0 - time=7.800001185387373e-06
        STAT check_pattern_A +0 -0 #it=1 maxmatch=0 i=0 - time=2.2899999748915434e-05
        STAT check_pattern_B +0 -0 #it=3 maxmatch=0 i=0 - time=3.760000254260376e-05
        STAT match_TransposeMatMulPattern +0 -0 #it=2 maxmatch=2 i=2 - time=7.900000491645187e-05
        STAT match_TransposeTransposePattern +0 -0 #it=3 maxmatch=0 i=0 - time=5.999999848427251e-05
        STAT pattern_optimization +0 -2 #it=0 maxmatch=0 i=0 - time=0.0008443999977316707
        STAT remove_identity_nodes +0 -0 #it=3 maxmatch=0 i=0 - time=0.00010399999882793054
        STAT remove_unused +0 -0 #it=0 maxmatch=0 i=0 - time=6.520000169984996e-05
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--
         INPUT:   1 x 1t
        OUTPUT:   1 x 1t
          INIT:   4 x 1t
          NODE:   2 x Gemm
          NODE:   1 x Relu
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--DETAILED--
         INPUT:   1 x 1t[3x10]
        OUTPUT:   1 x 1t[3x1]
          INIT:   1 x 1t[1]
          INIT:   1 x 1t[1x32]
          INIT:   1 x 1t[32]
          INIT:   1 x 1t[32x10]
          NODE:   1 x Gemm -SIG- 1t[3x10], 1t[32x10], 1t[32]
          NODE:   1 x Gemm -SIG- 1t[3x32], 1t[1x32], 1t[1]
          NODE:   1 x Relu -SIG- 1t[3x32]

There exists some predefined lists of patterns:

  • default: includes all patterns using only standard onnx patterns.

  • onnxruntime: patterns specific to onnxruntime, the final model may be executed by onnxruntime and possibly only onnxruntime as it may introduce patterns from Supported Operators and Data Types.

<<<

import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions

onx = onnx.load("temp_doc_mlp.onnx")

gr = GraphBuilder(
    onx,
    infer_shapes=True,
    optimization_options=OptimizationOptions(patterns="default+onnxruntime", verbose=1),
)
opt_onx = gr.to_onnx(optimize=True)

>>>

    [GraphBuilderPatternOptimization.optimize] start with 5 nodes and 33 patterns, priorities=[0, 1, 2, 3]
    [GraphBuilderPatternOptimization.optimize] use pattern   1/33 - P0 - CastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   2/33 - P0 - ExpandPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   3/33 - P0 - IdentityPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   4/33 - P0 - ReshapeReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   5/33 - P0 - SameChildrenPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   6/33 - P0 - TransposeTransposePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   7/33 - P0 - UnsqueezeUnsqueezePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   8/33 - P1 - CastCastBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern   9/33 - P1 - CastOpCastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  10/33 - P1 - ComputationCastOpCastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  11/33 - P1 - DivByMulScalarPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  12/33 - P1 - ExpandBroadcastPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  13/33 - P1 - ExpandSwapPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  14/33 - P1 - MatMulReshape2Of3Pattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  15/33 - P1 - MulMulMulScalarPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  16/33 - P1 - ReduceReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  17/33 - P1 - ReduceSumNormalizePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  18/33 - P1 - Reshape2Of3Pattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  19/33 - P1 - ReshapeMatMulReshapePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  20/33 - P1 - ReshapeReshapeBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  21/33 - P1 - RotaryConcatPartPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  22/33 - P1 - SimplifiedLayerNormalizationPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  23/33 - P1 - SlicesSplitPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  24/33 - P1 - SoftmaxGradPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  25/33 - P1 - Sub1MulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  26/33 - P1 - SwitchOrderBinaryPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  27/33 - P1 - TransposeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  28/33 - P1 - TransposeReshapeMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  29/33 - P1 - UnsqueezeEqualPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  30/33 - P2 - FusedMatMulDivPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  31/33 - P2 - FusedMatMulPattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  32/33 - P3 - FusedMatMulTransposePattern()
    [GraphBuilderPatternOptimization.optimize] use pattern  33/33 - P3 - FusedMatMulx2Pattern()
    [GraphBuilderPatternOptimization.optimize] iteration 0: 5 nodes, priority=0
    [GraphBuilderPatternOptimization.optimize] increase priority to 1
    [GraphBuilderPatternOptimization.optimize] iteration 1: 5 nodes, priority=1
    [GraphBuilderPatternOptimization.optimize] applies 2 matches, 2*TransposeMatMulPattern - time=0.000 | max_time=TransposeMatMulPattern:0.000
    [GraphBuilderPatternOptimization.optimize] iteration 2: 3 nodes, priority=1
    [GraphBuilderPatternOptimization.optimize] increase priority to 2
    [GraphBuilderPatternOptimization.optimize] iteration 3: 3 nodes, priority=2
    [GraphBuilderPatternOptimization.optimize] increase priority to 3
    [GraphBuilderPatternOptimization.optimize] iteration 4: 3 nodes, priority=3
    [GraphBuilderPatternOptimization.optimize] done after 5 iterations with 3 nodes in 0.001
    [GraphBuilder] done with 3 nodes in 0.002
        STAT apply_TransposeMatMulPattern +2 -4 #it=1 maxmatch=1 i=2 - time=0.00023330000112764537
        STAT build_for_pattern +0 -0 #it=5 maxmatch=0 i=0 - time=0.00015189999976428226
        STAT check_A +0 -0 #it=0 maxmatch=0 i=0 - time=1.5300000086426735e-05
        STAT check_B +0 -0 #it=0 maxmatch=0 i=0 - time=1.0599993402138352e-05
        STAT check_C +0 -0 #it=0 maxmatch=0 i=0 - time=1.0600000678095967e-05
        STAT check_F +0 -0 #it=0 maxmatch=0 i=0 - time=9.700001101009548e-06
        STAT check_G +0 -0 #it=0 maxmatch=0 i=0 - time=8.09999619377777e-06
        STAT check_pattern_A +0 -0 #it=1 maxmatch=0 i=0 - time=2.2600004740525037e-05
        STAT check_pattern_B +0 -0 #it=3 maxmatch=0 i=0 - time=3.220000508008525e-05
        STAT match_CastCastBinaryPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.6699996194802225e-05
        STAT match_CastOpCastPattern +0 -0 #it=4 maxmatch=0 i=0 - time=2.160000440198928e-05
        STAT match_CastPattern +0 -0 #it=5 maxmatch=0 i=0 - time=3.35000004270114e-05
        STAT match_ComputationCastOpCastPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.569999585626647e-05
        STAT match_DivByMulScalarPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.569999585626647e-05
        STAT match_ExpandBroadcastPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.5199999324977398e-05
        STAT match_ExpandPattern +0 -0 #it=5 maxmatch=0 i=0 - time=2.0999992557335645e-05
        STAT match_ExpandSwapPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.4199998986441642e-05
        STAT match_FusedMatMulDivPattern +0 -0 #it=2 maxmatch=0 i=0 - time=8.500006515532732e-06
        STAT match_FusedMatMulPattern +0 -0 #it=2 maxmatch=0 i=0 - time=9.2999980552122e-06
        STAT match_FusedMatMulTransposePattern +0 -0 #it=1 maxmatch=0 i=0 - time=4.900000931229442e-06
        STAT match_FusedMatMulx2Pattern +0 -0 #it=1 maxmatch=0 i=0 - time=5.399997462518513e-06
        STAT match_IdentityPattern +0 -0 #it=5 maxmatch=0 i=0 - time=4.889999399892986e-05
        STAT match_MatMulReshape2Of3Pattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.6599988157395273e-05
        STAT match_MulMulMulScalarPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.5899997379165143e-05
        STAT match_ReduceReshapePattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.7599995771888644e-05
        STAT match_ReduceSumNormalizePattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.6600002709310502e-05
        STAT match_Reshape2Of3Pattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.6900012269616127e-05
        STAT match_ReshapeMatMulReshapePattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.5900004655122757e-05
        STAT match_ReshapeReshapeBinaryPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.529999281046912e-05
        STAT match_ReshapeReshapePattern +0 -0 #it=5 maxmatch=0 i=0 - time=1.999999221879989e-05
        STAT match_RotaryConcatPartPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.4899997040629387e-05
        STAT match_SameChildrenPattern +0 -0 #it=5 maxmatch=0 i=0 - time=4.6600005589425564e-05
        STAT match_SimplifiedLayerNormalizationPattern +0 -0 #it=4 maxmatch=2 i=0 - time=1.6699996194802225e-05
        STAT match_SlicesSplitPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.68999977177009e-05
        STAT match_SoftmaxGradPattern +0 -0 #it=4 maxmatch=2 i=0 - time=1.5600002370774746e-05
        STAT match_Sub1MulPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.5499994333367795e-05
        STAT match_SwitchOrderBinaryPattern +0 -0 #it=4 maxmatch=0 i=0 - time=1.5700003132224083e-05
        STAT match_TransposeMatMulPattern +0 -0 #it=4 maxmatch=2 i=2 - time=9.029999637277797e-05
        STAT match_TransposeReshapeMatMulPattern +0 -0 #it=4 maxmatch=2 i=0 - time=1.7300000763498247e-05
        STAT match_TransposeTransposePattern +0 -0 #it=5 maxmatch=2 i=0 - time=3.099998866673559e-05
        STAT match_UnsqueezeEqualPattern +0 -0 #it=4 maxmatch=2 i=0 - time=1.4699988241773099e-05
        STAT match_UnsqueezeUnsqueezePattern +0 -0 #it=5 maxmatch=2 i=0 - time=1.9699989934451878e-05
        STAT pattern_optimization +0 -2 #it=0 maxmatch=0 i=0 - time=0.0016413999983342364
        STAT remove_identity_nodes +0 -0 #it=3 maxmatch=0 i=0 - time=8.019999950192869e-05
        STAT remove_unused +0 -0 #it=0 maxmatch=0 i=0 - time=5.4799995268695056e-05
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--
         INPUT:   1 x 1t
        OUTPUT:   1 x 1t
          INIT:   4 x 1t
          NODE:   2 x Gemm
          NODE:   1 x Relu
    --MODEL: 3 nodes, 1 inputs, 1 outputs, 4 initializers--DETAILED--
         INPUT:   1 x 1t[3x10]
        OUTPUT:   1 x 1t[3x1]
          INIT:   1 x 1t[1]
          INIT:   1 x 1t[1x32]
          INIT:   1 x 1t[32]
          INIT:   1 x 1t[32x10]
          NODE:   1 x Gemm -SIG- 1t[3x10], 1t[32x10], 1t[32]
          NODE:   1 x Gemm -SIG- 1t[3x32], 1t[1x32], 1t[1]
          NODE:   1 x Relu -SIG- 1t[3x32]

Statistics

This can be used to see when a pattern is applied and how long it takes.

<<<

import pandas
import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions

onx = onnx.load("temp_doc_mlp.onnx")

gr = GraphBuilder(
    onx,
    infer_shapes=True,
    optimization_options=OptimizationOptions(patterns="default"),
)
stat = gr.optimize()

print(pandas.DataFrame(stat))

>>>

                      pattern   time_in  removed  added  iteration  instances  match_index
    0                 check_A  0.000016      NaN    NaN        NaN        NaN          NaN
    1   remove_identity_nodes  0.000024      0.0    0.0        NaN        NaN          NaN
    2                 check_B  0.000011      NaN    NaN        NaN        NaN          NaN
    3           remove_unused  0.000030      0.0    NaN        NaN        NaN          NaN
    4                 check_C  0.000011      NaN    NaN        NaN        NaN          NaN
    ..                    ...       ...      ...    ...        ...        ...          ...
    78      build_for_pattern  0.000023      NaN    NaN        2.0        NaN          NaN
    79   pattern_optimization  0.001070      2.0    NaN        NaN        NaN          NaN
    80                check_F  0.000008      NaN    NaN        NaN        NaN          NaN
    81          remove_unused  0.000020      0.0    NaN        NaN        NaN          NaN
    82                check_G  0.000008      NaN    NaN        NaN        NaN          NaN
    
    [83 rows x 7 columns]

It can be aggregated:

<<<

import pandas
import onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions

onx = onnx.load("temp_doc_mlp.onnx")

gr = GraphBuilder(
    onx,
    infer_shapes=True,
    optimization_options=OptimizationOptions(patterns="default"),
)
stat = gr.optimize()

df = pandas.DataFrame(stat)
for c in df.columns:
    if "time" not in c and "pattern" not in c:
        df[c] = df[c].fillna(0).astype(int)
aggs = {
    "time_in": "sum",
    "added": "sum",
    "removed": "sum",
    "iteration": "max",
    "match_index": "max",
    "instances": "sum",
}
print(df.groupby("pattern").agg(aggs))

>>>

                                          time_in  added  removed  iteration  match_index  instances
    pattern                                                                                         
    apply_TransposeMatMulPattern         0.000234      2        4          1            1          2
    build_for_pattern                    0.000071      0        0          2            0          0
    check_A                              0.000015      0        0          0            0          0
    check_B                              0.000010      0        0          0            0          0
    check_C                              0.000010      0        0          0            0          0
    check_F                              0.000008      0        0          0            0          0
    check_G                              0.000007      0        0          0            0          0
    check_pattern_A                      0.000022      0        0          1            0          0
    check_pattern_B                      0.000030      0        0          2            0          0
    match_CastCastBinaryPattern          0.000009      0        0          2            0          0
    match_CastOpCastPattern              0.000012      0        0          2            0          0
    match_CastPattern                    0.000024      0        0          2            0          0
    match_ComputationCastOpCastPattern   0.000008      0        0          2            0          0
    match_DivByMulScalarPattern          0.000008      0        0          2            0          0
    match_ExpandBroadcastPattern         0.000008      0        0          2            0          0
    match_ExpandPattern                  0.000014      0        0          2            0          0
    match_ExpandSwapPattern              0.000007      0        0          2            0          0
    match_IdentityPattern                0.000036      0        0          2            0          0
    match_MatMulReshape2Of3Pattern       0.000009      0        0          2            0          0
    match_MulMulMulScalarPattern         0.000008      0        0          2            0          0
    match_ReduceReshapePattern           0.000009      0        0          2            0          0
    match_ReduceSumNormalizePattern      0.000009      0        0          2            0          0
    match_Reshape2Of3Pattern             0.000009      0        0          2            0          0
    match_ReshapeMatMulReshapePattern    0.000008      0        0          2            0          0
    match_ReshapeReshapeBinaryPattern    0.000009      0        0          2            0          0
    match_ReshapeReshapePattern          0.000013      0        0          2            0          0
    match_RotaryConcatPartPattern        0.000008      0        0          2            0          0
    match_SameChildrenPattern            0.000027      0        0          2            0          0
    match_SlicesSplitPattern             0.000009      0        0          2            0          0
    match_Sub1MulPattern                 0.000008      0        0          2            0          0
    match_SwitchOrderBinaryPattern       0.000008      0        0          2            0          0
    match_TransposeMatMulPattern         0.000054      0        0          2            2          2
    match_TransposeReshapeMatMulPattern  0.000009      0        0          2            2          0
    match_TransposeTransposePattern      0.000022      0        0          2            2          0
    match_UnsqueezeEqualPattern          0.000008      0        0          2            2          0
    match_UnsqueezeUnsqueezePattern      0.000013      0        0          2            2          0
    pattern_optimization                 0.000982      0        2          0            0          0
    remove_identity_nodes                0.000072      0        0          2            0          0
    remove_unused                        0.000046      0        0          0            0          0

Shape inference

The optimizers require to know the shapes to ensure they can rewrite some nodes and avoid producing a model which does not return the same results. If it is missing, some patterns cannot match for sure and they will not match.

This information can be built by running shape inference on the onnx models. That’s what is done is the previous examples. However, the best case is when this information comes from torch.

Function to_onnx converts a torch model into ONNX. While doing so, it stores the shape information coming from torch. There is no need to run shape inference on the onnx model it generates before optimizing it.

Available Patterns and API

All patterns may be found at Onnx (default) Patterns and Ort Patterns.

When writing a pattern, walking along the graph or checking the shape is very common. Class GraphBuilderPatternOptimization provides the following methods.

Opsets

Patterns must rewrite using the nodes of the opset defined in the model.

Shapes, Types

Constants

  • is_constant: tells if a node is a constant (it may be a constant, an initializer or any value built on other constants)

  • is_constant_scalar: checks a constant is a scalar and compares its value to a number

  • get_computed_constant: returns the constant, computes it is a constant built from other constants

  • get_attribute: returns an attribute of a node

Graph

Nodes

  • make_node: creates a node without adding it to the graph

  • make_node_check_opset: creates a node without adding it to the graph, deals with some constraints related to opset version