101: Onnx Model Optimization based on Pattern Rewriting

This example shows how to optimize a graph using pattern optimization. The graph was obtained by running a dummy llama model. It is the backward graph.

A model

import os
import onnx
import pandas
from experimental_experiment.helpers import pretty_onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.xbuilder.graph_builder import (
    GraphBuilder,
    OptimizationOptions,
)

filename = (
    os.path.join(os.path.dirname(__file__), "data", "dort-c-custom__1.onnx")
    if "__file__" in globals()
    else "data/dort-c-custom__1.onnx"
)
proto = onnx.load(filename)

print(f"number of nodes: {len(proto.graph.node)}")


print(pretty_onnx(proto))
number of nodes: 215
opset: domain='' version=18
input: name='input0' type=dtype('float32') shape=[1024]
input: name='input1' type=dtype('float32') shape=[1024]
input: name='input2' type=dtype('float32') shape=[1024]
input: name='input3' type=dtype('int64') shape=[2, 1024]
input: name='input4' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input5' type=dtype('float32') shape=[2, 1024, 1]
input: name='input6' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input7' type=dtype('float32') shape=[1024, 1024]
input: name='input8' type=dtype('float32') shape=[2048, 1024]
input: name='input9' type=dtype('float32') shape=[1024, 1024]
input: name='input10' type=dtype('float32') shape=[2048, 1024]
input: name='input11' type=dtype('float32') shape=[1024, 1024]
input: name='input12' type=dtype('float32') shape=[2048, 1024]
input: name='input13' type=dtype('float32') shape=[1024, 512]
input: name='input14' type=dtype('float32') shape=[1024, 512]
input: name='input15' type=dtype('float32') shape=[4, 1024, 512]
input: name='input16' type=dtype('float32') shape=[4, 512, 1024]
input: name='input17' type=dtype('float32') shape=[2, 2, 1024, 1024]
input: name='input18' type=dtype('float32') shape=[4, 1024, 1024]
input: name='input19' type=dtype('float32') shape=[4, 1024, 512]
input: name='input20' type=dtype('float32') shape=[1024, 1024]
input: name='input21' type=dtype('float32') shape=[2048, 1024]
input: name='input22' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input23' type=dtype('float32') shape=[2, 1024, 1]
input: name='input24' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input25' type=dtype('float32') shape=[1024, 1024]
input: name='input26' type=dtype('float32') shape=[2048, 1024]
input: name='input27' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input28' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input29' type=dtype('float32') shape=[1024, 1024]
input: name='input30' type=dtype('float32') shape=[2048, 1024]
input: name='input31' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input32' type=dtype('float32') shape=[1024, 1024]
input: name='input33' type=dtype('float32') shape=[2048, 1024]
input: name='input34' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input35' type=dtype('float32') shape=[2, 1024, 1]
input: name='input36' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input37' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input38' type=dtype('float32') shape=[2, 2, 1024, 512]
input: name='input39' type=dtype('float32') shape=[2, 2, 1024, 512]
init: name='init1_s1_' type=float32 shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_2' type=float32 shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_3' type=float32 shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_4' type=float32 shape=(1,) -- array([0.], dtype=float32)
init: name='init1_s_' type=float32 shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_10' type=float32 shape=() -- array([1024.], dtype=float32)
init: name='init1_s_11' type=float32 shape=() -- array([2.], dtype=float32)
init: name='init1_s_2' type=float32 shape=() -- array([1024.], dtype=float32)
init: name='init1_s_3' type=float32 shape=() -- array([2.], dtype=float32)
init: name='init1_s_4' type=float32 shape=() -- array([1.], dtype=float32)
init: name='init1_s_5' type=float32 shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_6' type=float32 shape=() -- array([1024.], dtype=float32)
init: name='init1_s_7' type=float32 shape=() -- array([2.], dtype=float32)
init: name='init1_s_8' type=float32 shape=() -- array([22.627417], dtype=float32)
init: name='init1_s_9' type=float32 shape=() -- array([-0.5], dtype=float32)
init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])
init: name='init7_s1_-12' type=int64 shape=(1,) -- array([-1])
init: name='init7_s1_-13' type=int64 shape=(1,) -- array([-1])
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])
init: name='init7_s1_02' type=int64 shape=(1,) -- array([0])
init: name='init7_s1_1024' type=int64 shape=(1,) -- array([1024])
init: name='init7_s1_10242' type=int64 shape=(1,) -- array([1024])
init: name='init7_s1_10243' type=int64 shape=(1,) -- array([1024])
init: name='init7_s1_2' type=int64 shape=(1,) -- array([2])
init: name='init7_s1_22' type=int64 shape=(1,) -- array([2])
init: name='init7_s1_23' type=int64 shape=(1,) -- array([2])
init: name='init7_s1_256' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_2562' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_2563' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_2564' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_3' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_32' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_33' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_34' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_512' type=int64 shape=(1,) -- array([512])
init: name='init7_s1_5122' type=int64 shape=(1,) -- array([512])
init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1])
init: name='init7_s2_0_12' type=int64 shape=(2,) -- array([0, 1])
init: name='init7_s2_0_13' type=int64 shape=(2,) -- array([0, 1])
init: name='init7_s2_1024_10242' type=int64 shape=(2,) -- array([1024, 1024])
init: name='init7_s2_2048_1024' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10242' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10243' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10244' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10245' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10246' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10247' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s3_2_1024_1024' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102410' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102411' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102412' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102413' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102414' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102415' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10242' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10243' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10245' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10246' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10247' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10248' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10249' type=int64 shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_4_1024_1024' type=int64 shape=(3,) -- array([   4, 1024, 1024])
init: name='init7_s3_4_1024_512' type=int64 shape=(3,) -- array([   4, 1024,  512])
init: name='init7_s4_2_1024_2_512' type=int64 shape=(4,) -- array([   2, 1024,    2,  512])
init: name='init7_s4_2_2_1024_1024' type=int64 shape=(4,) -- array([   2,    2, 1024, 1024])
init: name='init7_s4_2_2_1024_256' type=int64 shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_2562' type=int64 shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_2563' type=int64 shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_2564' type=int64 shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_512' type=int64 shape=(4,) -- array([   2,    2, 1024,  512])
init: name='init7_s4_2_2_1024_5122' type=int64 shape=(4,) -- array([   2,    2, 1024,  512])
init: name='init7_s4_2_2_512_1024' type=int64 shape=(4,) -- array([   2,    2,  512, 1024])
init: name='init7_s_-1' type=int64 shape=() -- array([-1])
Constant(value_float=0) -> output_11
Mul(input37, input2) -> _onx_mul0
  Cast(_onx_mul0, to=1) -> mul_13
    Mul(mul_13, input34) -> _onx_mul03
      Cast(_onx_mul03, to=1) -> mul_15
        ReduceSum(mul_15, init7_s1_2, keepdims=1) -> sum_2
          Mul(sum_2, init1_s_) -> _onx_mul05
            Cast(_onx_mul05, to=1) -> mul_17
Mul(input37, input36) -> _onx_mul02
  Cast(_onx_mul02, to=1) -> mul_14
    ReduceSum(mul_14, init7_s2_0_1, keepdims=1) -> sum_1
      Reshape(sum_1, init7_s1_1024) -> output_2
    Mul(mul_13, input35) -> _onx_mul04
      Cast(_onx_mul04, to=1) -> mul_16
Pow(input35, init1_s1_) -> pow_4
  Mul(mul_17, pow_4) -> _onx_mul06
    Cast(_onx_mul06, to=1) -> mul_18
      Expand(mul_18, init7_s3_2_1024_1024) -> expand_5
        Div(expand_5, init1_s_2) -> _onx_div0
          Cast(_onx_div0, to=1) -> div_1
Mul(input34, init1_s_3) -> _onx_mul07
  Cast(_onx_mul07, to=1) -> mul_19
    Mul(div_1, mul_19) -> _onx_mul08
      Cast(_onx_mul08, to=1) -> mul_20
        Add(mul_16, mul_20) -> add_8
          Reshape(add_8, init7_s2_2048_1024) -> view_22
            Transpose(view_22, perm=[1,0]) -> t_8
              MatMul(t_8, input33) -> mm_8
                Transpose(mm_8, perm=[1,0]) -> t_9
                  Transpose(t_9, perm=[1,0]) -> output_10
Transpose(input32, perm=[1,0]) -> t_10
  MatMul(view_22, t_10) -> mm_9
    Reshape(mm_9, init7_s3_2_1024_10242) -> view_23
      Mul(view_23, input28) -> _onx_mul09
        Cast(_onx_mul09, to=1) -> mul_21
          Reshape(mul_21, init7_s2_2048_10242) -> view_24
            Transpose(view_24, perm=[1,0]) -> t_12
              MatMul(t_12, input30) -> mm_10
                Transpose(mm_10, perm=[1,0]) -> t_13
                  Transpose(t_13, perm=[1,0]) -> output_9
      Mul(view_23, input31) -> _onx_mul010
        Cast(_onx_mul010, to=1) -> mul_22
Transpose(input29, perm=[1,0]) -> t_14
  MatMul(view_24, t_14) -> mm_11
    Reshape(mm_11, init7_s3_2_1024_10243) -> view_25
Sigmoid(input27) -> sigmoid
ConstantOfShape(init7_s3_2_1024_10245, value=[1.0]) -> fill
  Sub(fill, sigmoid) -> sub
    Mul(input27, sub) -> _onx_mul011
      Cast(_onx_mul011, to=1) -> mul_23
        Add(mul_23, init1_s_4) -> add_9
  Mul(sigmoid, add_9) -> _onx_mul012
    Cast(_onx_mul012, to=1) -> mul_24
      Mul(mul_22, mul_24) -> _onx_mul013
        Cast(_onx_mul013, to=1) -> mul_25
          Reshape(mul_25, init7_s2_2048_10243) -> view_26
            Transpose(view_26, perm=[1,0]) -> t_16
              MatMul(t_16, input26) -> mm_12
                Transpose(mm_12, perm=[1,0]) -> t_17
                  Transpose(t_17, perm=[1,0]) -> output_8
Transpose(input25, perm=[1,0]) -> t_18
  MatMul(view_26, t_18) -> mm_13
    Reshape(mm_13, init7_s3_2_1024_10246) -> view_27
      Add(view_25, view_27) -> add_10
        Mul(add_10, input1) -> _onx_mul014
          Cast(_onx_mul014, to=1) -> mul_26
            Mul(mul_26, input22) -> _onx_mul016
              Cast(_onx_mul016, to=1) -> mul_28
                ReduceSum(mul_28, init7_s1_22, keepdims=1) -> sum_4
                  Mul(sum_4, init1_s_5) -> _onx_mul018
                    Cast(_onx_mul018, to=1) -> mul_30
        Mul(add_10, input24) -> _onx_mul015
          Cast(_onx_mul015, to=1) -> mul_27
            ReduceSum(mul_27, init7_s2_0_12, keepdims=1) -> sum_3
              Reshape(sum_3, init7_s1_10242) -> output_1
            Mul(mul_26, input23) -> _onx_mul017
              Cast(_onx_mul017, to=1) -> mul_29
          Add(add_8, mul_29) -> add_11
Pow(input23, init1_s1_2) -> pow_6
  Mul(mul_30, pow_6) -> _onx_mul019
    Cast(_onx_mul019, to=1) -> mul_31
      Expand(mul_31, init7_s3_2_1024_10247) -> expand_6
        Div(expand_6, init1_s_6) -> _onx_div02
          Cast(_onx_div02, to=1) -> div_2
Mul(input22, init1_s_7) -> _onx_mul020
  Cast(_onx_mul020, to=1) -> mul_32
    Mul(div_2, mul_32) -> _onx_mul021
      Cast(_onx_mul021, to=1) -> mul_33
        Add(add_11, mul_33) -> add_12
          Reshape(add_12, init7_s2_2048_10244) -> view_29
            Transpose(view_29, perm=[1,0]) -> t_20
              MatMul(t_20, input21) -> mm_14
                Transpose(mm_14, perm=[1,0]) -> t_21
                  Transpose(t_21, perm=[1,0]) -> output_7
Transpose(input20, perm=[1,0]) -> t_22
  MatMul(view_29, t_22) -> mm_15
    Reshape(mm_15, init7_s3_2_1024_10248) -> view_30
      Reshape(view_30, init7_s4_2_1024_2_512) -> view_31
        Transpose(view_31, perm=[0,2,1,3]) -> transpose_5
          Reshape(transpose_5, init7_s3_4_1024_512) -> _unsafe_view_3
Transpose(input18, perm=[0,2,1]) -> transpose_6
  MatMul(transpose_6, _unsafe_view_3) -> bmm_2
    Reshape(bmm_2, init7_s4_2_2_1024_512) -> view_32
      Add(input39, view_32) -> add_13
        Transpose(add_13, perm=[0,2,1,3]) -> transpose_11
          Reshape(transpose_11, init7_s3_2_1024_10249) -> _unsafe_view_4
            Reshape(_unsafe_view_4, init7_s2_2048_10245) -> view_37
              Transpose(view_37, perm=[1,0]) -> t_24
                MatMul(t_24, input12) -> mm_16
                  Transpose(mm_16, perm=[1,0]) -> t_25
                    Transpose(t_25, perm=[1,0]) -> output_6
Transpose(input19, perm=[0,2,1]) -> transpose_7
  MatMul(_unsafe_view_3, transpose_7) -> bmm_3
    Reshape(bmm_3, init7_s4_2_2_1024_1024) -> view_33
      Cast(view_33, to=1) -> _onx_cast0
        Mul(_onx_cast0, input17) -> _onx_mul022
          ReduceSum(_onx_mul022, init7_s1_-1, keepdims=1) -> _onx_reducesum0
            Mul(input17, _onx_reducesum0) -> _onx_mul023
          Sub(_onx_mul022, _onx_mul023) -> _softmax_backward_data
            Div(_softmax_backward_data, init1_s_8) -> div_3
              Reshape(div_3, init7_s3_4_1024_1024) -> view_34
Transpose(input15, perm=[0,2,1]) -> transpose_8
  MatMul(transpose_8, view_34) -> bmm_4
    Reshape(bmm_4, init7_s4_2_2_512_1024) -> view_35
      Transpose(view_35, perm=[0,1,3,2]) -> transpose_10
        Add(input38, transpose_10) -> add_14
          Mul(add_14, input14) -> _onx_mul024
            Cast(_onx_mul024, to=1) -> mul_34
              Slice(mul_34, init7_s1_0, init7_s1_256, init7_s1_3) -> slice_10
                Neg(slice_10) -> neg_2
Transpose(input16, perm=[0,2,1]) -> transpose_9
  MatMul(view_34, transpose_9) -> bmm_5
    Reshape(bmm_5, init7_s4_2_2_1024_5122) -> view_36
      Mul(view_36, input14) -> _onx_mul026
        Cast(_onx_mul026, to=1) -> mul_36
          Slice(mul_36, init7_s1_02, init7_s1_2563, init7_s1_33) -> slice_12
            Neg(slice_12) -> neg_3
Slice(mul_34, init7_s1_2562, init7_s1_512, init7_s1_32) -> slice_11
ConstantOfShape(init7_s4_2_2_1024_256, value=[0.0]) -> _onx_constantofshape0
  Concat(_onx_constantofshape0, neg_2, axis=3) -> _onx_concat0
ConstantOfShape(init7_s4_2_2_1024_2562, value=[0.0]) -> _onx_constantofshape02
  Concat(slice_11, _onx_constantofshape02, axis=3) -> _onx_concat02
    Add(_onx_concat0, _onx_concat02) -> add_15
Mul(add_14, input13) -> _onx_mul025
  Cast(_onx_mul025, to=1) -> mul_35
    Add(add_15, mul_35) -> add_16
      Transpose(add_16, perm=[0,2,1,3]) -> transpose_12
        Reshape(transpose_12, init7_s3_2_1024_102410) -> _unsafe_view_5
          Reshape(_unsafe_view_5, init7_s2_2048_10246) -> view_39
            Transpose(view_39, perm=[1,0]) -> t_28
              MatMul(t_28, input10) -> mm_18
                Transpose(mm_18, perm=[1,0]) -> t_29
                  Transpose(t_29, perm=[1,0]) -> output_5
          Slice(mul_36, init7_s1_2564, init7_s1_5122, init7_s1_34) -> slice_13
ConstantOfShape(init7_s4_2_2_1024_2563, value=[0.0]) -> _onx_constantofshape03
  Concat(_onx_constantofshape03, neg_3, axis=3) -> _onx_concat03
ConstantOfShape(init7_s4_2_2_1024_2564, value=[0.0]) -> _onx_constantofshape04
  Concat(slice_13, _onx_constantofshape04, axis=3) -> _onx_concat04
    Add(_onx_concat03, _onx_concat04) -> add_17
Mul(view_36, input13) -> _onx_mul027
  Cast(_onx_mul027, to=1) -> mul_37
    Add(add_17, mul_37) -> add_18
      Transpose(add_18, perm=[0,2,1,3]) -> transpose_13
        Reshape(transpose_13, init7_s3_2_1024_102411) -> _unsafe_view_6
          Reshape(_unsafe_view_6, init7_s2_2048_10247) -> view_41
            Transpose(view_41, perm=[1,0]) -> t_32
              MatMul(t_32, input8) -> mm_20
                Transpose(mm_20, perm=[1,0]) -> t_33
                  Transpose(t_33, perm=[1,0]) -> output_4
Transpose(input11, perm=[1,0]) -> t_26
  MatMul(view_37, t_26) -> mm_17
    Reshape(mm_17, init7_s3_2_1024_102412) -> view_38
Transpose(input9, perm=[1,0]) -> t_30
  MatMul(view_39, t_30) -> mm_19
    Reshape(mm_19, init7_s3_2_1024_102413) -> view_40
      Add(view_38, view_40) -> add_19
Transpose(input7, perm=[1,0]) -> t_34
  MatMul(view_41, t_34) -> mm_21
    Reshape(mm_21, init7_s3_2_1024_102414) -> view_42
      Add(add_19, view_42) -> add_20
        Mul(add_20, input0) -> _onx_mul028
          Cast(_onx_mul028, to=1) -> mul_38
            Mul(mul_38, input4) -> _onx_mul030
              Cast(_onx_mul030, to=1) -> mul_40
                ReduceSum(mul_40, init7_s1_23, keepdims=1) -> sum_6
                  Mul(sum_6, init1_s_9) -> _onx_mul032
                    Cast(_onx_mul032, to=1) -> mul_42
        Mul(add_20, input6) -> _onx_mul029
          Cast(_onx_mul029, to=1) -> mul_39
            ReduceSum(mul_39, init7_s2_0_13, keepdims=1) -> sum_5
              Reshape(sum_5, init7_s1_10243) -> output_0
            Mul(mul_38, input5) -> _onx_mul031
              Cast(_onx_mul031, to=1) -> mul_41
          Add(add_12, mul_41) -> add_21
Pow(input5, init1_s1_3) -> pow_8
  Mul(mul_42, pow_8) -> _onx_mul033
    Cast(_onx_mul033, to=1) -> mul_43
      Expand(mul_43, init7_s3_2_1024_102415) -> expand_7
        Div(expand_7, init1_s_10) -> _onx_div03
          Cast(_onx_div03, to=1) -> div_4
Mul(input4, init1_s_11) -> _onx_mul034
  Cast(_onx_mul034, to=1) -> mul_44
    Mul(div_4, mul_44) -> _onx_mul035
      Cast(_onx_mul035, to=1) -> mul_45
        Add(add_21, mul_45) -> add_22
Equal(input3, init7_s_-1) -> eq_2
  Unsqueeze(eq_2, init7_s1_-12) -> unsqueeze_6
    Where(unsqueeze_6, init1_s1_4, add_22) -> _onx_where0
Unsqueeze(input3, init7_s1_-13) -> _onx_unsqueeze0
ConstantOfShape(init7_s2_1024_10242, value=[0.0]) -> _onx_constantofshape05
  ScatterND(_onx_constantofshape05, _onx_unsqueeze0, _onx_where0, reduction=b'add') -> _onx_scatternd0
    Identity(_onx_scatternd0) -> output_3
Constant(value_float=0) -> output_12
Constant(value_float=0) -> output_13
Constant(value_float=0) -> output_14
output: name='output_0' type=dtype('float32') shape=[1024]
output: name='output_1' type=dtype('float32') shape=[1024]
output: name='output_2' type=dtype('float32') shape=[1024]
output: name='output_3' type=dtype('float32') shape=[1024, 1024]
output: name='output_4' type=dtype('float32') shape=[1024, 1024]
output: name='output_5' type=dtype('float32') shape=[1024, 1024]
output: name='output_6' type=dtype('float32') shape=[1024, 1024]
output: name='output_7' type=dtype('float32') shape=[1024, 1024]
output: name='output_8' type=dtype('float32') shape=[1024, 1024]
output: name='output_9' type=dtype('float32') shape=[1024, 1024]
output: name='output_10' type=dtype('float32') shape=[1024, 1024]
output: name='output_11' type=dtype('float32') shape=None
output: name='output_12' type=dtype('float32') shape=None
output: name='output_13' type=dtype('float32') shape=None
output: name='output_14' type=dtype('float32') shape=None

And visually.

plot optimize 101

Optimization

gr = GraphBuilder(
    proto,
    infer_shapes_options=True,
    optimization_options=OptimizationOptions(
        patterns="default",
        verbose=1,  # a higher value increases the verbosity when optimizations for patterns
    ),
)
stats = gr.optimize()
df = pandas.DataFrame(stats)
df.to_csv("plot_optimize.csv")
df.to_excel("plot_optimize.xlsx")
df
[GraphBuilder.optimize] start with 214 nodes
[GraphBuilder.optimize] #patterns=44
[GraphBuilderPatternOptimization.optimize] start with 214 nodes, 73 initializers, 44 patterns, priorities=[0, 1]
[GraphBuilderPatternOptimization.optimize] iteration 0: 214 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 48 matches, 37*CastPattern, 4*ReshapeReshapePattern, 7*TransposeTransposePattern - time=0.012 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization.optimize] iteration 1: 159 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] increase priority to 1
[GraphBuilderPatternOptimization.optimize] iteration 2: 159 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 28 matches, 3*MulMulMulScalarPattern, 3*ReduceReshapePattern, 2*Reshape2Of3Pattern, 1*ReshapeReshapeBinaryPattern, 2*MatMulReshape2Of3Pattern, 2*RotaryConcatPartPattern, 1*Sub1MulPattern, 14*TransposeMatMulPattern - time=0.012 | max_time=Sub1MulPattern:0.002
[GraphBuilderPatternOptimization.optimize] iteration 3: 131 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 12 matches, 3*ExpandBroadcastPattern, 1*ReshapeReshapeBinaryPattern, 2*MatMulAddPattern, 2*MatMulReshape2Of3Pattern, 2*SlicesSplitPattern, 2*TransposeReshapeMatMulPattern - time=0.006 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 4: 121 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 6 matches, 1*MatMulAddPattern, 3*SwitchOrderBinaryPattern, 2*TransposeReshapeMatMulPattern - time=0.005 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 5: 121 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: MatMulAddPattern replaces ['Gemm', 'Add'] - time=0.005 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 6: 120 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] done after 7 iterations with 120 nodes in 0.117
[GraphBuilder.optimize] done with 117 nodes in 0.128
pattern time_in removed added iteration instances match_index
0 check_A 0.001523 NaN NaN NaN NaN NaN
1 remove_identity_nodes 0.001873 0.0 0.0 NaN NaN NaN
2 check_B 0.001219 NaN NaN NaN NaN NaN
3 remove_unused 0.002189 0.0 NaN NaN NaN NaN
4 check_C 0.000911 NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ...
458 build_graph_for_pattern 0.000652 NaN NaN 6.0 NaN NaN
459 pattern_optimization 0.119236 94.0 NaN NaN NaN NaN
460 check_F 0.000235 NaN NaN NaN NaN NaN
461 remove_unused 0.000425 3.0 NaN NaN NaN NaN
462 check_G 0.000206 NaN NaN NaN NaN NaN

463 rows × 7 columns



Summary

for c in df.columns:
    if "time" not in c and "pattern" not in c:
        df[c] = df[c].fillna(0).astype(int)

aggs = {
    "time_in": "sum",
    "added": "sum",
    "removed": "sum",
    "iteration": "max",
    "match_index": "max",
    "instances": "sum",
}
print(df.groupby("pattern").agg(aggs))
                                  time_in  added  removed  iteration  match_index  instances
pattern
apply_CastPattern                0.003582     37       37          0           36         37
apply_ExpandBroadcastPattern     0.000236      3        6          3            2          3
apply_MatMulAddPattern           0.000383      5        8          5            5          4
apply_MatMulReshape2Of3Pattern   0.001935     10       12          3           10          4
apply_MulMulMulScalarPattern     0.001868      6        9          2            2          3
...                                   ...    ...      ...        ...          ...        ...
match_UnsqueezeEqualPattern      0.000465      0        0          6           28          0
match_UnsqueezeUnsqueezePattern  0.000840      0        0          6           48          0
pattern_optimization             0.119236      0       94          0            0          0
remove_identity_nodes            0.004658     44       88          2            0          0
remove_unused                    0.002614      0        3          0            0          0

[72 rows x 6 columns]

The total is:

diff = df["added"].sum() - df["removed"].sum()

print(f"number of removed nodes: {-diff}")
number of removed nodes: 191

Conversion to onnx.

optimized_proto = gr.to_onnx(optimize=False)
with open("plot_optimize_101.onnx", "wb") as f:
    f.write(optimized_proto.SerializeToString())

print(f"number of new nodes: {len(optimized_proto.graph.node)}")
number of new nodes: 117

It gives the following.

opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input0' type=dtype('float32') shape=[1024]
input: name='input1' type=dtype('float32') shape=[1024]
input: name='input2' type=dtype('float32') shape=[1024]
input: name='input3' type=dtype('int64') shape=[2, 1024]
input: name='input4' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input5' type=dtype('float32') shape=[2, 1024, 1]
input: name='input6' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input7' type=dtype('float32') shape=[1024, 1024]
input: name='input8' type=dtype('float32') shape=[2048, 1024]
input: name='input9' type=dtype('float32') shape=[1024, 1024]
input: name='input10' type=dtype('float32') shape=[2048, 1024]
input: name='input11' type=dtype('float32') shape=[1024, 1024]
input: name='input12' type=dtype('float32') shape=[2048, 1024]
input: name='input13' type=dtype('float32') shape=[1024, 512]
input: name='input14' type=dtype('float32') shape=[1024, 512]
input: name='input15' type=dtype('float32') shape=[4, 1024, 512]
input: name='input16' type=dtype('float32') shape=[4, 512, 1024]
input: name='input17' type=dtype('float32') shape=[2, 2, 1024, 1024]
input: name='input18' type=dtype('float32') shape=[4, 1024, 1024]
input: name='input19' type=dtype('float32') shape=[4, 1024, 512]
input: name='input20' type=dtype('float32') shape=[1024, 1024]
input: name='input21' type=dtype('float32') shape=[2048, 1024]
input: name='input22' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input23' type=dtype('float32') shape=[2, 1024, 1]
input: name='input24' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input25' type=dtype('float32') shape=[1024, 1024]
input: name='input26' type=dtype('float32') shape=[2048, 1024]
input: name='input27' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input28' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input29' type=dtype('float32') shape=[1024, 1024]
input: name='input30' type=dtype('float32') shape=[2048, 1024]
input: name='input31' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input32' type=dtype('float32') shape=[1024, 1024]
input: name='input33' type=dtype('float32') shape=[2048, 1024]
input: name='input34' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input35' type=dtype('float32') shape=[2, 1024, 1]
input: name='input36' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input37' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input38' type=dtype('float32') shape=[2, 2, 1024, 512]
input: name='input39' type=dtype('float32') shape=[2, 2, 1024, 512]
init: name='init1_s1_' type=float32 shape=(1,) -- array([3.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_)
init: name='init1_s1_2' type=float32 shape=(1,) -- array([3.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_2)
init: name='init1_s1_3' type=float32 shape=(1,) -- array([3.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_3)
init: name='init1_s1_4' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_4)
init: name='init1_s_' type=float32 shape=() -- array([-0.5], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_)
init: name='init1_s_4' type=float32 shape=() -- array([1.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_4)
init: name='init1_s_5' type=float32 shape=() -- array([-0.5], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_5)
init: name='init1_s_8' type=float32 shape=() -- array([22.627417], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_8)
init: name='init1_s_9' type=float32 shape=() -- array([-0.5], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_9)
init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])         -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-1)
init: name='init7_s1_-12' type=int64 shape=(1,) -- array([-1])        -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-12)
init: name='init7_s1_-13' type=int64 shape=(1,) -- array([-1])        -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-13)
init: name='init7_s1_2' type=int64 shape=(1,) -- array([2])           -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_2)
init: name='init7_s1_22' type=int64 shape=(1,) -- array([2])          -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_22)
init: name='init7_s1_23' type=int64 shape=(1,) -- array([2])          -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_23)
init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1])      -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_1)
init: name='init7_s2_0_12' type=int64 shape=(2,) -- array([0, 1])     -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_12)
init: name='init7_s2_0_13' type=int64 shape=(2,) -- array([0, 1])     -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_13)
init: name='init7_s2_1024_10242' type=int64 shape=(2,) -- array([1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_1024_10242)
init: name='init7_s2_2048_1024' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)##GraphBuilder.compute_constant/from(init7_s2_2048_1024)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)
init: name='init7_s2_2048_10242' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)##GraphBuilder.compute_constant/from(init7_s2_2048_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)
init: name='init7_s2_2048_10243' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)##GraphBuilder.compute_constant/from(init7_s2_2048_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)
init: name='init7_s2_2048_10244' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)##GraphBuilder.compute_constant/from(init7_s2_2048_10244)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)
init: name='init7_s2_2048_10245' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)##GraphBuilder.compute_constant/from(init7_s2_2048_10245)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)
init: name='init7_s2_2048_10246' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)##GraphBuilder.compute_constant/from(init7_s2_2048_10246)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)
init: name='init7_s2_2048_10247' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)##GraphBuilder.compute_constant/from(init7_s2_2048_10247)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)
init: name='init7_s3_2_1024_102413' type=int64 shape=(3,) -- array([   2, 1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)##GraphBuilder.compute_constant/from(init7_s3_2_1024_102413)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)
init: name='init7_s3_2_1024_10242' type=int64 shape=(3,) -- array([   2, 1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)
init: name='init7_s3_2_1024_10243' type=int64 shape=(3,) -- array([   2, 1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)
init: name='init7_s4_2_1024_2_512' type=int64 shape=(4,) -- array([   2, 1024,    2,  512])-- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)##GraphBuilder.compute_constant/from(init7_s4_2_1024_2_512)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)
init: name='init7_s_-1' type=int64 shape=() -- array([-1])            -- GraphBuilder._update_structures_with_proto.1/from(init7_s_-1)
init: name='init1_s_12' type=float32 shape=() -- array([0.00195312], dtype=float32)-- MulMulMulScalarPattern.apply.new_cst##MulMulMulScalarPattern.apply.new_cst##MulMulMulScalarPattern.apply.new_cst
init: name='init7_s4_2_2_512_10243' type=int64 shape=(4,) -- array([   2,    2,  512, 1024])-- MatMulReshape2Of3Pattern.apply.shape.2##TransposeReshapeMatMulPattern.apply.shape_name
init: name='init7_s4_2_2_1024_5123' type=int64 shape=(4,) -- array([   2,    2, 1024,  512])-- MatMulReshape2Of3Pattern.apply.shape.2##TransposeReshapeMatMulPattern.apply.shape_name##TransposeReshapeMatMulPattern.apply.shape_name
init: name='init7_s2_256_256' type=int64 shape=(2,) -- array([256, 256])-- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
init: name='init7_s4_2_2_1024_10243' type=int64 shape=(4,) -- array([   2,    2, 1024, 1024])-- TransposeReshapeMatMulPattern.apply.shape_name
Constant(value_float=0) -> output_11
  Identity(output_11) -> output_12
Reshape(input28, init7_s2_2048_10242) -> Reshape2Of3PatternR_input28
Mul(input37, input2) -> _onx_mul0
  Mul(_onx_mul0, input34) -> _onx_mul03
    ReduceSum(_onx_mul03, init7_s1_2, keepdims=1) -> sum_2
      Mul(sum_2, init1_s_) -> _onx_mul05
Mul(input37, input36) -> _onx_mul02
  ReduceSum(_onx_mul02, init7_s2_0_1, keepdims=0) -> output_2
Mul(_onx_mul0, input35) -> _onx_mul04
Pow(input35, init1_s1_) -> pow_4
  Mul(_onx_mul05, pow_4) -> _onx_mul06
    Mul(_onx_mul06, init1_s_12) -> mul-_onx_mul06
      Mul(mul-_onx_mul06, input34) -> _onx_mul08
  Add(_onx_mul04, _onx_mul08) -> add_8
    Reshape(add_8, init7_s2_2048_1024) -> view_22
      Gemm(view_22, input32, transA=0, transB=1) -> mm_9
        Reshape(mm_9, init7_s3_2_1024_10242) -> view_23
          Mul(view_23, input31) -> _onx_mul010
      Gemm(view_22, input33, transA=1, transB=0) -> output_10
  Mul(mm_9, Reshape2Of3PatternR_input28) -> view_24
    Gemm(view_24, input30, transA=1, transB=0) -> output_9
Sigmoid(input27) -> sigmoid
  Mul(input27, sigmoid) -> Sub1MulPattern--_onx_mul011
    Sub(input27, Sub1MulPattern--_onx_mul011) -> _onx_mul011
      Add(_onx_mul011, init1_s_4) -> add_9
  Mul(sigmoid, add_9) -> _onx_mul012
    Mul(_onx_mul010, _onx_mul012) -> _onx_mul013
      Reshape(_onx_mul013, init7_s2_2048_10243) -> view_26
        Gemm(view_26, input25, transA=0, transB=1) -> mm_13
    Gemm(view_24, input29, mm_13, transA=0, transB=1) -> add-mm_11
      Reshape(add-mm_11, init7_s3_2_1024_10243) -> add_10
        Mul(add_10, input1) -> _onx_mul014
          Mul(_onx_mul014, input22) -> _onx_mul016
            ReduceSum(_onx_mul016, init7_s1_22, keepdims=1) -> sum_4
              Mul(sum_4, init1_s_5) -> _onx_mul018
        Gemm(view_26, input26, transA=1, transB=0) -> output_8
Mul(add_10, input24) -> _onx_mul015
  ReduceSum(_onx_mul015, init7_s2_0_12, keepdims=0) -> output_1
Mul(_onx_mul014, input23) -> _onx_mul017
  Add(add_8, _onx_mul017) -> add_11
Pow(input23, init1_s1_2) -> pow_6
  Mul(_onx_mul018, pow_6) -> _onx_mul019
    Mul(_onx_mul019, init1_s_12) -> mul-_onx_mul019
      Mul(mul-_onx_mul019, input22) -> _onx_mul021
    Add(add_11, _onx_mul021) -> add_12
      Reshape(add_12, init7_s2_2048_10244) -> view_29
        Gemm(view_29, input20, transA=0, transB=1) -> mm_15
          Reshape(mm_15, init7_s4_2_1024_2_512) -> view_31
            Transpose(view_31, perm=[0,2,1,3]) -> transpose_5
        Gemm(view_29, input21, transA=1, transB=0) -> output_7
Reshape(input18, init7_s4_2_2_1024_10243) -> TransposeReshapeMatMulPatternL_input18
  Transpose(TransposeReshapeMatMulPatternL_input18, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL_transpose_6
    MatMul(MatMulReshape2Of3PatternL_transpose_6, transpose_5) -> view_32
      Add(input39, view_32) -> add_13
        Transpose(add_13, perm=[0,2,1,3]) -> transpose_11
          Reshape(transpose_11, init7_s2_2048_10245) -> view_37
            Gemm(view_37, input12, transA=1, transB=0) -> output_6
Reshape(input19, init7_s4_2_2_1024_5123) -> TransposeReshapeMatMulPatternL_input19
  Transpose(TransposeReshapeMatMulPatternL_input19, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL__unsafe_view_3
    MatMul(transpose_5, MatMulReshape2Of3PatternL__unsafe_view_3) -> view_33
      Mul(view_33, input17) -> _onx_mul022
        ReduceSum(_onx_mul022, init7_s1_-1, keepdims=1) -> _onx_reducesum0
          Mul(input17, _onx_reducesum0) -> _onx_mul023
        Sub(_onx_mul022, _onx_mul023) -> _softmax_backward_data
          Div(_softmax_backward_data, init1_s_8) -> div_3
Reshape(input15, init7_s4_2_2_1024_5123) -> TransposeReshapeMatMulPatternL_input15
  Transpose(TransposeReshapeMatMulPatternL_input15, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL_transpose_8
    MatMul(MatMulReshape2Of3PatternL_transpose_8, div_3) -> view_35
      Transpose(view_35, perm=[0,1,3,2]) -> transpose_10
        Add(input38, transpose_10) -> add_14
          Mul(add_14, input14) -> _onx_mul024
            Split(_onx_mul024, init7_s2_256_256, axis=3) -> slice_10, slice_11
              Neg(slice_10) -> neg_2
              Concat(slice_11, neg_2, axis=3) -> add_15
Reshape(input16, init7_s4_2_2_512_10243) -> TransposeReshapeMatMulPatternL_input16
  Transpose(TransposeReshapeMatMulPatternL_input16, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL_view_34
    MatMul(div_3, MatMulReshape2Of3PatternL_view_34) -> view_36
      Mul(view_36, input14) -> _onx_mul026
        Split(_onx_mul026, init7_s2_256_256, axis=3) -> slice_12, slice_13
          Neg(slice_12) -> neg_3
          Concat(slice_13, neg_3, axis=3) -> add_17
Mul(add_14, input13) -> _onx_mul025
  Add(add_15, _onx_mul025) -> add_16
    Transpose(add_16, perm=[0,2,1,3]) -> transpose_12
      Reshape(transpose_12, init7_s2_2048_10246) -> view_39
        Gemm(view_39, input10, transA=1, transB=0) -> output_5
      Mul(view_36, input13) -> _onx_mul027
        Add(add_17, _onx_mul027) -> add_18
          Transpose(add_18, perm=[0,2,1,3]) -> transpose_13
            Reshape(transpose_13, init7_s2_2048_10247) -> view_41
              Gemm(view_41, input7, transA=0, transB=1) -> mm_21
        Gemm(view_39, input9, mm_21, transA=0, transB=1) -> MatMulAddPattern--mm_19
          Gemm(view_37, input11, MatMulAddPattern--mm_19, transA=0, transB=1) -> add-Reshape2Of3PatternL_add_19
            Reshape(add-Reshape2Of3PatternL_add_19, init7_s3_2_1024_102413) -> add_20
              Mul(add_20, input0) -> _onx_mul028
                Mul(_onx_mul028, input4) -> _onx_mul030
                  ReduceSum(_onx_mul030, init7_s1_23, keepdims=1) -> sum_6
                    Mul(sum_6, init1_s_9) -> _onx_mul032
              Gemm(view_41, input8, transA=1, transB=0) -> output_4
Mul(add_20, input6) -> _onx_mul029
  ReduceSum(_onx_mul029, init7_s2_0_13, keepdims=0) -> output_0
Mul(_onx_mul028, input5) -> _onx_mul031
  Add(add_12, _onx_mul031) -> add_21
Pow(input5, init1_s1_3) -> pow_8
  Mul(_onx_mul032, pow_8) -> _onx_mul033
    Mul(_onx_mul033, init1_s_12) -> mul-_onx_mul033
      Mul(mul-_onx_mul033, input4) -> _onx_mul035
    Add(add_21, _onx_mul035) -> add_22
Equal(input3, init7_s_-1) -> eq_2
  Unsqueeze(eq_2, init7_s1_-12) -> unsqueeze_6
    Where(unsqueeze_6, init1_s1_4, add_22) -> _onx_where0
Unsqueeze(input3, init7_s1_-13) -> _onx_unsqueeze0
ConstantOfShape(init7_s2_1024_10242, value=[0.0]) -> _onx_constantofshape05
  ScatterND(_onx_constantofshape05, _onx_unsqueeze0, _onx_where0, reduction=b'add') -> output_3
Identity(output_11) -> output_13
Identity(output_11) -> output_14
output: name='output_0' type=dtype('float32') shape=[1024]
output: name='output_1' type=dtype('float32') shape=[1024]
output: name='output_2' type=dtype('float32') shape=[1024]
output: name='output_3' type=dtype('float32') shape=[1024, 1024]
output: name='output_4' type=dtype('float32') shape=[1024, 1024]
output: name='output_5' type=dtype('float32') shape=[1024, 1024]
output: name='output_6' type=dtype('float32') shape=[1024, 1024]
output: name='output_7' type=dtype('float32') shape=[1024, 1024]
output: name='output_8' type=dtype('float32') shape=[1024, 1024]
output: name='output_9' type=dtype('float32') shape=[1024, 1024]
output: name='output_10' type=dtype('float32') shape=[1024, 1024]
output: name='output_11' type=dtype('float32') shape=None
output: name='output_12' type=dtype('float32') shape=None
output: name='output_13' type=dtype('float32') shape=None
output: name='output_14' type=dtype('float32') shape=None

And visually.

plot optimize 101

The first list of patterns optimizes the graph with only standard onnx operators: experimental_experiment.xoptim.patterns. The second list is specific to onnxruntime: experimental_experiment.xoptim.patterns_ort.

Focus on one optimizer

gr = GraphBuilder(
    optimized_proto,
    infer_shapes_options=True,
    optimization_options=OptimizationOptions(
        patterns="SwitchOrderBinary",
        verbose=10,
    ),
)
stats = gr.optimize()
df = pandas.DataFrame(stats)
df.to_csv("plot_optimize.csv")
df.to_excel("plot_optimize.xlsx")
df
[GraphBuilder.optimize] start with 117 nodes
[GraphBuilder.optimize] #patterns=1
[GraphBuilderPatternOptimization.optimize] start with 117 nodes, 36 initializers, 1 patterns, priorities=[1]
[GraphBuilderPatternOptimization.optimize] use pattern   1/1 - P1 - SwitchOrderBinaryPattern()
--

opset: : 18
init: init1_s1_: ?: ?                                                  -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_)
init: init1_s1_2: ?: ?                                                 -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_2)
init: init1_s1_3: ?: ?                                                 -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_3)
init: init1_s1_4: ?: ?                                                 -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_4)
init: init1_s_: ?: ?                                                   -- GraphBuilder._update_structures_with_proto.1/from(init1_s_)
init: init1_s_4: ?: ?                                                  -- GraphBuilder._update_structures_with_proto.1/from(init1_s_4)
init: init1_s_5: ?: ?                                                  -- GraphBuilder._update_structures_with_proto.1/from(init1_s_5)
init: init1_s_8: ?: ?                                                  -- GraphBuilder._update_structures_with_proto.1/from(init1_s_8)
init: init1_s_9: ?: ?                                                  -- GraphBuilder._update_structures_with_proto.1/from(init1_s_9)
init: init7_s1_-1: ?: ?                                                -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-1)
init: init7_s1_-12: ?: ?                                               -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-12)
init: init7_s1_-13: ?: ?                                               -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-13)
init: init7_s1_2: ?: ?                                                 -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_2)
init: init7_s1_22: ?: ?                                                -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_22)
init: init7_s1_23: ?: ?                                                -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_23)
init: init7_s2_0_1: ?: ?                                               -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_1)
init: init7_s2_0_12: ?: ?                                              -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_12)
init: init7_s2_0_13: ?: ?                                              -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_13)
init: init7_s2_1024_10242: ?: ?                                        -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_1024_10242)
init: init7_s2_2048_1024: int64: 2                                     -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)##GraphBuilder.compute_constant/from(init7_s2_2048_1024)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)
init: init7_s2_2048_10242: int64: 2                                    -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)##GraphBuilder.compute_constant/from(init7_s2_2048_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)
init: init7_s2_2048_10243: int64: 2                                    -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)##GraphBuilder.compute_constant/from(init7_s2_2048_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)
init: init7_s2_2048_10244: int64: 2                                    -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)##GraphBuilder.compute_constant/from(init7_s2_2048_10244)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)
init: init7_s2_2048_10245: int64: 2                                    -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)##GraphBuilder.compute_constant/from(init7_s2_2048_10245)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)
init: init7_s2_2048_10246: int64: 2                                    -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)##GraphBuilder.compute_constant/from(init7_s2_2048_10246)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)
init: init7_s2_2048_10247: int64: 2                                    -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)##GraphBuilder.compute_constant/from(init7_s2_2048_10247)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)
init: init7_s3_2_1024_102413: int64: 3                                 -- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)##GraphBuilder.compute_constant/from(init7_s3_2_1024_102413)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)
init: init7_s3_2_1024_10242: int64: 3                                  -- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)
init: init7_s3_2_1024_10243: int64: 3                                  -- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)
init: init7_s4_2_1024_2_512: int64: 4                                  -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)##GraphBuilder.compute_constant/from(init7_s4_2_1024_2_512)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)
init: init7_s_-1: ?: ?                                                 -- GraphBuilder._update_structures_with_proto.1/from(init7_s_-1)
init: init1_s_12: ?: ?                                                 -- GraphBuilder._update_structures_with_proto.1/from(init1_s_12)
init: init7_s4_2_2_512_10243: int64: 4                                 -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_512_10243)##GraphBuilder.compute_constant/from(init7_s4_2_2_512_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_512_10243)
init: init7_s4_2_2_1024_5123: int64: 4                                 -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_5123)##GraphBuilder.compute_constant/from(init7_s4_2_2_1024_5123)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_5123)
init: init7_s2_256_256: ?: ?                                           -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_256_256)
init: init7_s4_2_2_1024_10243: int64: 4                                -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_10243)##GraphBuilder.compute_constant/from(init7_s4_2_2_1024_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_10243)
input:: input0                                                                  |T1: 1024
input:: input1                                                                  |T1: 1024
input:: input2                                                                  |T1: 1024
input:: input3                                                                  |T7: 2 x 1024
input:: input4                                                                  |T1: 2 x 1024 x 1024
input:: input5                                                                  |T1: 2 x 1024 x 1
input:: input6                                                                  |T1: 2 x 1024 x 1024
input:: input7                                                                  |T1: 1024 x 1024
input:: input8                                                                  |T1: 2048 x 1024
input:: input9                                                                  |T1: 1024 x 1024
input:: input10                                                                 |T1: 2048 x 1024
input:: input11                                                                 |T1: 1024 x 1024
input:: input12                                                                 |T1: 2048 x 1024
input:: input13                                                                 |T1: 1024 x 512
input:: input14                                                                 |T1: 1024 x 512
input:: input15                                                                 |T1: 4 x 1024 x 512
input:: input16                                                                 |T1: 4 x 512 x 1024
input:: input17                                                                 |T1: 2 x 2 x 1024 x 1024
input:: input18                                                                 |T1: 4 x 1024 x 1024
input:: input19                                                                 |T1: 4 x 1024 x 512
input:: input20                                                                 |T1: 1024 x 1024
input:: input21                                                                 |T1: 2048 x 1024
input:: input22                                                                 |T1: 2 x 1024 x 1024
input:: input23                                                                 |T1: 2 x 1024 x 1
input:: input24                                                                 |T1: 2 x 1024 x 1024
input:: input25                                                                 |T1: 1024 x 1024
input:: input26                                                                 |T1: 2048 x 1024
input:: input27                                                                 |T1: 2 x 1024 x 1024
input:: input28                                                                 |T1: 2 x 1024 x 1024
input:: input29                                                                 |T1: 1024 x 1024
input:: input30                                                                 |T1: 2048 x 1024
input:: input31                                                                 |T1: 2 x 1024 x 1024
input:: input32                                                                 |T1: 1024 x 1024
input:: input33                                                                 |T1: 2048 x 1024
input:: input34                                                                 |T1: 2 x 1024 x 1024
input:: input35                                                                 |T1: 2 x 1024 x 1
input:: input36                                                                 |T1: 2 x 1024 x 1024
input:: input37                                                                 |T1: 2 x 1024 x 1024
input:: input38                                                                 |T1: 2 x 2 x 1024 x 512
input:: input39                                                                 |T1: 2 x 2 x 1024 x 512
Reshape: input28, init7_s2_2048_10242 -> Reshape2Of3PatternR_input28            |T1: 2048 x 1024              - Reshape2Of3Pattern--mul17
Mul: input37, input2 -> _onx_mul0                                               |T1: 2 x 1024 x 1024          - mul
Mul: input37, input36 -> _onx_mul02                                             |T1: 2 x 1024 x 1024          - mul3
ReduceSum: _onx_mul02, init7_s2_0_1 -> output_2                                 |T1: 1024                     - ReduceReshapePattern--sum
Mul: _onx_mul0, input34 -> _onx_mul03                                           |T1: 2 x 1024 x 1024          - mul5
Mul: _onx_mul0, input35 -> _onx_mul04                                           |T1: 2 x 1024 x 1024          - mul7
ReduceSum: _onx_mul03, init7_s1_2 -> sum_2                                      |T1: 2 x 1024 x 1             - sum2
Pow: input35, init1_s1_ -> pow_4                                                |T1: 2 x 1024 x 1             - Pow
Mul: sum_2, init1_s_ -> _onx_mul05                                              |T1: 2 x 1024 x 1             - mul9
Mul: _onx_mul05, pow_4 -> _onx_mul06                                            |T1: 2 x 1024 x 1             - mul11
Mul: _onx_mul06, init1_s_12 -> mul-_onx_mul06                                   |T1: 2 x 1024 x 1             - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst
Mul: mul-_onx_mul06, input34 -> _onx_mul08                                      |T1: 2 x 1024 x 1024          - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst2
Add: _onx_mul04, _onx_mul08 -> add_8                                            |T1: 2 x 1024 x 1024          - add_Tensor
Reshape: add_8, init7_s2_2048_1024 -> view_22                                   |T1: 2048 x 1024              - view2
Gemm: view_22, input32 -> mm_9                                                  |T1: 2048 x 1024              - TransposeMatMulPattern--mm2
Gemm: view_22, input33 -> output_10                                             |T1: 1024 x 1024              - TransposeMatMulPattern--mm
Reshape: mm_9, init7_s3_2_1024_10242 -> view_23                                 |T1: 2 x 1024 x 1024          - view3
Mul: mm_9, Reshape2Of3PatternR_input28 -> view_24                               |T1: 2048 x 1024              - Reshape2Of3Pattern--mul172
Gemm: view_24, input30 -> output_9                                              |T1: 1024 x 1024              - TransposeMatMulPattern--mm3
Mul: view_23, input31 -> _onx_mul010                                            |T1: 2 x 1024 x 1024          - mul19
Sigmoid: input27 -> sigmoid                                                     |T1: 2 x 1024 x 1024          - Sigmoid
Mul: input27, sigmoid -> Sub1MulPattern--_onx_mul011                            |T1: 2 x 1024 x 1024          - Sub1MulPattern--mul21
Sub: input27, Sub1MulPattern--_onx_mul011 -> _onx_mul011                        |T1: 2 x 1024 x 1024          - Sub1MulPattern--mul212
Add: _onx_mul011, init1_s_4 -> add_9                                            |T1: 2 x 1024 x 1024          - add_Scalar
Mul: sigmoid, add_9 -> _onx_mul012                                              |T1: 2 x 1024 x 1024          - mul23
Mul: _onx_mul010, _onx_mul012 -> _onx_mul013                                    |T1: 2 x 1024 x 1024          - mul25
Reshape: _onx_mul013, init7_s2_2048_10243 -> view_26                            |T1: 2048 x 1024              - view6
Gemm: view_26, input25 -> mm_13                                                 |T1: 2048 x 1024              - TransposeMatMulPattern--mm6
Gemm: view_26, input26 -> output_8                                              |T1: 1024 x 1024              - TransposeMatMulPattern--mm5
Gemm: view_24, input29, mm_13 -> add-mm_11                                      |T1: 2048 x 1024              - MatMulAddPattern--TransposeMatMulPattern--mm4
Reshape: add-mm_11, init7_s3_2_1024_10243 -> add_10                             |T1: 2 x 1024 x 1024          - ReshapeReshapeBinaryPattern--add_Tensor22
Mul: add_10, input1 -> _onx_mul014                                              |T1: 2 x 1024 x 1024          - mul27
Mul: add_10, input24 -> _onx_mul015                                             |T1: 2 x 1024 x 1024          - mul29
ReduceSum: _onx_mul015, init7_s2_0_12 -> output_1                               |T1: 1024                     - ReduceReshapePattern--sum3
Mul: _onx_mul014, input22 -> _onx_mul016                                        |T1: 2 x 1024 x 1024          - mul31
Mul: _onx_mul014, input23 -> _onx_mul017                                        |T1: 2 x 1024 x 1024          - mul33
ReduceSum: _onx_mul016, init7_s1_22 -> sum_4                                    |T1: 2 x 1024 x 1             - sum4
Add: add_8, _onx_mul017 -> add_11                                               |T1: 2 x 1024 x 1024          - add_Tensor3
Pow: input23, init1_s1_2 -> pow_6                                               |T1: 2 x 1024 x 1             - Pow1
Mul: sum_4, init1_s_5 -> _onx_mul018                                            |T1: 2 x 1024 x 1             - mul35
Mul: _onx_mul018, pow_6 -> _onx_mul019                                          |T1: 2 x 1024 x 1             - mul37
Mul: _onx_mul019, init1_s_12 -> mul-_onx_mul019                                 |T1: 2 x 1024 x 1             - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst
Mul: mul-_onx_mul019, input22 -> _onx_mul021                                    |T1: 2 x 1024 x 1024          - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst2
Add: add_11, _onx_mul021 -> add_12                                              |T1: 2 x 1024 x 1024          - add_Tensor4
Reshape: add_12, init7_s2_2048_10244 -> view_29                                 |T1: 2048 x 1024              - view9
Gemm: view_29, input20 -> mm_15                                                 |T1: 2048 x 1024              - TransposeMatMulPattern--mm8
Gemm: view_29, input21 -> output_7                                              |T1: 1024 x 1024              - TransposeMatMulPattern--mm7
Reshape: mm_15, init7_s4_2_1024_2_512 -> view_31                                |T1: 2 x 1024 x 2 x 512       - ReshapeReshapePattern--view10
Transpose: view_31 -> transpose_5                                               |T1: 2 x 2 x 1024 x 512       - Transpose
Reshape: input18, init7_s4_2_2_1024_10243 -> TransposeReshapeMatMulPatternL_input18   |T1: 2 x 2 x 1024 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm2
Transpose: TransposeReshapeMatMulPatternL_input18 -> MatMulReshape2Of3PatternL_transpose_6          |T1: 2 x 2 x 1024 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm22
MatMul: MatMulReshape2Of3PatternL_transpose_6, transpose_5 -> view_32           |T1: 2 x 2 x 1024 x 512       - MatMulReshape2Of3Pattern--bmm2
Reshape: input19, init7_s4_2_2_1024_5123 -> TransposeReshapeMatMulPatternL_input19  |T1: 2 x 2 x 1024 x 512   - TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm23
Transpose: TransposeReshapeMatMulPatternL_input19 -> MatMulReshape2Of3PatternL__unsafe_view_3             |T1: 2 x 2 x 512 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm232
MatMul: transpose_5, MatMulReshape2Of3PatternL__unsafe_view_3 -> view_33        |T1: 2 x 2 x 1024 x 1024      - MatMulReshape2Of3Pattern--bmm23
Add: input39, view_32 -> add_13                                                 |T1: 2 x 2 x 1024 x 512       - add_Tensor5
Mul: view_33, input17 -> _onx_mul022                                            |T1: 2 x 2 x 1024 x 1024      - Mul
ReduceSum: _onx_mul022, init7_s1_-1 -> _onx_reducesum0                          |T1: 2 x 2 x 1024 x 1         - softmax_backward_data
Mul: input17, _onx_reducesum0 -> _onx_mul023                                    |T1: 2 x 2 x 1024 x 1024      - softmax_backward_data2
Sub: _onx_mul022, _onx_mul023 -> _softmax_backward_data                         |T1: 2 x 2 x 1024 x 1024      - softmax_backward_data3
Div: _softmax_backward_data, init1_s_8 -> div_3                                 |T1: 2 x 2 x 1024 x 1024      - div_Tensor
Reshape: input15, init7_s4_2_2_1024_5123 -> TransposeReshapeMatMulPatternL_input15  |T1: 2 x 2 x 1024 x 512   - TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm32
Transpose: TransposeReshapeMatMulPatternL_input15 -> MatMulReshape2Of3PatternL_transpose_8          |T1: 2 x 2 x 512 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm322
MatMul: MatMulReshape2Of3PatternL_transpose_8, div_3 -> view_35                 |T1: 2 x 2 x 512 x 1024       - MatMulReshape2Of3Pattern--bmm32
Reshape: input16, init7_s4_2_2_512_10243 -> TransposeReshapeMatMulPatternL_input16  |T1: 2 x 2 x 512 x 1024   - TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm42
Transpose: TransposeReshapeMatMulPatternL_input16 -> MatMulReshape2Of3PatternL_view_34      |T1: 2 x 2 x 1024 x 512- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm422
MatMul: div_3, MatMulReshape2Of3PatternL_view_34 -> view_36                     |T1: 2 x 2 x 1024 x 512       - MatMulReshape2Of3Pattern--bmm42
Transpose: view_35 -> transpose_10                                              |T1: 2 x 2 x 1024 x 512       - Transpose12345
Add: input38, transpose_10 -> add_14                                            |T1: 2 x 2 x 1024 x 512       - add_Tensor6
Mul: add_14, input14 -> _onx_mul024                                             |T1: 2 x 2 x 1024 x 512       - mul43
Split: _onx_mul024, init7_s2_256_256 -> slice_10, slice_11                      |T1: 2 x 2 x 1024 x 256 T1: 2 x 2 x 1024 x 256- SlicesSplitPattern--Slice
Neg: slice_10 -> neg_2                                                          |T1: 2 x 2 x 1024 x 256       - Neg
Concat: slice_11, neg_2 -> add_15                                               |T1: 2 x 2 x 1024 x 512       - RotaryConcatPartPattern--add_Tensor7
Mul: add_14, input13 -> _onx_mul025                                             |T1: 2 x 2 x 1024 x 512       - mul45
Add: add_15, _onx_mul025 -> add_16                                              |T1: 2 x 2 x 1024 x 512       - add_Tensor8
Mul: view_36, input14 -> _onx_mul026                                            |T1: 2 x 2 x 1024 x 512       - mul47
Split: _onx_mul026, init7_s2_256_256 -> slice_12, slice_13                      |T1: 2 x 2 x 1024 x 256 T1: 2 x 2 x 1024 x 256- SlicesSplitPattern--Slice12
Neg: slice_12 -> neg_3                                                          |T1: 2 x 2 x 1024 x 256       - Neg1
Concat: slice_13, neg_3 -> add_17                                               |T1: 2 x 2 x 1024 x 512       - RotaryConcatPartPattern--add_Tensor9
Mul: view_36, input13 -> _onx_mul027                                            |T1: 2 x 2 x 1024 x 512       - mul49
Add: add_17, _onx_mul027 -> add_18                                              |T1: 2 x 2 x 1024 x 512       - add_Tensor10
Transpose: add_13 -> transpose_11                                               |T1: 2 x 1024 x 2 x 512       - Transpose123456
Transpose: add_16 -> transpose_12                                               |T1: 2 x 1024 x 2 x 512       - Transpose1234567
Transpose: add_18 -> transpose_13                                               |T1: 2 x 1024 x 2 x 512       - Transpose12345678
Reshape: transpose_11, init7_s2_2048_10245 -> view_37                           |T1: 2048 x 1024              - ReshapeReshapePattern--_unsafe_view2
Gemm: view_37, input12 -> output_6                                              |T1: 1024 x 1024              - TransposeMatMulPattern--mm9
Reshape: transpose_12, init7_s2_2048_10246 -> view_39                           |T1: 2048 x 1024              - ReshapeReshapePattern--_unsafe_view3
Gemm: view_39, input10 -> output_5                                              |T1: 1024 x 1024              - TransposeMatMulPattern--mm11
Reshape: transpose_13, init7_s2_2048_10247 -> view_41                           |T1: 2048 x 1024              - ReshapeReshapePattern--_unsafe_view4
Gemm: view_41, input7 -> mm_21                                                  |T1: 2048 x 1024              - TransposeMatMulPattern--mm14
Gemm: view_41, input8 -> output_4                                               |T1: 1024 x 1024              - TransposeMatMulPattern--mm13
Gemm: view_39, input9, mm_21 -> MatMulAddPattern--mm_19                         |T1: 2048 x 1024              - MatMulAddPattern--TransposeMatMulPattern--mm12
Gemm: view_37, input11, MatMulAddPattern--mm_19 -> add-Reshape2Of3PatternL_add_19 |T1: 2048 x 1024            - MatMulAddPattern--MatMulAddPattern--TransposeMatMulPattern--mm102
Reshape: add-Reshape2Of3PatternL_add_19, init7_s3_2_1024_102413 -> add_20       |T1: 2 x 1024 x 1024          - ReshapeReshapeBinaryPattern--add_Tensor122
Mul: add_20, input0 -> _onx_mul028                                              |T1: 2 x 1024 x 1024          - mul51
Mul: add_20, input6 -> _onx_mul029                                              |T1: 2 x 1024 x 1024          - mul53
ReduceSum: _onx_mul029, init7_s2_0_13 -> output_0                               |T1: 1024                     - ReduceReshapePattern--sum5
Mul: _onx_mul028, input4 -> _onx_mul030                                         |T1: 2 x 1024 x 1024          - mul55
Mul: _onx_mul028, input5 -> _onx_mul031                                         |T1: 2 x 1024 x 1024          - mul57
ReduceSum: _onx_mul030, init7_s1_23 -> sum_6                                    |T1: 2 x 1024 x 1             - sum6
Add: add_12, _onx_mul031 -> add_21                                              |T1: 2 x 1024 x 1024          - add_Tensor13
Pow: input5, init1_s1_3 -> pow_8                                                |T1: 2 x 1024 x 1             - Pow12
Mul: sum_6, init1_s_9 -> _onx_mul032                                            |T1: 2 x 1024 x 1             - mul59
Mul: _onx_mul032, pow_8 -> _onx_mul033                                          |T1: 2 x 1024 x 1             - mul61
Mul: _onx_mul033, init1_s_12 -> mul-_onx_mul033                                 |T1: 2 x 1024 x 1             - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst
Mul: mul-_onx_mul033, input4 -> _onx_mul035                                     |T1: 2 x 1024 x 1024          - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst2
Add: add_21, _onx_mul035 -> add_22                                              |T1: 2 x 1024 x 1024          - add_Tensor14
Equal: input3, init7_s_-1 -> eq_2                                               |T9: 2 x 1024                 - Equal
Unsqueeze: eq_2, init7_s1_-12 -> unsqueeze_6                                    |T9: 2 x 1024 x 1             - Unsqueeze
Where: unsqueeze_6, init1_s1_4, add_22 -> _onx_where0                           |T1: 2 x 1024 x 1024          - masked_fill_Scalar
Unsqueeze: input3, init7_s1_-13 -> _onx_unsqueeze0                              |T7: 2 x 1024 x 1             - aten__unsafe_index_put
ConstantOfShape: init7_s2_1024_10242 -> _onx_constantofshape05                  |T1: 1024 x 1024              - aten__unsafe_index_put2
ScatterND: _onx_constantofshape05, _onx_unsqueeze0, _onx_where0 -> output_3     |T1: 1024 x 1024              - aten__unsafe_index_put3
Constant:  -> output_11                                                         |T1:                          - Constant
Identity: output_11 -> output_12                                                |T1:                          - ._update_structures_with_proto
Identity: output_11 -> output_13                                                |T1:                          - ._update_structures_with_proto
Identity: output_11 -> output_14                                                |T1:                          - ._update_structures_with_proto
output:: output_0                                                               |T1: 1024
output:: output_1                                                               |T1: 1024
output:: output_2                                                               |T1: 1024
output:: output_3                                                               |T1: 1024 x 1024
output:: output_4                                                               |T1: 1024 x 1024
output:: output_5                                                               |T1: 1024 x 1024
output:: output_6                                                               |T1: 1024 x 1024
output:: output_7                                                               |T1: 1024 x 1024
output:: output_8                                                               |T1: 1024 x 1024
output:: output_9                                                               |T1: 1024 x 1024
output:: output_10                                                              |T1: 1024 x 1024
output:: output_11                                                              |T1:
output:: output_12                                                              |T1:
output:: output_13                                                              |T1:
output:: output_14                                                              |T1:
--
[GraphBuilderPatternOptimization.optimize] iteration 0: 117 nodes, priority=1
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul5
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul7
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul11
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst2
[SwitchOrderBinaryPattern.match] NONE - line: 184:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul25
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul31
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul33
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor3
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul37
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst2
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor4
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul55
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul57
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor13
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul61
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst2
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor14
[GraphBuilderPatternOptimization.optimize] done all: -0 +0 nodes
[GraphBuilderPatternOptimization.optimize] done after 1 iterations with 117 nodes in 0.002
    STAT build_graph_for_pattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00035422100336290896
    STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00019108100241282955
    STAT check_pattern_B0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0002679980025277473
    STAT match_SwitchOrderBinaryPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.0006592560021090321
    STAT remove_identity_nodes +0 -0 #it=1 maxmatch=0 i=0 - time=0.0003321409967611544
--MODEL: 117 nodes, 40 inputs, 15 outputs, 36 initializers--
         INPUT:  39 x 1t
         INPUT:   1 x 7t
     INPUT-SEQ:  40 x Falset
        OUTPUT:  15 x 1t
    OUTPUT-SEQ:  15 x Falset
          INIT:  10 x 1t
          INIT:  26 x 7t
          NODE:  10 x Add
          NODE:   2 x Concat
          NODE:   1 x Constant
          NODE:   1 x ConstantOfShape
          NODE:   1 x Div
          NODE:   1 x Equal
          NODE:  14 x Gemm
          NODE:   3 x Identity
          NODE:   4 x MatMul
          NODE:  35 x Mul
          NODE:   2 x Neg
          NODE:   3 x Pow
          NODE:   7 x ReduceSum
          NODE:  15 x Reshape
          NODE:   1 x ScatterND
          NODE:   1 x Sigmoid
          NODE:   2 x Split
          NODE:   2 x Sub
          NODE:   9 x Transpose
          NODE:   2 x Unsqueeze
          NODE:   1 x Where
--MODEL: 117 nodes, 40 inputs, 15 outputs, 36 initializers--DETAILED--
     INPUT:   3 x 1t[1024]
     INPUT:   7 x 1t[1024x1024]
     INPUT:   2 x 1t[1024x512]
     INPUT:   7 x 1t[2048x1024]
     INPUT:  10 x 1t[2x1024x1024]
     INPUT:   3 x 1t[2x1024x1]
     INPUT:   1 x 1t[2x2x1024x1024]
     INPUT:   2 x 1t[2x2x1024x512]
     INPUT:   1 x 1t[4x1024x1024]
     INPUT:   2 x 1t[4x1024x512]
     INPUT:   1 x 1t[4x512x1024]
     INPUT:   1 x 7t[2x1024]
    OUTPUT:   3 x 1t[1024]
    OUTPUT:   8 x 1t[1024x1024]
    OUTPUT:   4 x 1t[1]
      INIT:  10 x 1t[1]
      INIT:   7 x 7t[1]
      INIT:  12 x 7t[2]
      INIT:   3 x 7t[3]
      INIT:   4 x 7t[4]
      NODE:   1 x Add -SIG- 1t[2x1024x1024], 1t[1]
      NODE:   5 x Add -SIG- 1t[2x1024x1024], 1t[2x1024x1024]
      NODE:   4 x Add -SIG- 1t[2x2x1024x512], 1t[2x2x1024x512]
      NODE:   2 x Concat -SIG- 1t[2x2x1024x256], 1t[2x2x1024x256]
      NODE:   1 x Constant -SIG-
      NODE:   1 x ConstantOfShape -SIG- 7t[2]
      NODE:   1 x Div -SIG- 1t[2x2x1024x1024], 1t[1]
      NODE:   1 x Equal -SIG- 7t[2x1024], 7t[1]
      NODE:   4 x Gemm -SIG- 1t[2048x1024], 1t[1024x1024]
      NODE:   3 x Gemm -SIG- 1t[2048x1024], 1t[1024x1024], 1t[2048x1024]
      NODE:   7 x Gemm -SIG- 1t[2048x1024], 1t[2048x1024]
      NODE:   3 x Identity -SIG- 1t[1]
      NODE:   2 x MatMul -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x512]
      NODE:   1 x MatMul -SIG- 1t[2x2x1024x512], 1t[2x2x512x1024]
      NODE:   1 x MatMul -SIG- 1t[2x2x512x1024], 1t[2x2x1024x1024]
      NODE:   1 x Mul -SIG- 1t[2048x1024], 1t[2048x1024]
      NODE:   3 x Mul -SIG- 1t[2x1024x1024], 1t[1024]
      NODE:  10 x Mul -SIG- 1t[2x1024x1024], 1t[2x1024x1024]
      NODE:   3 x Mul -SIG- 1t[2x1024x1024], 1t[2x1024x1]
      NODE:   6 x Mul -SIG- 1t[2x1024x1], 1t[1]
      NODE:   3 x Mul -SIG- 1t[2x1024x1], 1t[2x1024x1024]
      NODE:   3 x Mul -SIG- 1t[2x1024x1], 1t[2x1024x1]
      NODE:   1 x Mul -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x1024]
      NODE:   1 x Mul -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x1]
      NODE:   4 x Mul -SIG- 1t[2x2x1024x512], 1t[1024x512]
      NODE:   2 x Neg -SIG- 1t[2x2x1024x256]
      NODE:   3 x Pow -SIG- 1t[2x1024x1], 1t[1]
      NODE:   3 x ReduceSum -SIG- 1t[2x1024x1024], 7t[1]
      NODE:   3 x ReduceSum -SIG- 1t[2x1024x1024], 7t[2]
      NODE:   1 x ReduceSum -SIG- 1t[2x2x1024x1024], 7t[1]
      NODE:   3 x Reshape -SIG- 1t[2048x1024], 7t[3]
      NODE:   1 x Reshape -SIG- 1t[2048x1024], 7t[4]
      NODE:   4 x Reshape -SIG- 1t[2x1024x1024], 7t[2]
      NODE:   3 x Reshape -SIG- 1t[2x1024x2x512], 7t[2]
      NODE:   1 x Reshape -SIG- 1t[4x1024x1024], 7t[4]
      NODE:   2 x Reshape -SIG- 1t[4x1024x512], 7t[4]
      NODE:   1 x Reshape -SIG- 1t[4x512x1024], 7t[4]
      NODE:   1 x ScatterND -SIG- 1t[1024x1024], 7t[2x1024x1], 1t[2x1024x1024]
      NODE:   1 x Sigmoid -SIG- 1t[2x1024x1024]
      NODE:   2 x Split -SIG- 1t[2x2x1024x512], 7t[2]
      NODE:   1 x Sub -SIG- 1t[2x1024x1024], 1t[2x1024x1024]
      NODE:   1 x Sub -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x1024]
      NODE:   1 x Transpose -SIG- 1t[2x1024x2x512]-perm=0;2;1;3
      NODE:   1 x Transpose -SIG- 1t[2x2x1024x1024]-perm=0;1;3;2
      NODE:   2 x Transpose -SIG- 1t[2x2x1024x512]-perm=0;1;3;2
      NODE:   3 x Transpose -SIG- 1t[2x2x1024x512]-perm=0;2;1;3
      NODE:   2 x Transpose -SIG- 1t[2x2x512x1024]-perm=0;1;3;2
      NODE:   1 x Unsqueeze -SIG- 7t[2x1024], 7t[1]
      NODE:   1 x Unsqueeze -SIG- 9t[2x1024], 7t[1]
      NODE:   1 x Where -SIG- 9t[2x1024x1], 1t[1], 1t[2x1024x1024]
[GraphBuilder.optimize] done with 117 nodes in 0.008
pattern time_in removed added iteration instances match_index
0 check_A 0.000237 NaN NaN NaN NaN NaN
1 remove_identity_nodes 0.000339 0.0 0.0 NaN NaN NaN
2 check_B 0.000206 NaN NaN NaN NaN NaN
3 remove_unused 0.000397 0.0 NaN NaN NaN NaN
4 check_C 0.000227 NaN NaN NaN NaN NaN
5 check_pattern_00 0.000191 NaN NaN -1.0 NaN NaN
6 match_SwitchOrderBinaryPattern 0.000659 NaN NaN 0.0 0.0 0.0
7 remove_identity_nodes 0.000332 0.0 0.0 0.0 NaN NaN
8 check_pattern_B0 0.000268 NaN NaN 0.0 NaN NaN
9 build_graph_for_pattern 0.000354 NaN NaN 0.0 NaN NaN
10 pattern_optimization 0.006072 0.0 NaN NaN NaN NaN
11 check_F 0.000206 NaN NaN NaN NaN NaN
12 remove_unused 0.000360 0.0 NaN NaN NaN NaN
13 check_G 0.000192 NaN NaN NaN NaN NaN


Total running time of the script: (0 minutes 0.832 seconds)

Related examples

101: Onnx Model Rewriting

101: Onnx Model Rewriting

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

102: Fuse kernels in a small Llama Model

102: Fuse kernels in a small Llama Model

Gallery generated by Sphinx-Gallery