101: Graph Optimization

This example shows how to optimize a graph using pattern optimization. The graph was obtained by running a dummy llama model. It is the backward graph.

A model

import os
import onnx
import pandas
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.xbuilder.graph_builder import (
    GraphBuilder,
    OptimizationOptions,
)

filename = (
    os.path.join(os.path.dirname(__file__), "data", "dort-c-custom__1.onnx")
    if "__file__" in globals()
    else "data/dort-c-custom__1.onnx"
)
proto = onnx.load(filename)

print(f"number of nodes: {len(proto.graph.node)}")


print(onnx_simple_text_plot(proto))
number of nodes: 215
opset: domain='' version=18
input: name='input0' type=dtype('float32') shape=[1024]
input: name='input1' type=dtype('float32') shape=[1024]
input: name='input2' type=dtype('float32') shape=[1024]
input: name='input3' type=dtype('int64') shape=[2, 1024]
input: name='input4' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input5' type=dtype('float32') shape=[2, 1024, 1]
input: name='input6' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input7' type=dtype('float32') shape=[1024, 1024]
input: name='input8' type=dtype('float32') shape=[2048, 1024]
input: name='input9' type=dtype('float32') shape=[1024, 1024]
input: name='input10' type=dtype('float32') shape=[2048, 1024]
input: name='input11' type=dtype('float32') shape=[1024, 1024]
input: name='input12' type=dtype('float32') shape=[2048, 1024]
input: name='input13' type=dtype('float32') shape=[1024, 512]
input: name='input14' type=dtype('float32') shape=[1024, 512]
input: name='input15' type=dtype('float32') shape=[4, 1024, 512]
input: name='input16' type=dtype('float32') shape=[4, 512, 1024]
input: name='input17' type=dtype('float32') shape=[2, 2, 1024, 1024]
input: name='input18' type=dtype('float32') shape=[4, 1024, 1024]
input: name='input19' type=dtype('float32') shape=[4, 1024, 512]
input: name='input20' type=dtype('float32') shape=[1024, 1024]
input: name='input21' type=dtype('float32') shape=[2048, 1024]
input: name='input22' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input23' type=dtype('float32') shape=[2, 1024, 1]
input: name='input24' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input25' type=dtype('float32') shape=[1024, 1024]
input: name='input26' type=dtype('float32') shape=[2048, 1024]
input: name='input27' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input28' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input29' type=dtype('float32') shape=[1024, 1024]
input: name='input30' type=dtype('float32') shape=[2048, 1024]
input: name='input31' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input32' type=dtype('float32') shape=[1024, 1024]
input: name='input33' type=dtype('float32') shape=[2048, 1024]
input: name='input34' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input35' type=dtype('float32') shape=[2, 1024, 1]
input: name='input36' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input37' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input38' type=dtype('float32') shape=[2, 2, 1024, 512]
input: name='input39' type=dtype('float32') shape=[2, 2, 1024, 512]
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='init1_s_' type=dtype('float32') shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_10' type=dtype('float32') shape=() -- array([1024.], dtype=float32)
init: name='init1_s_11' type=dtype('float32') shape=() -- array([2.], dtype=float32)
init: name='init1_s_2' type=dtype('float32') shape=() -- array([1024.], dtype=float32)
init: name='init1_s_3' type=dtype('float32') shape=() -- array([2.], dtype=float32)
init: name='init1_s_4' type=dtype('float32') shape=() -- array([1.], dtype=float32)
init: name='init1_s_5' type=dtype('float32') shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_6' type=dtype('float32') shape=() -- array([1024.], dtype=float32)
init: name='init1_s_7' type=dtype('float32') shape=() -- array([2.], dtype=float32)
init: name='init1_s_8' type=dtype('float32') shape=() -- array([22.627417], dtype=float32)
init: name='init1_s_9' type=dtype('float32') shape=() -- array([-0.5], dtype=float32)
init: name='init7_s1_-1' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init7_s1_-12' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init7_s1_-13' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
init: name='init7_s1_02' type=dtype('int64') shape=(1,) -- array([0])
init: name='init7_s1_1024' type=dtype('int64') shape=(1,) -- array([1024])
init: name='init7_s1_10242' type=dtype('int64') shape=(1,) -- array([1024])
init: name='init7_s1_10243' type=dtype('int64') shape=(1,) -- array([1024])
init: name='init7_s1_2' type=dtype('int64') shape=(1,) -- array([2])
init: name='init7_s1_22' type=dtype('int64') shape=(1,) -- array([2])
init: name='init7_s1_23' type=dtype('int64') shape=(1,) -- array([2])
init: name='init7_s1_256' type=dtype('int64') shape=(1,) -- array([256])
init: name='init7_s1_2562' type=dtype('int64') shape=(1,) -- array([256])
init: name='init7_s1_2563' type=dtype('int64') shape=(1,) -- array([256])
init: name='init7_s1_2564' type=dtype('int64') shape=(1,) -- array([256])
init: name='init7_s1_3' type=dtype('int64') shape=(1,) -- array([3])
init: name='init7_s1_32' type=dtype('int64') shape=(1,) -- array([3])
init: name='init7_s1_33' type=dtype('int64') shape=(1,) -- array([3])
init: name='init7_s1_34' type=dtype('int64') shape=(1,) -- array([3])
init: name='init7_s1_512' type=dtype('int64') shape=(1,) -- array([512])
init: name='init7_s1_5122' type=dtype('int64') shape=(1,) -- array([512])
init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_0_12' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_0_13' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_1024_10242' type=dtype('int64') shape=(2,) -- array([1024, 1024])
init: name='init7_s2_2048_1024' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10242' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10243' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10244' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10245' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10246' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10247' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s3_2_1024_1024' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102410' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102411' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102412' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102413' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102414' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_102415' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10242' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10243' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10245' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10246' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10247' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10248' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10249' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_4_1024_1024' type=dtype('int64') shape=(3,) -- array([   4, 1024, 1024])
init: name='init7_s3_4_1024_512' type=dtype('int64') shape=(3,) -- array([   4, 1024,  512])
init: name='init7_s4_2_1024_2_512' type=dtype('int64') shape=(4,) -- array([   2, 1024,    2,  512])
init: name='init7_s4_2_2_1024_1024' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024, 1024])
init: name='init7_s4_2_2_1024_256' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_2562' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_2563' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_2564' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  256])
init: name='init7_s4_2_2_1024_512' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  512])
init: name='init7_s4_2_2_1024_5122' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  512])
init: name='init7_s4_2_2_512_1024' type=dtype('int64') shape=(4,) -- array([   2,    2,  512, 1024])
init: name='init7_s_-1' type=dtype('int64') shape=() -- array([-1])
Constant(value_float=0) -> output_11
Mul(input37, input2) -> _onx_mul0
  Cast(_onx_mul0, to=1) -> mul_13
    Mul(mul_13, input34) -> _onx_mul03
      Cast(_onx_mul03, to=1) -> mul_15
        ReduceSum(mul_15, init7_s1_2, keepdims=1) -> sum_2
          Mul(sum_2, init1_s_) -> _onx_mul05
            Cast(_onx_mul05, to=1) -> mul_17
Mul(input37, input36) -> _onx_mul02
  Cast(_onx_mul02, to=1) -> mul_14
    ReduceSum(mul_14, init7_s2_0_1, keepdims=1) -> sum_1
      Reshape(sum_1, init7_s1_1024) -> output_2
    Mul(mul_13, input35) -> _onx_mul04
      Cast(_onx_mul04, to=1) -> mul_16
Pow(input35, init1_s1_) -> pow_4
  Mul(mul_17, pow_4) -> _onx_mul06
    Cast(_onx_mul06, to=1) -> mul_18
      Expand(mul_18, init7_s3_2_1024_1024) -> expand_5
        Div(expand_5, init1_s_2) -> _onx_div0
          Cast(_onx_div0, to=1) -> div_1
Mul(input34, init1_s_3) -> _onx_mul07
  Cast(_onx_mul07, to=1) -> mul_19
    Mul(div_1, mul_19) -> _onx_mul08
      Cast(_onx_mul08, to=1) -> mul_20
        Add(mul_16, mul_20) -> add_8
          Reshape(add_8, init7_s2_2048_1024) -> view_22
            Transpose(view_22, perm=[1,0]) -> t_8
              MatMul(t_8, input33) -> mm_8
                Transpose(mm_8, perm=[1,0]) -> t_9
                  Transpose(t_9, perm=[1,0]) -> output_10
Transpose(input32, perm=[1,0]) -> t_10
  MatMul(view_22, t_10) -> mm_9
    Reshape(mm_9, init7_s3_2_1024_10242) -> view_23
      Mul(view_23, input28) -> _onx_mul09
        Cast(_onx_mul09, to=1) -> mul_21
          Reshape(mul_21, init7_s2_2048_10242) -> view_24
            Transpose(view_24, perm=[1,0]) -> t_12
              MatMul(t_12, input30) -> mm_10
                Transpose(mm_10, perm=[1,0]) -> t_13
                  Transpose(t_13, perm=[1,0]) -> output_9
      Mul(view_23, input31) -> _onx_mul010
        Cast(_onx_mul010, to=1) -> mul_22
Transpose(input29, perm=[1,0]) -> t_14
  MatMul(view_24, t_14) -> mm_11
    Reshape(mm_11, init7_s3_2_1024_10243) -> view_25
Sigmoid(input27) -> sigmoid
ConstantOfShape(init7_s3_2_1024_10245, value=[1.0]) -> fill
  Sub(fill, sigmoid) -> sub
    Mul(input27, sub) -> _onx_mul011
      Cast(_onx_mul011, to=1) -> mul_23
        Add(mul_23, init1_s_4) -> add_9
  Mul(sigmoid, add_9) -> _onx_mul012
    Cast(_onx_mul012, to=1) -> mul_24
      Mul(mul_22, mul_24) -> _onx_mul013
        Cast(_onx_mul013, to=1) -> mul_25
          Reshape(mul_25, init7_s2_2048_10243) -> view_26
            Transpose(view_26, perm=[1,0]) -> t_16
              MatMul(t_16, input26) -> mm_12
                Transpose(mm_12, perm=[1,0]) -> t_17
                  Transpose(t_17, perm=[1,0]) -> output_8
Transpose(input25, perm=[1,0]) -> t_18
  MatMul(view_26, t_18) -> mm_13
    Reshape(mm_13, init7_s3_2_1024_10246) -> view_27
      Add(view_25, view_27) -> add_10
        Mul(add_10, input1) -> _onx_mul014
          Cast(_onx_mul014, to=1) -> mul_26
            Mul(mul_26, input22) -> _onx_mul016
              Cast(_onx_mul016, to=1) -> mul_28
                ReduceSum(mul_28, init7_s1_22, keepdims=1) -> sum_4
                  Mul(sum_4, init1_s_5) -> _onx_mul018
                    Cast(_onx_mul018, to=1) -> mul_30
        Mul(add_10, input24) -> _onx_mul015
          Cast(_onx_mul015, to=1) -> mul_27
            ReduceSum(mul_27, init7_s2_0_12, keepdims=1) -> sum_3
              Reshape(sum_3, init7_s1_10242) -> output_1
            Mul(mul_26, input23) -> _onx_mul017
              Cast(_onx_mul017, to=1) -> mul_29
          Add(add_8, mul_29) -> add_11
Pow(input23, init1_s1_2) -> pow_6
  Mul(mul_30, pow_6) -> _onx_mul019
    Cast(_onx_mul019, to=1) -> mul_31
      Expand(mul_31, init7_s3_2_1024_10247) -> expand_6
        Div(expand_6, init1_s_6) -> _onx_div02
          Cast(_onx_div02, to=1) -> div_2
Mul(input22, init1_s_7) -> _onx_mul020
  Cast(_onx_mul020, to=1) -> mul_32
    Mul(div_2, mul_32) -> _onx_mul021
      Cast(_onx_mul021, to=1) -> mul_33
        Add(add_11, mul_33) -> add_12
          Reshape(add_12, init7_s2_2048_10244) -> view_29
            Transpose(view_29, perm=[1,0]) -> t_20
              MatMul(t_20, input21) -> mm_14
                Transpose(mm_14, perm=[1,0]) -> t_21
                  Transpose(t_21, perm=[1,0]) -> output_7
Transpose(input20, perm=[1,0]) -> t_22
  MatMul(view_29, t_22) -> mm_15
    Reshape(mm_15, init7_s3_2_1024_10248) -> view_30
      Reshape(view_30, init7_s4_2_1024_2_512) -> view_31
        Transpose(view_31, perm=[0,2,1,3]) -> transpose_5
          Reshape(transpose_5, init7_s3_4_1024_512) -> _unsafe_view_3
Transpose(input18, perm=[0,2,1]) -> transpose_6
  MatMul(transpose_6, _unsafe_view_3) -> bmm_2
    Reshape(bmm_2, init7_s4_2_2_1024_512) -> view_32
      Add(input39, view_32) -> add_13
        Transpose(add_13, perm=[0,2,1,3]) -> transpose_11
          Reshape(transpose_11, init7_s3_2_1024_10249) -> _unsafe_view_4
            Reshape(_unsafe_view_4, init7_s2_2048_10245) -> view_37
              Transpose(view_37, perm=[1,0]) -> t_24
                MatMul(t_24, input12) -> mm_16
                  Transpose(mm_16, perm=[1,0]) -> t_25
                    Transpose(t_25, perm=[1,0]) -> output_6
Transpose(input19, perm=[0,2,1]) -> transpose_7
  MatMul(_unsafe_view_3, transpose_7) -> bmm_3
    Reshape(bmm_3, init7_s4_2_2_1024_1024) -> view_33
      Cast(view_33, to=1) -> _onx_cast0
        Mul(_onx_cast0, input17) -> _onx_mul022
          ReduceSum(_onx_mul022, init7_s1_-1, keepdims=1) -> _onx_reducesum0
            Mul(input17, _onx_reducesum0) -> _onx_mul023
          Sub(_onx_mul022, _onx_mul023) -> _softmax_backward_data
            Div(_softmax_backward_data, init1_s_8) -> div_3
              Reshape(div_3, init7_s3_4_1024_1024) -> view_34
Transpose(input15, perm=[0,2,1]) -> transpose_8
  MatMul(transpose_8, view_34) -> bmm_4
    Reshape(bmm_4, init7_s4_2_2_512_1024) -> view_35
      Transpose(view_35, perm=[0,1,3,2]) -> transpose_10
        Add(input38, transpose_10) -> add_14
          Mul(add_14, input14) -> _onx_mul024
            Cast(_onx_mul024, to=1) -> mul_34
              Slice(mul_34, init7_s1_0, init7_s1_256, init7_s1_3) -> slice_10
                Neg(slice_10) -> neg_2
Transpose(input16, perm=[0,2,1]) -> transpose_9
  MatMul(view_34, transpose_9) -> bmm_5
    Reshape(bmm_5, init7_s4_2_2_1024_5122) -> view_36
      Mul(view_36, input14) -> _onx_mul026
        Cast(_onx_mul026, to=1) -> mul_36
          Slice(mul_36, init7_s1_02, init7_s1_2563, init7_s1_33) -> slice_12
            Neg(slice_12) -> neg_3
Slice(mul_34, init7_s1_2562, init7_s1_512, init7_s1_32) -> slice_11
ConstantOfShape(init7_s4_2_2_1024_256, value=[0.0]) -> _onx_constantofshape0
  Concat(_onx_constantofshape0, neg_2, axis=3) -> _onx_concat0
ConstantOfShape(init7_s4_2_2_1024_2562, value=[0.0]) -> _onx_constantofshape02
  Concat(slice_11, _onx_constantofshape02, axis=3) -> _onx_concat02
    Add(_onx_concat0, _onx_concat02) -> add_15
Mul(add_14, input13) -> _onx_mul025
  Cast(_onx_mul025, to=1) -> mul_35
    Add(add_15, mul_35) -> add_16
      Transpose(add_16, perm=[0,2,1,3]) -> transpose_12
        Reshape(transpose_12, init7_s3_2_1024_102410) -> _unsafe_view_5
          Reshape(_unsafe_view_5, init7_s2_2048_10246) -> view_39
            Transpose(view_39, perm=[1,0]) -> t_28
              MatMul(t_28, input10) -> mm_18
                Transpose(mm_18, perm=[1,0]) -> t_29
                  Transpose(t_29, perm=[1,0]) -> output_5
          Slice(mul_36, init7_s1_2564, init7_s1_5122, init7_s1_34) -> slice_13
ConstantOfShape(init7_s4_2_2_1024_2563, value=[0.0]) -> _onx_constantofshape03
  Concat(_onx_constantofshape03, neg_3, axis=3) -> _onx_concat03
ConstantOfShape(init7_s4_2_2_1024_2564, value=[0.0]) -> _onx_constantofshape04
  Concat(slice_13, _onx_constantofshape04, axis=3) -> _onx_concat04
    Add(_onx_concat03, _onx_concat04) -> add_17
Mul(view_36, input13) -> _onx_mul027
  Cast(_onx_mul027, to=1) -> mul_37
    Add(add_17, mul_37) -> add_18
      Transpose(add_18, perm=[0,2,1,3]) -> transpose_13
        Reshape(transpose_13, init7_s3_2_1024_102411) -> _unsafe_view_6
          Reshape(_unsafe_view_6, init7_s2_2048_10247) -> view_41
            Transpose(view_41, perm=[1,0]) -> t_32
              MatMul(t_32, input8) -> mm_20
                Transpose(mm_20, perm=[1,0]) -> t_33
                  Transpose(t_33, perm=[1,0]) -> output_4
Transpose(input11, perm=[1,0]) -> t_26
  MatMul(view_37, t_26) -> mm_17
    Reshape(mm_17, init7_s3_2_1024_102412) -> view_38
Transpose(input9, perm=[1,0]) -> t_30
  MatMul(view_39, t_30) -> mm_19
    Reshape(mm_19, init7_s3_2_1024_102413) -> view_40
      Add(view_38, view_40) -> add_19
Transpose(input7, perm=[1,0]) -> t_34
  MatMul(view_41, t_34) -> mm_21
    Reshape(mm_21, init7_s3_2_1024_102414) -> view_42
      Add(add_19, view_42) -> add_20
        Mul(add_20, input0) -> _onx_mul028
          Cast(_onx_mul028, to=1) -> mul_38
            Mul(mul_38, input4) -> _onx_mul030
              Cast(_onx_mul030, to=1) -> mul_40
                ReduceSum(mul_40, init7_s1_23, keepdims=1) -> sum_6
                  Mul(sum_6, init1_s_9) -> _onx_mul032
                    Cast(_onx_mul032, to=1) -> mul_42
        Mul(add_20, input6) -> _onx_mul029
          Cast(_onx_mul029, to=1) -> mul_39
            ReduceSum(mul_39, init7_s2_0_13, keepdims=1) -> sum_5
              Reshape(sum_5, init7_s1_10243) -> output_0
            Mul(mul_38, input5) -> _onx_mul031
              Cast(_onx_mul031, to=1) -> mul_41
          Add(add_12, mul_41) -> add_21
Pow(input5, init1_s1_3) -> pow_8
  Mul(mul_42, pow_8) -> _onx_mul033
    Cast(_onx_mul033, to=1) -> mul_43
      Expand(mul_43, init7_s3_2_1024_102415) -> expand_7
        Div(expand_7, init1_s_10) -> _onx_div03
          Cast(_onx_div03, to=1) -> div_4
Mul(input4, init1_s_11) -> _onx_mul034
  Cast(_onx_mul034, to=1) -> mul_44
    Mul(div_4, mul_44) -> _onx_mul035
      Cast(_onx_mul035, to=1) -> mul_45
        Add(add_21, mul_45) -> add_22
Equal(input3, init7_s_-1) -> eq_2
  Unsqueeze(eq_2, init7_s1_-12) -> unsqueeze_6
    Where(unsqueeze_6, init1_s1_4, add_22) -> _onx_where0
Unsqueeze(input3, init7_s1_-13) -> _onx_unsqueeze0
ConstantOfShape(init7_s2_1024_10242, value=[0.0]) -> _onx_constantofshape05
  ScatterND(_onx_constantofshape05, _onx_unsqueeze0, _onx_where0, reduction=b'add') -> _onx_scatternd0
    Identity(_onx_scatternd0) -> output_3
Constant(value_float=0) -> output_12
Constant(value_float=0) -> output_13
Constant(value_float=0) -> output_14
output: name='output_0' type=dtype('float32') shape=[1024]
output: name='output_1' type=dtype('float32') shape=[1024]
output: name='output_2' type=dtype('float32') shape=[1024]
output: name='output_3' type=dtype('float32') shape=[1024, 1024]
output: name='output_4' type=dtype('float32') shape=[1024, 1024]
output: name='output_5' type=dtype('float32') shape=[1024, 1024]
output: name='output_6' type=dtype('float32') shape=[1024, 1024]
output: name='output_7' type=dtype('float32') shape=[1024, 1024]
output: name='output_8' type=dtype('float32') shape=[1024, 1024]
output: name='output_9' type=dtype('float32') shape=[1024, 1024]
output: name='output_10' type=dtype('float32') shape=[1024, 1024]
output: name='output_11' type=dtype('float32') shape=None
output: name='output_12' type=dtype('float32') shape=None
output: name='output_13' type=dtype('float32') shape=None
output: name='output_14' type=dtype('float32') shape=None

And visually.

plot optimize 101
<Axes: >

Optimization

gr = GraphBuilder(
    proto,
    infer_shapes=True,
    optimization_options=OptimizationOptions(
        patterns="default",
        verbose=1,  # a higher value increases the verbosity when optimizations for patterns
    ),
)
stats = gr.optimize()
df = pandas.DataFrame(stats)
df.to_csv("plot_optimize.csv")
df.to_excel("plot_optimize.xlsx")
df
[GraphBuilderPatternOptimization.optimize] start with 214 nodes and 20 patterns
[GraphBuilderPatternOptimization.optimize] use pattern 1/20 - CastPattern
[GraphBuilderPatternOptimization.optimize] use pattern 2/20 - CastCastBinaryPattern
[GraphBuilderPatternOptimization.optimize] use pattern 3/20 - ExpandPattern
[GraphBuilderPatternOptimization.optimize] use pattern 4/20 - ExpandBroadcastPattern
[GraphBuilderPatternOptimization.optimize] use pattern 5/20 - ExpandSwapPattern
[GraphBuilderPatternOptimization.optimize] use pattern 6/20 - MulMulMulScalarPattern
[GraphBuilderPatternOptimization.optimize] use pattern 7/20 - ReduceReshapePattern
[GraphBuilderPatternOptimization.optimize] use pattern 8/20 - ReshapeMatMulReshapePattern
[GraphBuilderPatternOptimization.optimize] use pattern 9/20 - Reshape2Of3Pattern
[GraphBuilderPatternOptimization.optimize] use pattern 10/20 - ReshapeReshapeBinaryPattern
[GraphBuilderPatternOptimization.optimize] use pattern 11/20 - MatMulReshape2Of3Pattern
[GraphBuilderPatternOptimization.optimize] use pattern 12/20 - ReshapeReshapePattern
[GraphBuilderPatternOptimization.optimize] use pattern 13/20 - RotaryConcatPartPattern
[GraphBuilderPatternOptimization.optimize] use pattern 14/20 - SlicesSplitPattern
[GraphBuilderPatternOptimization.optimize] use pattern 15/20 - Sub1MulPattern
[GraphBuilderPatternOptimization.optimize] use pattern 16/20 - SwitchOrderBinaryPattern
[GraphBuilderPatternOptimization.optimize] use pattern 17/20 - TransposeMatMulPattern
[GraphBuilderPatternOptimization.optimize] use pattern 18/20 - TransposeReshapeMatMulPattern
[GraphBuilderPatternOptimization.optimize] use pattern 19/20 - TransposeTransposePattern
[GraphBuilderPatternOptimization.optimize] use pattern 20/20 - UnsqueezeUnsqueezePattern
[GraphBuilderPatternOptimization.optimize] iteration 0: 214 nodes
[GraphBuilderPatternOptimization.optimize] applies 55 matches, 37*CastPattern, 3*ReduceReshapePattern, 1*Reshape2Of3Pattern, 1*ReshapeReshapeBinaryPattern, 4*ReshapeReshapePattern, 2*SlicesSplitPattern, 7*TransposeTransposePattern - time=0.006 | max_time=Sub1MulPattern:0.002
[GraphBuilderPatternOptimization.optimize] iteration 1: 152 nodes
[GraphBuilderPatternOptimization.optimize] applies 5 matches, 3*MulMulMulScalarPattern, 1*Reshape2Of3Pattern, 1*ReshapeReshapeBinaryPattern - time=0.003 | max_time=Reshape2Of3Pattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 2: 148 nodes
[GraphBuilderPatternOptimization.optimize] applies 4 matches, 3*ExpandBroadcastPattern, 1*MatMulReshape2Of3Pattern - time=0.003 | max_time=Reshape2Of3Pattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 3: 145 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: MatMulReshape2Of3Pattern replaces ['Reshape', 'MatMul', 'Reshape'] - time=0.003 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 4: 144 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: MatMulReshape2Of3Pattern replaces ['Reshape', 'MatMul', 'Reshape'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 5: 144 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: MatMulReshape2Of3Pattern replaces ['Reshape', 'MatMul', 'Reshape'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 6: 143 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: RotaryConcatPartPattern replaces ['ConstantOfShape', 'Split', 'Neg', 'Concat', 'ConstantOfShape', 'Concat', 'Add'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 7: 140 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: RotaryConcatPartPattern replaces ['ConstantOfShape', 'Split', 'Neg', 'Concat', 'ConstantOfShape', 'Concat', 'Add'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 8: 137 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: Sub1MulPattern replaces ['Mul', 'Sub'] - time=0.002 | max_time=Reshape2Of3Pattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 9: 137 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: SwitchOrderBinaryPattern replaces ['Mul', 'Mul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 10: 137 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: SwitchOrderBinaryPattern replaces ['Mul', 'Mul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 11: 137 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: SwitchOrderBinaryPattern replaces ['Mul', 'Mul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 12: 137 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 13: 136 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 14: 135 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 15: 134 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 16: 133 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 17: 132 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 18: 131 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=Reshape2Of3Pattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 19: 130 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 20: 129 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 21: 128 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 22: 127 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.003 | max_time=ReshapeReshapePattern:0.001
[GraphBuilderPatternOptimization.optimize] iteration 23: 126 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 24: 125 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 25: 124 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeMatMulPattern replaces ['Transpose', 'MatMul'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 26: 123 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeReshapeMatMulPattern replaces ['MatMul', 'Reshape', 'Transpose'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 27: 123 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeReshapeMatMulPattern replaces ['MatMul', 'Reshape', 'Transpose'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 28: 123 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeReshapeMatMulPattern replaces ['MatMul', 'Reshape', 'Transpose'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 29: 123 nodes
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: TransposeReshapeMatMulPattern replaces ['MatMul', 'Reshape', 'Transpose'] - time=0.002 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 30: 123 nodes
[GraphBuilderPatternOptimization.optimize] done after 31 iterations with 123 nodes in 0.149
pattern time_in removed iteration instances match_index added
0 check_A 0.000369 NaN NaN NaN NaN NaN
1 remove_identity_nodes 0.000442 0.0 NaN NaN NaN NaN
2 check_B 0.000318 NaN NaN NaN NaN NaN
3 remove_unused 0.000853 0.0 NaN NaN NaN NaN
4 check_C 0.000324 NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ...
899 build_for_pattern 0.000329 NaN 30.0 NaN NaN NaN
900 pattern_optimization 0.150636 91.0 NaN NaN NaN NaN
901 check_F 0.000180 NaN NaN NaN NaN NaN
902 remove_unused 0.000414 3.0 NaN NaN NaN NaN
903 check_G 0.000170 NaN NaN NaN NaN NaN

904 rows × 7 columns



Summary

for c in df.columns:
    if "time" not in c and "pattern" not in c:
        df[c] = df[c].fillna(0).astype(int)

aggs = {
    "time_in": "sum",
    "added": "sum",
    "removed": "sum",
    "iteration": "max",
    "match_index": "max",
    "instances": "sum",
}
print(df.groupby("pattern").agg(aggs))
                                      time_in  ...  instances
pattern                                        ...
apply_CastPattern                    0.003240  ...         37
apply_ExpandBroadcastPattern         0.000341  ...          3
apply_MatMulReshape2Of3Pattern       0.001726  ...          4
apply_MulMulMulScalarPattern         0.001253  ...          3
apply_ReduceReshapePattern           0.000350  ...          3
apply_Reshape2Of3Pattern             0.000820  ...          2
apply_ReshapeReshapeBinaryPattern    0.000237  ...          2
apply_ReshapeReshapePattern          0.000430  ...          4
apply_RotaryConcatPartPattern        0.003891  ...          2
apply_SlicesSplitPattern             0.000887  ...          2
apply_Sub1MulPattern                 0.000349  ...          1
apply_SwitchOrderBinaryPattern       0.000420  ...          3
apply_TransposeMatMulPattern         0.001636  ...         14
apply_TransposeReshapeMatMulPattern  0.000770  ...          4
apply_TransposeTransposePattern      0.002484  ...          7
build_for_pattern                    0.011141  ...          0
check_A                              0.000369  ...          0
check_B                              0.000318  ...          0
check_C                              0.000324  ...          0
check_F                              0.000180  ...          0
check_G                              0.000170  ...          0
check_pattern_A                      0.024054  ...          0
check_pattern_B                      0.005938  ...          0
match_CastCastBinaryPattern          0.005830  ...          0
match_CastPattern                    0.001434  ...         37
match_ExpandBroadcastPattern         0.001290  ...          3
match_ExpandPattern                  0.001465  ...          0
match_ExpandSwapPattern              0.001215  ...          0
match_MatMulReshape2Of3Pattern       0.005099  ...          4
match_MulMulMulScalarPattern         0.003237  ...          3
match_ReduceReshapePattern           0.002717  ...          3
match_Reshape2Of3Pattern             0.010424  ...          2
match_ReshapeMatMulReshapePattern    0.001981  ...          0
match_ReshapeReshapeBinaryPattern    0.004620  ...          2
match_ReshapeReshapePattern          0.003068  ...          4
match_RotaryConcatPartPattern        0.003231  ...          2
match_SlicesSplitPattern             0.001529  ...          2
match_Sub1MulPattern                 0.004499  ...          1
match_SwitchOrderBinaryPattern       0.011525  ...          3
match_TransposeMatMulPattern         0.004925  ...         14
match_TransposeReshapeMatMulPattern  0.003558  ...          4
match_TransposeTransposePattern      0.002230  ...          7
match_UnsqueezeUnsqueezePattern      0.001330  ...          0
pattern_optimization                 0.150636  ...          0
remove_identity_nodes                0.008686  ...          0
remove_unused                        0.001267  ...          0

[46 rows x 6 columns]

The total is:

diff = df["added"].sum() - df["removed"].sum()

print(f"number of removed nodes: {-diff}")
number of removed nodes: 185

Conversion to onnx.

optimized_proto = gr.to_onnx(optimize=False)
with open("plot_optimize_101.onnx", "wb") as f:
    f.write(optimized_proto.SerializeToString())

print(f"number of new nodes: {len(optimized_proto.graph.node)}")
number of new nodes: 120

It gives the following.

opset: domain='' version=18
input: name='input0' type=dtype('float32') shape=[1024]
input: name='input1' type=dtype('float32') shape=[1024]
input: name='input2' type=dtype('float32') shape=[1024]
input: name='input3' type=dtype('int64') shape=[2, 1024]
input: name='input4' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input5' type=dtype('float32') shape=[2, 1024, 1]
input: name='input6' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input7' type=dtype('float32') shape=[1024, 1024]
input: name='input8' type=dtype('float32') shape=[2048, 1024]
input: name='input9' type=dtype('float32') shape=[1024, 1024]
input: name='input10' type=dtype('float32') shape=[2048, 1024]
input: name='input11' type=dtype('float32') shape=[1024, 1024]
input: name='input12' type=dtype('float32') shape=[2048, 1024]
input: name='input13' type=dtype('float32') shape=[1024, 512]
input: name='input14' type=dtype('float32') shape=[1024, 512]
input: name='input15' type=dtype('float32') shape=[4, 1024, 512]
input: name='input16' type=dtype('float32') shape=[4, 512, 1024]
input: name='input17' type=dtype('float32') shape=[2, 2, 1024, 1024]
input: name='input18' type=dtype('float32') shape=[4, 1024, 1024]
input: name='input19' type=dtype('float32') shape=[4, 1024, 512]
input: name='input20' type=dtype('float32') shape=[1024, 1024]
input: name='input21' type=dtype('float32') shape=[2048, 1024]
input: name='input22' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input23' type=dtype('float32') shape=[2, 1024, 1]
input: name='input24' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input25' type=dtype('float32') shape=[1024, 1024]
input: name='input26' type=dtype('float32') shape=[2048, 1024]
input: name='input27' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input28' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input29' type=dtype('float32') shape=[1024, 1024]
input: name='input30' type=dtype('float32') shape=[2048, 1024]
input: name='input31' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input32' type=dtype('float32') shape=[1024, 1024]
input: name='input33' type=dtype('float32') shape=[2048, 1024]
input: name='input34' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input35' type=dtype('float32') shape=[2, 1024, 1]
input: name='input36' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input37' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input38' type=dtype('float32') shape=[2, 2, 1024, 512]
input: name='input39' type=dtype('float32') shape=[2, 2, 1024, 512]
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_2' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_3' type=dtype('float32') shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_4' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='init1_s_' type=dtype('float32') shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_4' type=dtype('float32') shape=() -- array([1.], dtype=float32)
init: name='init1_s_5' type=dtype('float32') shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_8' type=dtype('float32') shape=() -- array([22.627417], dtype=float32)
init: name='init1_s_9' type=dtype('float32') shape=() -- array([-0.5], dtype=float32)
init: name='init7_s1_-1' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init7_s1_-12' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init7_s1_-13' type=dtype('int64') shape=(1,) -- array([-1])
init: name='init7_s1_2' type=dtype('int64') shape=(1,) -- array([2])
init: name='init7_s1_22' type=dtype('int64') shape=(1,) -- array([2])
init: name='init7_s1_23' type=dtype('int64') shape=(1,) -- array([2])
init: name='init7_s2_0_1' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_0_12' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_0_13' type=dtype('int64') shape=(2,) -- array([0, 1])
init: name='init7_s2_1024_10242' type=dtype('int64') shape=(2,) -- array([1024, 1024])
init: name='init7_s2_2048_1024' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10242' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10243' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10244' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10245' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10246' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10247' type=dtype('int64') shape=(2,) -- array([2048, 1024])
init: name='init7_s3_2_1024_102413' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10242' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s3_2_1024_10243' type=dtype('int64') shape=(3,) -- array([   2, 1024, 1024])
init: name='init7_s4_2_1024_2_512' type=dtype('int64') shape=(4,) -- array([   2, 1024,    2,  512])
init: name='init7_s4_2_2_1024_1024' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024, 1024])
init: name='init7_s4_2_2_1024_512' type=dtype('int64') shape=(4,) -- array([   2,    2, 1024,  512])
init: name='init7_s4_2_2_512_1024' type=dtype('int64') shape=(4,) -- array([   2,    2,  512, 1024])
init: name='init7_s_-1' type=dtype('int64') shape=() -- array([-1])
init: name='init7_s2_256_256' type=dtype('int64') shape=(2,) -- array([256, 256])
init: name='init1_s_12' type=dtype('float32') shape=() -- array([0.00195312], dtype=float32)
init: name='init1_s_13' type=dtype('float32') shape=() -- array([0.00195312], dtype=float32)
init: name='init1_s_14' type=dtype('float32') shape=() -- array([0.00195312], dtype=float32)
Constant(value_float=0) -> output_11
  Identity(output_11) -> output_12
Reshape(input28, init7_s2_2048_10242) -> typeR_input28
Mul(input37, input2) -> _onx_mul0
  Mul(_onx_mul0, input34) -> _onx_mul03
    ReduceSum(_onx_mul03, init7_s1_2, keepdims=1) -> sum_2
      Mul(sum_2, init1_s_) -> _onx_mul05
Mul(input37, input36) -> _onx_mul02
  ReduceSum(_onx_mul02, init7_s2_0_1, keepdims=0) -> output_2
Mul(_onx_mul0, input35) -> _onx_mul04
Pow(input35, init1_s1_) -> pow_4
  Mul(_onx_mul05, pow_4) -> _onx_mul06
    Mul(_onx_mul06, init1_s_12) -> mul-_onx_mul06
      Mul(mul-_onx_mul06, input34) -> _onx_mul08
  Add(_onx_mul04, _onx_mul08) -> add_8
    Reshape(add_8, init7_s2_2048_1024) -> view_22
      Gemm(view_22, input33, transA=1, transB=0) -> output_10
Gemm(view_22, input32, transA=0, transB=1) -> mm_9
  Reshape(mm_9, init7_s3_2_1024_10242) -> view_23
    Mul(view_23, input31) -> _onx_mul010
  Mul(mm_9, typeR_input28) -> view_24
    Gemm(view_24, input30, transA=1, transB=0) -> output_9
Gemm(view_24, input29, transA=0, transB=1) -> mm_11
Sigmoid(input27) -> sigmoid
  Mul(input27, sigmoid) -> type--_onx_mul011
    Sub(input27, type--_onx_mul011) -> _onx_mul011
      Add(_onx_mul011, init1_s_4) -> add_9
  Mul(sigmoid, add_9) -> _onx_mul012
    Mul(_onx_mul010, _onx_mul012) -> _onx_mul013
      Reshape(_onx_mul013, init7_s2_2048_10243) -> view_26
        Gemm(view_26, input26, transA=1, transB=0) -> output_8
Gemm(view_26, input25, transA=0, transB=1) -> mm_13
  Add(mm_11, mm_13) -> add-mm_11
    Reshape(add-mm_11, init7_s3_2_1024_10243) -> add_10
      Mul(add_10, input1) -> _onx_mul014
        Mul(_onx_mul014, input22) -> _onx_mul016
          ReduceSum(_onx_mul016, init7_s1_22, keepdims=1) -> sum_4
            Mul(sum_4, init1_s_5) -> _onx_mul018
      Mul(add_10, input24) -> _onx_mul015
        ReduceSum(_onx_mul015, init7_s2_0_12, keepdims=0) -> output_1
Mul(_onx_mul014, input23) -> _onx_mul017
  Add(add_8, _onx_mul017) -> add_11
Pow(input23, init1_s1_2) -> pow_6
  Mul(_onx_mul018, pow_6) -> _onx_mul019
    Mul(_onx_mul019, init1_s_13) -> mul-_onx_mul019
      Mul(mul-_onx_mul019, input22) -> _onx_mul021
    Add(add_11, _onx_mul021) -> add_12
      Reshape(add_12, init7_s2_2048_10244) -> view_29
        Gemm(view_29, input21, transA=1, transB=0) -> output_7
Gemm(view_29, input20, transA=0, transB=1) -> mm_15
  Reshape(mm_15, init7_s4_2_1024_2_512) -> view_31
    Transpose(view_31, perm=[0,2,1,3]) -> transpose_5
Reshape(input18, init7_s4_2_2_1024_1024) -> typeL_input18
  Transpose(typeL_input18, perm=[0,1,3,2]) -> typeL_transpose_6
    MatMul(typeL_transpose_6, transpose_5) -> view_32
      Add(input39, view_32) -> add_13
        Transpose(add_13, perm=[0,2,1,3]) -> transpose_11
          Reshape(transpose_11, init7_s2_2048_10245) -> view_37
            Gemm(view_37, input12, transA=1, transB=0) -> output_6
Reshape(input19, init7_s4_2_2_1024_512) -> typeL_input19
  Transpose(typeL_input19, perm=[0,1,3,2]) -> typeL__unsafe_view_3
    MatMul(transpose_5, typeL__unsafe_view_3) -> view_33
      Mul(view_33, input17) -> _onx_mul022
        ReduceSum(_onx_mul022, init7_s1_-1, keepdims=1) -> _onx_reducesum0
          Mul(input17, _onx_reducesum0) -> _onx_mul023
        Sub(_onx_mul022, _onx_mul023) -> _softmax_backward_data
          Div(_softmax_backward_data, init1_s_8) -> div_3
Reshape(input15, init7_s4_2_2_1024_512) -> typeL_input15
  Transpose(typeL_input15, perm=[0,1,3,2]) -> typeL_transpose_8
    MatMul(typeL_transpose_8, div_3) -> view_35
      Transpose(view_35, perm=[0,1,3,2]) -> transpose_10
        Add(input38, transpose_10) -> add_14
          Mul(add_14, input14) -> _onx_mul024
            Split(_onx_mul024, init7_s2_256_256, axis=3) -> slice_10, slice_11
              Neg(slice_10) -> neg_2
              Concat(slice_11, neg_2, axis=3) -> add_15
Reshape(input16, init7_s4_2_2_512_1024) -> typeL_input16
  Transpose(typeL_input16, perm=[0,1,3,2]) -> typeL_view_34
    MatMul(div_3, typeL_view_34) -> view_36
      Mul(view_36, input14) -> _onx_mul026
        Split(_onx_mul026, init7_s2_256_256, axis=3) -> slice_12, slice_13
          Neg(slice_12) -> neg_3
          Concat(slice_13, neg_3, axis=3) -> add_17
Mul(add_14, input13) -> _onx_mul025
  Add(add_15, _onx_mul025) -> add_16
    Transpose(add_16, perm=[0,2,1,3]) -> transpose_12
      Reshape(transpose_12, init7_s2_2048_10246) -> view_39
        Gemm(view_39, input10, transA=1, transB=0) -> output_5
      Mul(view_36, input13) -> _onx_mul027
        Add(add_17, _onx_mul027) -> add_18
          Transpose(add_18, perm=[0,2,1,3]) -> transpose_13
            Reshape(transpose_13, init7_s2_2048_10247) -> view_41
              Gemm(view_41, input8, transA=1, transB=0) -> output_4
            Gemm(view_37, input11, transA=0, transB=1) -> mm_17
        Gemm(view_39, input9, transA=0, transB=1) -> mm_19
          Add(mm_17, mm_19) -> typeL_add_19
Gemm(view_41, input7, transA=0, transB=1) -> mm_21
  Add(typeL_add_19, mm_21) -> add-typeL_add_19
    Reshape(add-typeL_add_19, init7_s3_2_1024_102413) -> add_20
      Mul(add_20, input0) -> _onx_mul028
        Mul(_onx_mul028, input4) -> _onx_mul030
          ReduceSum(_onx_mul030, init7_s1_23, keepdims=1) -> sum_6
            Mul(sum_6, init1_s_9) -> _onx_mul032
      Mul(add_20, input6) -> _onx_mul029
        ReduceSum(_onx_mul029, init7_s2_0_13, keepdims=0) -> output_0
Mul(_onx_mul028, input5) -> _onx_mul031
  Add(add_12, _onx_mul031) -> add_21
Pow(input5, init1_s1_3) -> pow_8
  Mul(_onx_mul032, pow_8) -> _onx_mul033
    Mul(_onx_mul033, init1_s_14) -> mul-_onx_mul033
      Mul(mul-_onx_mul033, input4) -> _onx_mul035
    Add(add_21, _onx_mul035) -> add_22
Equal(input3, init7_s_-1) -> eq_2
  Unsqueeze(eq_2, init7_s1_-12) -> unsqueeze_6
    Where(unsqueeze_6, init1_s1_4, add_22) -> _onx_where0
Unsqueeze(input3, init7_s1_-13) -> _onx_unsqueeze0
ConstantOfShape(init7_s2_1024_10242, value=[0.0]) -> _onx_constantofshape05
  ScatterND(_onx_constantofshape05, _onx_unsqueeze0, _onx_where0, reduction=b'add') -> output_3
Identity(output_11) -> output_13
Identity(output_11) -> output_14
output: name='output_0' type=dtype('float32') shape=[1024]
output: name='output_1' type=dtype('float32') shape=[1024]
output: name='output_2' type=dtype('float32') shape=[1024]
output: name='output_3' type=dtype('float32') shape=[1024, 1024]
output: name='output_4' type=dtype('float32') shape=[1024, 1024]
output: name='output_5' type=dtype('float32') shape=[1024, 1024]
output: name='output_6' type=dtype('float32') shape=[1024, 1024]
output: name='output_7' type=dtype('float32') shape=[1024, 1024]
output: name='output_8' type=dtype('float32') shape=[1024, 1024]
output: name='output_9' type=dtype('float32') shape=[1024, 1024]
output: name='output_10' type=dtype('float32') shape=[1024, 1024]
output: name='output_11' type=dtype('float32') shape=None
output: name='output_12' type=dtype('float32') shape=None
output: name='output_13' type=dtype('float32') shape=None
output: name='output_14' type=dtype('float32') shape=None

And visually.

plot optimize 101
<Axes: >

The first list of patterns optimizes the graph with only standard onnx operators: Onnx (default) Patterns. The second list is specific to onnxruntime: Ort Patterns.

Focus on one optimizer

gr = GraphBuilder(
    optimized_proto,
    infer_shapes=True,
    optimization_options=OptimizationOptions(
        patterns="SwitchOrderBinary",
        verbose=10,
    ),
)
stats = gr.optimize()
df = pandas.DataFrame(stats)
df.to_csv("plot_optimize.csv")
df.to_excel("plot_optimize.xlsx")
df
[GraphBuilderPatternOptimization.optimize] start with 120 nodes and 1 patterns
[GraphBuilderPatternOptimization.optimize] use pattern 1/1 - SwitchOrderBinaryPattern
[GraphBuilderPatternOptimization.optimize] iteration 0: 120 nodes
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 203:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul
[SwitchOrderBinaryPattern.match] NONE - line: 192:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add
[GraphBuilderPatternOptimization.optimize] done all: -0 +0 nodes
[GraphBuilderPatternOptimization.optimize] done after 1 iterations with 120 nodes in 0.002
pattern time_in removed iteration instances match_index
0 check_A 0.000195 NaN NaN NaN NaN
1 remove_identity_nodes 0.000240 0.0 NaN NaN NaN
2 check_B 0.000178 NaN NaN NaN NaN
3 remove_unused 0.000430 0.0 NaN NaN NaN
4 check_C 0.000175 NaN NaN NaN NaN
5 match_SwitchOrderBinaryPattern 0.000913 NaN 0.0 0.0 0.0
6 remove_identity_nodes 0.000235 0.0 0.0 NaN NaN
7 check_pattern_B 0.000199 NaN 0.0 NaN NaN
8 build_for_pattern 0.000341 NaN 0.0 NaN NaN
9 pattern_optimization 0.002288 0.0 NaN NaN NaN
10 check_F 0.000181 NaN NaN NaN NaN
11 remove_unused 0.000414 0.0 NaN NaN NaN
12 check_G 0.000175 NaN NaN NaN NaN


Total running time of the script: (0 minutes 3.689 seconds)

Gallery generated by Sphinx-Gallery