Note
Go to the end to download the full example code.
101: Onnx Model Optimization based on Pattern Rewriting¶
This example shows how to optimize a graph using pattern optimization. The graph was obtained by running a dummy llama model. It is the backward graph.
A model¶
import os
import onnx
import pandas
from experimental_experiment.helpers import pretty_onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.xbuilder.graph_builder import (
GraphBuilder,
OptimizationOptions,
)
filename = (
os.path.join(os.path.dirname(__file__), "data", "dort-c-custom__1.onnx")
if "__file__" in globals()
else "data/dort-c-custom__1.onnx"
)
proto = onnx.load(filename)
print(f"number of nodes: {len(proto.graph.node)}")
print(pretty_onnx(proto))
number of nodes: 215
opset: domain='' version=18
input: name='input0' type=dtype('float32') shape=[1024]
input: name='input1' type=dtype('float32') shape=[1024]
input: name='input2' type=dtype('float32') shape=[1024]
input: name='input3' type=dtype('int64') shape=[2, 1024]
input: name='input4' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input5' type=dtype('float32') shape=[2, 1024, 1]
input: name='input6' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input7' type=dtype('float32') shape=[1024, 1024]
input: name='input8' type=dtype('float32') shape=[2048, 1024]
input: name='input9' type=dtype('float32') shape=[1024, 1024]
input: name='input10' type=dtype('float32') shape=[2048, 1024]
input: name='input11' type=dtype('float32') shape=[1024, 1024]
input: name='input12' type=dtype('float32') shape=[2048, 1024]
input: name='input13' type=dtype('float32') shape=[1024, 512]
input: name='input14' type=dtype('float32') shape=[1024, 512]
input: name='input15' type=dtype('float32') shape=[4, 1024, 512]
input: name='input16' type=dtype('float32') shape=[4, 512, 1024]
input: name='input17' type=dtype('float32') shape=[2, 2, 1024, 1024]
input: name='input18' type=dtype('float32') shape=[4, 1024, 1024]
input: name='input19' type=dtype('float32') shape=[4, 1024, 512]
input: name='input20' type=dtype('float32') shape=[1024, 1024]
input: name='input21' type=dtype('float32') shape=[2048, 1024]
input: name='input22' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input23' type=dtype('float32') shape=[2, 1024, 1]
input: name='input24' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input25' type=dtype('float32') shape=[1024, 1024]
input: name='input26' type=dtype('float32') shape=[2048, 1024]
input: name='input27' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input28' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input29' type=dtype('float32') shape=[1024, 1024]
input: name='input30' type=dtype('float32') shape=[2048, 1024]
input: name='input31' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input32' type=dtype('float32') shape=[1024, 1024]
input: name='input33' type=dtype('float32') shape=[2048, 1024]
input: name='input34' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input35' type=dtype('float32') shape=[2, 1024, 1]
input: name='input36' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input37' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input38' type=dtype('float32') shape=[2, 2, 1024, 512]
input: name='input39' type=dtype('float32') shape=[2, 2, 1024, 512]
init: name='init1_s1_' type=float32 shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_2' type=float32 shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_3' type=float32 shape=(1,) -- array([3.], dtype=float32)
init: name='init1_s1_4' type=float32 shape=(1,) -- array([0.], dtype=float32)
init: name='init1_s_' type=float32 shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_10' type=float32 shape=() -- array([1024.], dtype=float32)
init: name='init1_s_11' type=float32 shape=() -- array([2.], dtype=float32)
init: name='init1_s_2' type=float32 shape=() -- array([1024.], dtype=float32)
init: name='init1_s_3' type=float32 shape=() -- array([2.], dtype=float32)
init: name='init1_s_4' type=float32 shape=() -- array([1.], dtype=float32)
init: name='init1_s_5' type=float32 shape=() -- array([-0.5], dtype=float32)
init: name='init1_s_6' type=float32 shape=() -- array([1024.], dtype=float32)
init: name='init1_s_7' type=float32 shape=() -- array([2.], dtype=float32)
init: name='init1_s_8' type=float32 shape=() -- array([22.627417], dtype=float32)
init: name='init1_s_9' type=float32 shape=() -- array([-0.5], dtype=float32)
init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])
init: name='init7_s1_-12' type=int64 shape=(1,) -- array([-1])
init: name='init7_s1_-13' type=int64 shape=(1,) -- array([-1])
init: name='init7_s1_0' type=int64 shape=(1,) -- array([0])
init: name='init7_s1_02' type=int64 shape=(1,) -- array([0])
init: name='init7_s1_1024' type=int64 shape=(1,) -- array([1024])
init: name='init7_s1_10242' type=int64 shape=(1,) -- array([1024])
init: name='init7_s1_10243' type=int64 shape=(1,) -- array([1024])
init: name='init7_s1_2' type=int64 shape=(1,) -- array([2])
init: name='init7_s1_22' type=int64 shape=(1,) -- array([2])
init: name='init7_s1_23' type=int64 shape=(1,) -- array([2])
init: name='init7_s1_256' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_2562' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_2563' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_2564' type=int64 shape=(1,) -- array([256])
init: name='init7_s1_3' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_32' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_33' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_34' type=int64 shape=(1,) -- array([3])
init: name='init7_s1_512' type=int64 shape=(1,) -- array([512])
init: name='init7_s1_5122' type=int64 shape=(1,) -- array([512])
init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1])
init: name='init7_s2_0_12' type=int64 shape=(2,) -- array([0, 1])
init: name='init7_s2_0_13' type=int64 shape=(2,) -- array([0, 1])
init: name='init7_s2_1024_10242' type=int64 shape=(2,) -- array([1024, 1024])
init: name='init7_s2_2048_1024' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10242' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10243' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10244' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10245' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10246' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s2_2048_10247' type=int64 shape=(2,) -- array([2048, 1024])
init: name='init7_s3_2_1024_1024' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_102410' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_102411' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_102412' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_102413' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_102414' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_102415' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10242' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10243' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10245' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10246' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10247' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10248' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_2_1024_10249' type=int64 shape=(3,) -- array([ 2, 1024, 1024])
init: name='init7_s3_4_1024_1024' type=int64 shape=(3,) -- array([ 4, 1024, 1024])
init: name='init7_s3_4_1024_512' type=int64 shape=(3,) -- array([ 4, 1024, 512])
init: name='init7_s4_2_1024_2_512' type=int64 shape=(4,) -- array([ 2, 1024, 2, 512])
init: name='init7_s4_2_2_1024_1024' type=int64 shape=(4,) -- array([ 2, 2, 1024, 1024])
init: name='init7_s4_2_2_1024_256' type=int64 shape=(4,) -- array([ 2, 2, 1024, 256])
init: name='init7_s4_2_2_1024_2562' type=int64 shape=(4,) -- array([ 2, 2, 1024, 256])
init: name='init7_s4_2_2_1024_2563' type=int64 shape=(4,) -- array([ 2, 2, 1024, 256])
init: name='init7_s4_2_2_1024_2564' type=int64 shape=(4,) -- array([ 2, 2, 1024, 256])
init: name='init7_s4_2_2_1024_512' type=int64 shape=(4,) -- array([ 2, 2, 1024, 512])
init: name='init7_s4_2_2_1024_5122' type=int64 shape=(4,) -- array([ 2, 2, 1024, 512])
init: name='init7_s4_2_2_512_1024' type=int64 shape=(4,) -- array([ 2, 2, 512, 1024])
init: name='init7_s_-1' type=int64 shape=() -- array([-1])
Constant(value_float=0) -> output_11
Mul(input37, input2) -> _onx_mul0
Cast(_onx_mul0, to=1) -> mul_13
Mul(mul_13, input34) -> _onx_mul03
Cast(_onx_mul03, to=1) -> mul_15
ReduceSum(mul_15, init7_s1_2, keepdims=1) -> sum_2
Mul(sum_2, init1_s_) -> _onx_mul05
Cast(_onx_mul05, to=1) -> mul_17
Mul(input37, input36) -> _onx_mul02
Cast(_onx_mul02, to=1) -> mul_14
ReduceSum(mul_14, init7_s2_0_1, keepdims=1) -> sum_1
Reshape(sum_1, init7_s1_1024) -> output_2
Mul(mul_13, input35) -> _onx_mul04
Cast(_onx_mul04, to=1) -> mul_16
Pow(input35, init1_s1_) -> pow_4
Mul(mul_17, pow_4) -> _onx_mul06
Cast(_onx_mul06, to=1) -> mul_18
Expand(mul_18, init7_s3_2_1024_1024) -> expand_5
Div(expand_5, init1_s_2) -> _onx_div0
Cast(_onx_div0, to=1) -> div_1
Mul(input34, init1_s_3) -> _onx_mul07
Cast(_onx_mul07, to=1) -> mul_19
Mul(div_1, mul_19) -> _onx_mul08
Cast(_onx_mul08, to=1) -> mul_20
Add(mul_16, mul_20) -> add_8
Reshape(add_8, init7_s2_2048_1024) -> view_22
Transpose(view_22, perm=[1,0]) -> t_8
MatMul(t_8, input33) -> mm_8
Transpose(mm_8, perm=[1,0]) -> t_9
Transpose(t_9, perm=[1,0]) -> output_10
Transpose(input32, perm=[1,0]) -> t_10
MatMul(view_22, t_10) -> mm_9
Reshape(mm_9, init7_s3_2_1024_10242) -> view_23
Mul(view_23, input28) -> _onx_mul09
Cast(_onx_mul09, to=1) -> mul_21
Reshape(mul_21, init7_s2_2048_10242) -> view_24
Transpose(view_24, perm=[1,0]) -> t_12
MatMul(t_12, input30) -> mm_10
Transpose(mm_10, perm=[1,0]) -> t_13
Transpose(t_13, perm=[1,0]) -> output_9
Mul(view_23, input31) -> _onx_mul010
Cast(_onx_mul010, to=1) -> mul_22
Transpose(input29, perm=[1,0]) -> t_14
MatMul(view_24, t_14) -> mm_11
Reshape(mm_11, init7_s3_2_1024_10243) -> view_25
Sigmoid(input27) -> sigmoid
ConstantOfShape(init7_s3_2_1024_10245, value=[1.0]) -> fill
Sub(fill, sigmoid) -> sub
Mul(input27, sub) -> _onx_mul011
Cast(_onx_mul011, to=1) -> mul_23
Add(mul_23, init1_s_4) -> add_9
Mul(sigmoid, add_9) -> _onx_mul012
Cast(_onx_mul012, to=1) -> mul_24
Mul(mul_22, mul_24) -> _onx_mul013
Cast(_onx_mul013, to=1) -> mul_25
Reshape(mul_25, init7_s2_2048_10243) -> view_26
Transpose(view_26, perm=[1,0]) -> t_16
MatMul(t_16, input26) -> mm_12
Transpose(mm_12, perm=[1,0]) -> t_17
Transpose(t_17, perm=[1,0]) -> output_8
Transpose(input25, perm=[1,0]) -> t_18
MatMul(view_26, t_18) -> mm_13
Reshape(mm_13, init7_s3_2_1024_10246) -> view_27
Add(view_25, view_27) -> add_10
Mul(add_10, input1) -> _onx_mul014
Cast(_onx_mul014, to=1) -> mul_26
Mul(mul_26, input22) -> _onx_mul016
Cast(_onx_mul016, to=1) -> mul_28
ReduceSum(mul_28, init7_s1_22, keepdims=1) -> sum_4
Mul(sum_4, init1_s_5) -> _onx_mul018
Cast(_onx_mul018, to=1) -> mul_30
Mul(add_10, input24) -> _onx_mul015
Cast(_onx_mul015, to=1) -> mul_27
ReduceSum(mul_27, init7_s2_0_12, keepdims=1) -> sum_3
Reshape(sum_3, init7_s1_10242) -> output_1
Mul(mul_26, input23) -> _onx_mul017
Cast(_onx_mul017, to=1) -> mul_29
Add(add_8, mul_29) -> add_11
Pow(input23, init1_s1_2) -> pow_6
Mul(mul_30, pow_6) -> _onx_mul019
Cast(_onx_mul019, to=1) -> mul_31
Expand(mul_31, init7_s3_2_1024_10247) -> expand_6
Div(expand_6, init1_s_6) -> _onx_div02
Cast(_onx_div02, to=1) -> div_2
Mul(input22, init1_s_7) -> _onx_mul020
Cast(_onx_mul020, to=1) -> mul_32
Mul(div_2, mul_32) -> _onx_mul021
Cast(_onx_mul021, to=1) -> mul_33
Add(add_11, mul_33) -> add_12
Reshape(add_12, init7_s2_2048_10244) -> view_29
Transpose(view_29, perm=[1,0]) -> t_20
MatMul(t_20, input21) -> mm_14
Transpose(mm_14, perm=[1,0]) -> t_21
Transpose(t_21, perm=[1,0]) -> output_7
Transpose(input20, perm=[1,0]) -> t_22
MatMul(view_29, t_22) -> mm_15
Reshape(mm_15, init7_s3_2_1024_10248) -> view_30
Reshape(view_30, init7_s4_2_1024_2_512) -> view_31
Transpose(view_31, perm=[0,2,1,3]) -> transpose_5
Reshape(transpose_5, init7_s3_4_1024_512) -> _unsafe_view_3
Transpose(input18, perm=[0,2,1]) -> transpose_6
MatMul(transpose_6, _unsafe_view_3) -> bmm_2
Reshape(bmm_2, init7_s4_2_2_1024_512) -> view_32
Add(input39, view_32) -> add_13
Transpose(add_13, perm=[0,2,1,3]) -> transpose_11
Reshape(transpose_11, init7_s3_2_1024_10249) -> _unsafe_view_4
Reshape(_unsafe_view_4, init7_s2_2048_10245) -> view_37
Transpose(view_37, perm=[1,0]) -> t_24
MatMul(t_24, input12) -> mm_16
Transpose(mm_16, perm=[1,0]) -> t_25
Transpose(t_25, perm=[1,0]) -> output_6
Transpose(input19, perm=[0,2,1]) -> transpose_7
MatMul(_unsafe_view_3, transpose_7) -> bmm_3
Reshape(bmm_3, init7_s4_2_2_1024_1024) -> view_33
Cast(view_33, to=1) -> _onx_cast0
Mul(_onx_cast0, input17) -> _onx_mul022
ReduceSum(_onx_mul022, init7_s1_-1, keepdims=1) -> _onx_reducesum0
Mul(input17, _onx_reducesum0) -> _onx_mul023
Sub(_onx_mul022, _onx_mul023) -> _softmax_backward_data
Div(_softmax_backward_data, init1_s_8) -> div_3
Reshape(div_3, init7_s3_4_1024_1024) -> view_34
Transpose(input15, perm=[0,2,1]) -> transpose_8
MatMul(transpose_8, view_34) -> bmm_4
Reshape(bmm_4, init7_s4_2_2_512_1024) -> view_35
Transpose(view_35, perm=[0,1,3,2]) -> transpose_10
Add(input38, transpose_10) -> add_14
Mul(add_14, input14) -> _onx_mul024
Cast(_onx_mul024, to=1) -> mul_34
Slice(mul_34, init7_s1_0, init7_s1_256, init7_s1_3) -> slice_10
Neg(slice_10) -> neg_2
Transpose(input16, perm=[0,2,1]) -> transpose_9
MatMul(view_34, transpose_9) -> bmm_5
Reshape(bmm_5, init7_s4_2_2_1024_5122) -> view_36
Mul(view_36, input14) -> _onx_mul026
Cast(_onx_mul026, to=1) -> mul_36
Slice(mul_36, init7_s1_02, init7_s1_2563, init7_s1_33) -> slice_12
Neg(slice_12) -> neg_3
Slice(mul_34, init7_s1_2562, init7_s1_512, init7_s1_32) -> slice_11
ConstantOfShape(init7_s4_2_2_1024_256, value=[0.0]) -> _onx_constantofshape0
Concat(_onx_constantofshape0, neg_2, axis=3) -> _onx_concat0
ConstantOfShape(init7_s4_2_2_1024_2562, value=[0.0]) -> _onx_constantofshape02
Concat(slice_11, _onx_constantofshape02, axis=3) -> _onx_concat02
Add(_onx_concat0, _onx_concat02) -> add_15
Mul(add_14, input13) -> _onx_mul025
Cast(_onx_mul025, to=1) -> mul_35
Add(add_15, mul_35) -> add_16
Transpose(add_16, perm=[0,2,1,3]) -> transpose_12
Reshape(transpose_12, init7_s3_2_1024_102410) -> _unsafe_view_5
Reshape(_unsafe_view_5, init7_s2_2048_10246) -> view_39
Transpose(view_39, perm=[1,0]) -> t_28
MatMul(t_28, input10) -> mm_18
Transpose(mm_18, perm=[1,0]) -> t_29
Transpose(t_29, perm=[1,0]) -> output_5
Slice(mul_36, init7_s1_2564, init7_s1_5122, init7_s1_34) -> slice_13
ConstantOfShape(init7_s4_2_2_1024_2563, value=[0.0]) -> _onx_constantofshape03
Concat(_onx_constantofshape03, neg_3, axis=3) -> _onx_concat03
ConstantOfShape(init7_s4_2_2_1024_2564, value=[0.0]) -> _onx_constantofshape04
Concat(slice_13, _onx_constantofshape04, axis=3) -> _onx_concat04
Add(_onx_concat03, _onx_concat04) -> add_17
Mul(view_36, input13) -> _onx_mul027
Cast(_onx_mul027, to=1) -> mul_37
Add(add_17, mul_37) -> add_18
Transpose(add_18, perm=[0,2,1,3]) -> transpose_13
Reshape(transpose_13, init7_s3_2_1024_102411) -> _unsafe_view_6
Reshape(_unsafe_view_6, init7_s2_2048_10247) -> view_41
Transpose(view_41, perm=[1,0]) -> t_32
MatMul(t_32, input8) -> mm_20
Transpose(mm_20, perm=[1,0]) -> t_33
Transpose(t_33, perm=[1,0]) -> output_4
Transpose(input11, perm=[1,0]) -> t_26
MatMul(view_37, t_26) -> mm_17
Reshape(mm_17, init7_s3_2_1024_102412) -> view_38
Transpose(input9, perm=[1,0]) -> t_30
MatMul(view_39, t_30) -> mm_19
Reshape(mm_19, init7_s3_2_1024_102413) -> view_40
Add(view_38, view_40) -> add_19
Transpose(input7, perm=[1,0]) -> t_34
MatMul(view_41, t_34) -> mm_21
Reshape(mm_21, init7_s3_2_1024_102414) -> view_42
Add(add_19, view_42) -> add_20
Mul(add_20, input0) -> _onx_mul028
Cast(_onx_mul028, to=1) -> mul_38
Mul(mul_38, input4) -> _onx_mul030
Cast(_onx_mul030, to=1) -> mul_40
ReduceSum(mul_40, init7_s1_23, keepdims=1) -> sum_6
Mul(sum_6, init1_s_9) -> _onx_mul032
Cast(_onx_mul032, to=1) -> mul_42
Mul(add_20, input6) -> _onx_mul029
Cast(_onx_mul029, to=1) -> mul_39
ReduceSum(mul_39, init7_s2_0_13, keepdims=1) -> sum_5
Reshape(sum_5, init7_s1_10243) -> output_0
Mul(mul_38, input5) -> _onx_mul031
Cast(_onx_mul031, to=1) -> mul_41
Add(add_12, mul_41) -> add_21
Pow(input5, init1_s1_3) -> pow_8
Mul(mul_42, pow_8) -> _onx_mul033
Cast(_onx_mul033, to=1) -> mul_43
Expand(mul_43, init7_s3_2_1024_102415) -> expand_7
Div(expand_7, init1_s_10) -> _onx_div03
Cast(_onx_div03, to=1) -> div_4
Mul(input4, init1_s_11) -> _onx_mul034
Cast(_onx_mul034, to=1) -> mul_44
Mul(div_4, mul_44) -> _onx_mul035
Cast(_onx_mul035, to=1) -> mul_45
Add(add_21, mul_45) -> add_22
Equal(input3, init7_s_-1) -> eq_2
Unsqueeze(eq_2, init7_s1_-12) -> unsqueeze_6
Where(unsqueeze_6, init1_s1_4, add_22) -> _onx_where0
Unsqueeze(input3, init7_s1_-13) -> _onx_unsqueeze0
ConstantOfShape(init7_s2_1024_10242, value=[0.0]) -> _onx_constantofshape05
ScatterND(_onx_constantofshape05, _onx_unsqueeze0, _onx_where0, reduction=b'add') -> _onx_scatternd0
Identity(_onx_scatternd0) -> output_3
Constant(value_float=0) -> output_12
Constant(value_float=0) -> output_13
Constant(value_float=0) -> output_14
output: name='output_0' type=dtype('float32') shape=[1024]
output: name='output_1' type=dtype('float32') shape=[1024]
output: name='output_2' type=dtype('float32') shape=[1024]
output: name='output_3' type=dtype('float32') shape=[1024, 1024]
output: name='output_4' type=dtype('float32') shape=[1024, 1024]
output: name='output_5' type=dtype('float32') shape=[1024, 1024]
output: name='output_6' type=dtype('float32') shape=[1024, 1024]
output: name='output_7' type=dtype('float32') shape=[1024, 1024]
output: name='output_8' type=dtype('float32') shape=[1024, 1024]
output: name='output_9' type=dtype('float32') shape=[1024, 1024]
output: name='output_10' type=dtype('float32') shape=[1024, 1024]
output: name='output_11' type=dtype('float32') shape=None
output: name='output_12' type=dtype('float32') shape=None
output: name='output_13' type=dtype('float32') shape=None
output: name='output_14' type=dtype('float32') shape=None
And visually.
Optimization¶
gr = GraphBuilder(
proto,
infer_shapes_options=True,
optimization_options=OptimizationOptions(
patterns="default",
verbose=1, # a higher value increases the verbosity when optimizations for patterns
),
)
stats = gr.optimize()
df = pandas.DataFrame(stats)
df.to_csv("plot_optimize.csv")
df.to_excel("plot_optimize.xlsx")
df
[GraphBuilder.optimize] start with 214 nodes
[GraphBuilder.optimize] #patterns=44
[GraphBuilderPatternOptimization.optimize] start with 214 nodes, 73 initializers, 44 patterns, priorities=[0, 1]
[GraphBuilderPatternOptimization.optimize] iteration 0: 214 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] applies 48 matches, 37*CastPattern, 4*ReshapeReshapePattern, 7*TransposeTransposePattern - time=0.012 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization.optimize] iteration 1: 159 nodes, priority=0
[GraphBuilderPatternOptimization.optimize] increase priority to 1
[GraphBuilderPatternOptimization.optimize] iteration 2: 159 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 28 matches, 3*MulMulMulScalarPattern, 3*ReduceReshapePattern, 2*Reshape2Of3Pattern, 1*ReshapeReshapeBinaryPattern, 2*MatMulReshape2Of3Pattern, 2*RotaryConcatPartPattern, 1*Sub1MulPattern, 14*TransposeMatMulPattern - time=0.012 | max_time=Sub1MulPattern:0.002
[GraphBuilderPatternOptimization.optimize] iteration 3: 131 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 12 matches, 3*ExpandBroadcastPattern, 1*ReshapeReshapeBinaryPattern, 2*MatMulAddPattern, 2*MatMulReshape2Of3Pattern, 2*SlicesSplitPattern, 2*TransposeReshapeMatMulPattern - time=0.006 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 4: 121 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 6 matches, 1*MatMulAddPattern, 3*SwitchOrderBinaryPattern, 2*TransposeReshapeMatMulPattern - time=0.005 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 5: 121 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] applies 1 matches, [0]=MatchResult: MatMulAddPattern replaces ['Gemm', 'Add'] - time=0.005 | max_time=SwitchOrderBinaryPattern:0.000
[GraphBuilderPatternOptimization.optimize] iteration 6: 120 nodes, priority=1
[GraphBuilderPatternOptimization.optimize] done after 7 iterations with 120 nodes in 0.117
[GraphBuilder.optimize] done with 117 nodes in 0.128
Summary
time_in added removed iteration match_index instances
pattern
apply_CastPattern 0.003582 37 37 0 36 37
apply_ExpandBroadcastPattern 0.000236 3 6 3 2 3
apply_MatMulAddPattern 0.000383 5 8 5 5 4
apply_MatMulReshape2Of3Pattern 0.001935 10 12 3 10 4
apply_MulMulMulScalarPattern 0.001868 6 9 2 2 3
... ... ... ... ... ... ...
match_UnsqueezeEqualPattern 0.000465 0 0 6 28 0
match_UnsqueezeUnsqueezePattern 0.000840 0 0 6 48 0
pattern_optimization 0.119236 0 94 0 0 0
remove_identity_nodes 0.004658 44 88 2 0 0
remove_unused 0.002614 0 3 0 0 0
[72 rows x 6 columns]
The total is:
number of removed nodes: 191
Conversion to onnx.
optimized_proto = gr.to_onnx(optimize=False)
with open("plot_optimize_101.onnx", "wb") as f:
f.write(optimized_proto.SerializeToString())
print(f"number of new nodes: {len(optimized_proto.graph.node)}")
number of new nodes: 117
It gives the following.
print(pretty_onnx(optimized_proto))
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input0' type=dtype('float32') shape=[1024]
input: name='input1' type=dtype('float32') shape=[1024]
input: name='input2' type=dtype('float32') shape=[1024]
input: name='input3' type=dtype('int64') shape=[2, 1024]
input: name='input4' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input5' type=dtype('float32') shape=[2, 1024, 1]
input: name='input6' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input7' type=dtype('float32') shape=[1024, 1024]
input: name='input8' type=dtype('float32') shape=[2048, 1024]
input: name='input9' type=dtype('float32') shape=[1024, 1024]
input: name='input10' type=dtype('float32') shape=[2048, 1024]
input: name='input11' type=dtype('float32') shape=[1024, 1024]
input: name='input12' type=dtype('float32') shape=[2048, 1024]
input: name='input13' type=dtype('float32') shape=[1024, 512]
input: name='input14' type=dtype('float32') shape=[1024, 512]
input: name='input15' type=dtype('float32') shape=[4, 1024, 512]
input: name='input16' type=dtype('float32') shape=[4, 512, 1024]
input: name='input17' type=dtype('float32') shape=[2, 2, 1024, 1024]
input: name='input18' type=dtype('float32') shape=[4, 1024, 1024]
input: name='input19' type=dtype('float32') shape=[4, 1024, 512]
input: name='input20' type=dtype('float32') shape=[1024, 1024]
input: name='input21' type=dtype('float32') shape=[2048, 1024]
input: name='input22' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input23' type=dtype('float32') shape=[2, 1024, 1]
input: name='input24' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input25' type=dtype('float32') shape=[1024, 1024]
input: name='input26' type=dtype('float32') shape=[2048, 1024]
input: name='input27' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input28' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input29' type=dtype('float32') shape=[1024, 1024]
input: name='input30' type=dtype('float32') shape=[2048, 1024]
input: name='input31' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input32' type=dtype('float32') shape=[1024, 1024]
input: name='input33' type=dtype('float32') shape=[2048, 1024]
input: name='input34' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input35' type=dtype('float32') shape=[2, 1024, 1]
input: name='input36' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input37' type=dtype('float32') shape=[2, 1024, 1024]
input: name='input38' type=dtype('float32') shape=[2, 2, 1024, 512]
input: name='input39' type=dtype('float32') shape=[2, 2, 1024, 512]
init: name='init1_s1_' type=float32 shape=(1,) -- array([3.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_)
init: name='init1_s1_2' type=float32 shape=(1,) -- array([3.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_2)
init: name='init1_s1_3' type=float32 shape=(1,) -- array([3.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_3)
init: name='init1_s1_4' type=float32 shape=(1,) -- array([0.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s1_4)
init: name='init1_s_' type=float32 shape=() -- array([-0.5], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_)
init: name='init1_s_4' type=float32 shape=() -- array([1.], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_4)
init: name='init1_s_5' type=float32 shape=() -- array([-0.5], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_5)
init: name='init1_s_8' type=float32 shape=() -- array([22.627417], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_8)
init: name='init1_s_9' type=float32 shape=() -- array([-0.5], dtype=float32)-- GraphBuilder._update_structures_with_proto.1/from(init1_s_9)
init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-1)
init: name='init7_s1_-12' type=int64 shape=(1,) -- array([-1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-12)
init: name='init7_s1_-13' type=int64 shape=(1,) -- array([-1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-13)
init: name='init7_s1_2' type=int64 shape=(1,) -- array([2]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_2)
init: name='init7_s1_22' type=int64 shape=(1,) -- array([2]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_22)
init: name='init7_s1_23' type=int64 shape=(1,) -- array([2]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_23)
init: name='init7_s2_0_1' type=int64 shape=(2,) -- array([0, 1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_1)
init: name='init7_s2_0_12' type=int64 shape=(2,) -- array([0, 1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_12)
init: name='init7_s2_0_13' type=int64 shape=(2,) -- array([0, 1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_13)
init: name='init7_s2_1024_10242' type=int64 shape=(2,) -- array([1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_1024_10242)
init: name='init7_s2_2048_1024' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)##GraphBuilder.compute_constant/from(init7_s2_2048_1024)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)
init: name='init7_s2_2048_10242' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)##GraphBuilder.compute_constant/from(init7_s2_2048_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)
init: name='init7_s2_2048_10243' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)##GraphBuilder.compute_constant/from(init7_s2_2048_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)
init: name='init7_s2_2048_10244' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)##GraphBuilder.compute_constant/from(init7_s2_2048_10244)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)
init: name='init7_s2_2048_10245' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)##GraphBuilder.compute_constant/from(init7_s2_2048_10245)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)
init: name='init7_s2_2048_10246' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)##GraphBuilder.compute_constant/from(init7_s2_2048_10246)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)
init: name='init7_s2_2048_10247' type=int64 shape=(2,) -- array([2048, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)##GraphBuilder.compute_constant/from(init7_s2_2048_10247)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)
init: name='init7_s3_2_1024_102413' type=int64 shape=(3,) -- array([ 2, 1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)##GraphBuilder.compute_constant/from(init7_s3_2_1024_102413)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)
init: name='init7_s3_2_1024_10242' type=int64 shape=(3,) -- array([ 2, 1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)
init: name='init7_s3_2_1024_10243' type=int64 shape=(3,) -- array([ 2, 1024, 1024])-- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)
init: name='init7_s4_2_1024_2_512' type=int64 shape=(4,) -- array([ 2, 1024, 2, 512])-- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)##GraphBuilder.compute_constant/from(init7_s4_2_1024_2_512)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)
init: name='init7_s_-1' type=int64 shape=() -- array([-1]) -- GraphBuilder._update_structures_with_proto.1/from(init7_s_-1)
init: name='init1_s_12' type=float32 shape=() -- array([0.00195312], dtype=float32)-- MulMulMulScalarPattern.apply.new_cst##MulMulMulScalarPattern.apply.new_cst##MulMulMulScalarPattern.apply.new_cst
init: name='init7_s4_2_2_512_10243' type=int64 shape=(4,) -- array([ 2, 2, 512, 1024])-- MatMulReshape2Of3Pattern.apply.shape.2##TransposeReshapeMatMulPattern.apply.shape_name
init: name='init7_s4_2_2_1024_5123' type=int64 shape=(4,) -- array([ 2, 2, 1024, 512])-- MatMulReshape2Of3Pattern.apply.shape.2##TransposeReshapeMatMulPattern.apply.shape_name##TransposeReshapeMatMulPattern.apply.shape_name
init: name='init7_s2_256_256' type=int64 shape=(2,) -- array([256, 256])-- SlicesSplitPattern.apply.splits##SlicesSplitPattern.apply.splits
init: name='init7_s4_2_2_1024_10243' type=int64 shape=(4,) -- array([ 2, 2, 1024, 1024])-- TransposeReshapeMatMulPattern.apply.shape_name
Constant(value_float=0) -> output_11
Identity(output_11) -> output_12
Reshape(input28, init7_s2_2048_10242) -> Reshape2Of3PatternR_input28
Mul(input37, input2) -> _onx_mul0
Mul(_onx_mul0, input34) -> _onx_mul03
ReduceSum(_onx_mul03, init7_s1_2, keepdims=1) -> sum_2
Mul(sum_2, init1_s_) -> _onx_mul05
Mul(input37, input36) -> _onx_mul02
ReduceSum(_onx_mul02, init7_s2_0_1, keepdims=0) -> output_2
Mul(_onx_mul0, input35) -> _onx_mul04
Pow(input35, init1_s1_) -> pow_4
Mul(_onx_mul05, pow_4) -> _onx_mul06
Mul(_onx_mul06, init1_s_12) -> mul-_onx_mul06
Mul(mul-_onx_mul06, input34) -> _onx_mul08
Add(_onx_mul04, _onx_mul08) -> add_8
Reshape(add_8, init7_s2_2048_1024) -> view_22
Gemm(view_22, input32, transA=0, transB=1) -> mm_9
Reshape(mm_9, init7_s3_2_1024_10242) -> view_23
Mul(view_23, input31) -> _onx_mul010
Gemm(view_22, input33, transA=1, transB=0) -> output_10
Mul(mm_9, Reshape2Of3PatternR_input28) -> view_24
Gemm(view_24, input30, transA=1, transB=0) -> output_9
Sigmoid(input27) -> sigmoid
Mul(input27, sigmoid) -> Sub1MulPattern--_onx_mul011
Sub(input27, Sub1MulPattern--_onx_mul011) -> _onx_mul011
Add(_onx_mul011, init1_s_4) -> add_9
Mul(sigmoid, add_9) -> _onx_mul012
Mul(_onx_mul010, _onx_mul012) -> _onx_mul013
Reshape(_onx_mul013, init7_s2_2048_10243) -> view_26
Gemm(view_26, input25, transA=0, transB=1) -> mm_13
Gemm(view_24, input29, mm_13, transA=0, transB=1) -> add-mm_11
Reshape(add-mm_11, init7_s3_2_1024_10243) -> add_10
Mul(add_10, input1) -> _onx_mul014
Mul(_onx_mul014, input22) -> _onx_mul016
ReduceSum(_onx_mul016, init7_s1_22, keepdims=1) -> sum_4
Mul(sum_4, init1_s_5) -> _onx_mul018
Gemm(view_26, input26, transA=1, transB=0) -> output_8
Mul(add_10, input24) -> _onx_mul015
ReduceSum(_onx_mul015, init7_s2_0_12, keepdims=0) -> output_1
Mul(_onx_mul014, input23) -> _onx_mul017
Add(add_8, _onx_mul017) -> add_11
Pow(input23, init1_s1_2) -> pow_6
Mul(_onx_mul018, pow_6) -> _onx_mul019
Mul(_onx_mul019, init1_s_12) -> mul-_onx_mul019
Mul(mul-_onx_mul019, input22) -> _onx_mul021
Add(add_11, _onx_mul021) -> add_12
Reshape(add_12, init7_s2_2048_10244) -> view_29
Gemm(view_29, input20, transA=0, transB=1) -> mm_15
Reshape(mm_15, init7_s4_2_1024_2_512) -> view_31
Transpose(view_31, perm=[0,2,1,3]) -> transpose_5
Gemm(view_29, input21, transA=1, transB=0) -> output_7
Reshape(input18, init7_s4_2_2_1024_10243) -> TransposeReshapeMatMulPatternL_input18
Transpose(TransposeReshapeMatMulPatternL_input18, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL_transpose_6
MatMul(MatMulReshape2Of3PatternL_transpose_6, transpose_5) -> view_32
Add(input39, view_32) -> add_13
Transpose(add_13, perm=[0,2,1,3]) -> transpose_11
Reshape(transpose_11, init7_s2_2048_10245) -> view_37
Gemm(view_37, input12, transA=1, transB=0) -> output_6
Reshape(input19, init7_s4_2_2_1024_5123) -> TransposeReshapeMatMulPatternL_input19
Transpose(TransposeReshapeMatMulPatternL_input19, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL__unsafe_view_3
MatMul(transpose_5, MatMulReshape2Of3PatternL__unsafe_view_3) -> view_33
Mul(view_33, input17) -> _onx_mul022
ReduceSum(_onx_mul022, init7_s1_-1, keepdims=1) -> _onx_reducesum0
Mul(input17, _onx_reducesum0) -> _onx_mul023
Sub(_onx_mul022, _onx_mul023) -> _softmax_backward_data
Div(_softmax_backward_data, init1_s_8) -> div_3
Reshape(input15, init7_s4_2_2_1024_5123) -> TransposeReshapeMatMulPatternL_input15
Transpose(TransposeReshapeMatMulPatternL_input15, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL_transpose_8
MatMul(MatMulReshape2Of3PatternL_transpose_8, div_3) -> view_35
Transpose(view_35, perm=[0,1,3,2]) -> transpose_10
Add(input38, transpose_10) -> add_14
Mul(add_14, input14) -> _onx_mul024
Split(_onx_mul024, init7_s2_256_256, axis=3) -> slice_10, slice_11
Neg(slice_10) -> neg_2
Concat(slice_11, neg_2, axis=3) -> add_15
Reshape(input16, init7_s4_2_2_512_10243) -> TransposeReshapeMatMulPatternL_input16
Transpose(TransposeReshapeMatMulPatternL_input16, perm=[0,1,3,2]) -> MatMulReshape2Of3PatternL_view_34
MatMul(div_3, MatMulReshape2Of3PatternL_view_34) -> view_36
Mul(view_36, input14) -> _onx_mul026
Split(_onx_mul026, init7_s2_256_256, axis=3) -> slice_12, slice_13
Neg(slice_12) -> neg_3
Concat(slice_13, neg_3, axis=3) -> add_17
Mul(add_14, input13) -> _onx_mul025
Add(add_15, _onx_mul025) -> add_16
Transpose(add_16, perm=[0,2,1,3]) -> transpose_12
Reshape(transpose_12, init7_s2_2048_10246) -> view_39
Gemm(view_39, input10, transA=1, transB=0) -> output_5
Mul(view_36, input13) -> _onx_mul027
Add(add_17, _onx_mul027) -> add_18
Transpose(add_18, perm=[0,2,1,3]) -> transpose_13
Reshape(transpose_13, init7_s2_2048_10247) -> view_41
Gemm(view_41, input7, transA=0, transB=1) -> mm_21
Gemm(view_39, input9, mm_21, transA=0, transB=1) -> MatMulAddPattern--mm_19
Gemm(view_37, input11, MatMulAddPattern--mm_19, transA=0, transB=1) -> add-Reshape2Of3PatternL_add_19
Reshape(add-Reshape2Of3PatternL_add_19, init7_s3_2_1024_102413) -> add_20
Mul(add_20, input0) -> _onx_mul028
Mul(_onx_mul028, input4) -> _onx_mul030
ReduceSum(_onx_mul030, init7_s1_23, keepdims=1) -> sum_6
Mul(sum_6, init1_s_9) -> _onx_mul032
Gemm(view_41, input8, transA=1, transB=0) -> output_4
Mul(add_20, input6) -> _onx_mul029
ReduceSum(_onx_mul029, init7_s2_0_13, keepdims=0) -> output_0
Mul(_onx_mul028, input5) -> _onx_mul031
Add(add_12, _onx_mul031) -> add_21
Pow(input5, init1_s1_3) -> pow_8
Mul(_onx_mul032, pow_8) -> _onx_mul033
Mul(_onx_mul033, init1_s_12) -> mul-_onx_mul033
Mul(mul-_onx_mul033, input4) -> _onx_mul035
Add(add_21, _onx_mul035) -> add_22
Equal(input3, init7_s_-1) -> eq_2
Unsqueeze(eq_2, init7_s1_-12) -> unsqueeze_6
Where(unsqueeze_6, init1_s1_4, add_22) -> _onx_where0
Unsqueeze(input3, init7_s1_-13) -> _onx_unsqueeze0
ConstantOfShape(init7_s2_1024_10242, value=[0.0]) -> _onx_constantofshape05
ScatterND(_onx_constantofshape05, _onx_unsqueeze0, _onx_where0, reduction=b'add') -> output_3
Identity(output_11) -> output_13
Identity(output_11) -> output_14
output: name='output_0' type=dtype('float32') shape=[1024]
output: name='output_1' type=dtype('float32') shape=[1024]
output: name='output_2' type=dtype('float32') shape=[1024]
output: name='output_3' type=dtype('float32') shape=[1024, 1024]
output: name='output_4' type=dtype('float32') shape=[1024, 1024]
output: name='output_5' type=dtype('float32') shape=[1024, 1024]
output: name='output_6' type=dtype('float32') shape=[1024, 1024]
output: name='output_7' type=dtype('float32') shape=[1024, 1024]
output: name='output_8' type=dtype('float32') shape=[1024, 1024]
output: name='output_9' type=dtype('float32') shape=[1024, 1024]
output: name='output_10' type=dtype('float32') shape=[1024, 1024]
output: name='output_11' type=dtype('float32') shape=None
output: name='output_12' type=dtype('float32') shape=None
output: name='output_13' type=dtype('float32') shape=None
output: name='output_14' type=dtype('float32') shape=None
And visually.
The first list of patterns optimizes the graph with only standard onnx operators: experimental_experiment.xoptim.patterns. The second list is specific to onnxruntime: experimental_experiment.xoptim.patterns_ort.
Focus on one optimizer¶
gr = GraphBuilder(
optimized_proto,
infer_shapes_options=True,
optimization_options=OptimizationOptions(
patterns="SwitchOrderBinary",
verbose=10,
),
)
stats = gr.optimize()
df = pandas.DataFrame(stats)
df.to_csv("plot_optimize.csv")
df.to_excel("plot_optimize.xlsx")
df
[GraphBuilder.optimize] start with 117 nodes
[GraphBuilder.optimize] #patterns=1
[GraphBuilderPatternOptimization.optimize] start with 117 nodes, 36 initializers, 1 patterns, priorities=[1]
[GraphBuilderPatternOptimization.optimize] use pattern 1/1 - P1 - SwitchOrderBinaryPattern()
--
opset: : 18
init: init1_s1_: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_)
init: init1_s1_2: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_2)
init: init1_s1_3: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_3)
init: init1_s1_4: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s1_4)
init: init1_s_: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s_)
init: init1_s_4: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s_4)
init: init1_s_5: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s_5)
init: init1_s_8: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s_8)
init: init1_s_9: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s_9)
init: init7_s1_-1: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-1)
init: init7_s1_-12: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-12)
init: init7_s1_-13: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_-13)
init: init7_s1_2: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_2)
init: init7_s1_22: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_22)
init: init7_s1_23: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s1_23)
init: init7_s2_0_1: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_1)
init: init7_s2_0_12: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_12)
init: init7_s2_0_13: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_0_13)
init: init7_s2_1024_10242: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_1024_10242)
init: init7_s2_2048_1024: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)##GraphBuilder.compute_constant/from(init7_s2_2048_1024)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_1024)
init: init7_s2_2048_10242: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)##GraphBuilder.compute_constant/from(init7_s2_2048_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10242)
init: init7_s2_2048_10243: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)##GraphBuilder.compute_constant/from(init7_s2_2048_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10243)
init: init7_s2_2048_10244: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)##GraphBuilder.compute_constant/from(init7_s2_2048_10244)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10244)
init: init7_s2_2048_10245: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)##GraphBuilder.compute_constant/from(init7_s2_2048_10245)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10245)
init: init7_s2_2048_10246: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)##GraphBuilder.compute_constant/from(init7_s2_2048_10246)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10246)
init: init7_s2_2048_10247: int64: 2 -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)##GraphBuilder.compute_constant/from(init7_s2_2048_10247)##GraphBuilder._update_structures_with_proto.1/from(init7_s2_2048_10247)
init: init7_s3_2_1024_102413: int64: 3 -- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)##GraphBuilder.compute_constant/from(init7_s3_2_1024_102413)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_102413)
init: init7_s3_2_1024_10242: int64: 3 -- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10242)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10242)
init: init7_s3_2_1024_10243: int64: 3 -- GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)##GraphBuilder.compute_constant/from(init7_s3_2_1024_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s3_2_1024_10243)
init: init7_s4_2_1024_2_512: int64: 4 -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)##GraphBuilder.compute_constant/from(init7_s4_2_1024_2_512)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_1024_2_512)
init: init7_s_-1: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s_-1)
init: init1_s_12: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init1_s_12)
init: init7_s4_2_2_512_10243: int64: 4 -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_512_10243)##GraphBuilder.compute_constant/from(init7_s4_2_2_512_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_512_10243)
init: init7_s4_2_2_1024_5123: int64: 4 -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_5123)##GraphBuilder.compute_constant/from(init7_s4_2_2_1024_5123)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_5123)
init: init7_s2_256_256: ?: ? -- GraphBuilder._update_structures_with_proto.1/from(init7_s2_256_256)
init: init7_s4_2_2_1024_10243: int64: 4 -- GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_10243)##GraphBuilder.compute_constant/from(init7_s4_2_2_1024_10243)##GraphBuilder._update_structures_with_proto.1/from(init7_s4_2_2_1024_10243)
input:: input0 |T1: 1024
input:: input1 |T1: 1024
input:: input2 |T1: 1024
input:: input3 |T7: 2 x 1024
input:: input4 |T1: 2 x 1024 x 1024
input:: input5 |T1: 2 x 1024 x 1
input:: input6 |T1: 2 x 1024 x 1024
input:: input7 |T1: 1024 x 1024
input:: input8 |T1: 2048 x 1024
input:: input9 |T1: 1024 x 1024
input:: input10 |T1: 2048 x 1024
input:: input11 |T1: 1024 x 1024
input:: input12 |T1: 2048 x 1024
input:: input13 |T1: 1024 x 512
input:: input14 |T1: 1024 x 512
input:: input15 |T1: 4 x 1024 x 512
input:: input16 |T1: 4 x 512 x 1024
input:: input17 |T1: 2 x 2 x 1024 x 1024
input:: input18 |T1: 4 x 1024 x 1024
input:: input19 |T1: 4 x 1024 x 512
input:: input20 |T1: 1024 x 1024
input:: input21 |T1: 2048 x 1024
input:: input22 |T1: 2 x 1024 x 1024
input:: input23 |T1: 2 x 1024 x 1
input:: input24 |T1: 2 x 1024 x 1024
input:: input25 |T1: 1024 x 1024
input:: input26 |T1: 2048 x 1024
input:: input27 |T1: 2 x 1024 x 1024
input:: input28 |T1: 2 x 1024 x 1024
input:: input29 |T1: 1024 x 1024
input:: input30 |T1: 2048 x 1024
input:: input31 |T1: 2 x 1024 x 1024
input:: input32 |T1: 1024 x 1024
input:: input33 |T1: 2048 x 1024
input:: input34 |T1: 2 x 1024 x 1024
input:: input35 |T1: 2 x 1024 x 1
input:: input36 |T1: 2 x 1024 x 1024
input:: input37 |T1: 2 x 1024 x 1024
input:: input38 |T1: 2 x 2 x 1024 x 512
input:: input39 |T1: 2 x 2 x 1024 x 512
Reshape: input28, init7_s2_2048_10242 -> Reshape2Of3PatternR_input28 |T1: 2048 x 1024 - Reshape2Of3Pattern--mul17
Mul: input37, input2 -> _onx_mul0 |T1: 2 x 1024 x 1024 - mul
Mul: input37, input36 -> _onx_mul02 |T1: 2 x 1024 x 1024 - mul3
ReduceSum: _onx_mul02, init7_s2_0_1 -> output_2 |T1: 1024 - ReduceReshapePattern--sum
Mul: _onx_mul0, input34 -> _onx_mul03 |T1: 2 x 1024 x 1024 - mul5
Mul: _onx_mul0, input35 -> _onx_mul04 |T1: 2 x 1024 x 1024 - mul7
ReduceSum: _onx_mul03, init7_s1_2 -> sum_2 |T1: 2 x 1024 x 1 - sum2
Pow: input35, init1_s1_ -> pow_4 |T1: 2 x 1024 x 1 - Pow
Mul: sum_2, init1_s_ -> _onx_mul05 |T1: 2 x 1024 x 1 - mul9
Mul: _onx_mul05, pow_4 -> _onx_mul06 |T1: 2 x 1024 x 1 - mul11
Mul: _onx_mul06, init1_s_12 -> mul-_onx_mul06 |T1: 2 x 1024 x 1 - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst
Mul: mul-_onx_mul06, input34 -> _onx_mul08 |T1: 2 x 1024 x 1024 - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst2
Add: _onx_mul04, _onx_mul08 -> add_8 |T1: 2 x 1024 x 1024 - add_Tensor
Reshape: add_8, init7_s2_2048_1024 -> view_22 |T1: 2048 x 1024 - view2
Gemm: view_22, input32 -> mm_9 |T1: 2048 x 1024 - TransposeMatMulPattern--mm2
Gemm: view_22, input33 -> output_10 |T1: 1024 x 1024 - TransposeMatMulPattern--mm
Reshape: mm_9, init7_s3_2_1024_10242 -> view_23 |T1: 2 x 1024 x 1024 - view3
Mul: mm_9, Reshape2Of3PatternR_input28 -> view_24 |T1: 2048 x 1024 - Reshape2Of3Pattern--mul172
Gemm: view_24, input30 -> output_9 |T1: 1024 x 1024 - TransposeMatMulPattern--mm3
Mul: view_23, input31 -> _onx_mul010 |T1: 2 x 1024 x 1024 - mul19
Sigmoid: input27 -> sigmoid |T1: 2 x 1024 x 1024 - Sigmoid
Mul: input27, sigmoid -> Sub1MulPattern--_onx_mul011 |T1: 2 x 1024 x 1024 - Sub1MulPattern--mul21
Sub: input27, Sub1MulPattern--_onx_mul011 -> _onx_mul011 |T1: 2 x 1024 x 1024 - Sub1MulPattern--mul212
Add: _onx_mul011, init1_s_4 -> add_9 |T1: 2 x 1024 x 1024 - add_Scalar
Mul: sigmoid, add_9 -> _onx_mul012 |T1: 2 x 1024 x 1024 - mul23
Mul: _onx_mul010, _onx_mul012 -> _onx_mul013 |T1: 2 x 1024 x 1024 - mul25
Reshape: _onx_mul013, init7_s2_2048_10243 -> view_26 |T1: 2048 x 1024 - view6
Gemm: view_26, input25 -> mm_13 |T1: 2048 x 1024 - TransposeMatMulPattern--mm6
Gemm: view_26, input26 -> output_8 |T1: 1024 x 1024 - TransposeMatMulPattern--mm5
Gemm: view_24, input29, mm_13 -> add-mm_11 |T1: 2048 x 1024 - MatMulAddPattern--TransposeMatMulPattern--mm4
Reshape: add-mm_11, init7_s3_2_1024_10243 -> add_10 |T1: 2 x 1024 x 1024 - ReshapeReshapeBinaryPattern--add_Tensor22
Mul: add_10, input1 -> _onx_mul014 |T1: 2 x 1024 x 1024 - mul27
Mul: add_10, input24 -> _onx_mul015 |T1: 2 x 1024 x 1024 - mul29
ReduceSum: _onx_mul015, init7_s2_0_12 -> output_1 |T1: 1024 - ReduceReshapePattern--sum3
Mul: _onx_mul014, input22 -> _onx_mul016 |T1: 2 x 1024 x 1024 - mul31
Mul: _onx_mul014, input23 -> _onx_mul017 |T1: 2 x 1024 x 1024 - mul33
ReduceSum: _onx_mul016, init7_s1_22 -> sum_4 |T1: 2 x 1024 x 1 - sum4
Add: add_8, _onx_mul017 -> add_11 |T1: 2 x 1024 x 1024 - add_Tensor3
Pow: input23, init1_s1_2 -> pow_6 |T1: 2 x 1024 x 1 - Pow1
Mul: sum_4, init1_s_5 -> _onx_mul018 |T1: 2 x 1024 x 1 - mul35
Mul: _onx_mul018, pow_6 -> _onx_mul019 |T1: 2 x 1024 x 1 - mul37
Mul: _onx_mul019, init1_s_12 -> mul-_onx_mul019 |T1: 2 x 1024 x 1 - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst
Mul: mul-_onx_mul019, input22 -> _onx_mul021 |T1: 2 x 1024 x 1024 - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst2
Add: add_11, _onx_mul021 -> add_12 |T1: 2 x 1024 x 1024 - add_Tensor4
Reshape: add_12, init7_s2_2048_10244 -> view_29 |T1: 2048 x 1024 - view9
Gemm: view_29, input20 -> mm_15 |T1: 2048 x 1024 - TransposeMatMulPattern--mm8
Gemm: view_29, input21 -> output_7 |T1: 1024 x 1024 - TransposeMatMulPattern--mm7
Reshape: mm_15, init7_s4_2_1024_2_512 -> view_31 |T1: 2 x 1024 x 2 x 512 - ReshapeReshapePattern--view10
Transpose: view_31 -> transpose_5 |T1: 2 x 2 x 1024 x 512 - Transpose
Reshape: input18, init7_s4_2_2_1024_10243 -> TransposeReshapeMatMulPatternL_input18 |T1: 2 x 2 x 1024 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm2
Transpose: TransposeReshapeMatMulPatternL_input18 -> MatMulReshape2Of3PatternL_transpose_6 |T1: 2 x 2 x 1024 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm22
MatMul: MatMulReshape2Of3PatternL_transpose_6, transpose_5 -> view_32 |T1: 2 x 2 x 1024 x 512 - MatMulReshape2Of3Pattern--bmm2
Reshape: input19, init7_s4_2_2_1024_5123 -> TransposeReshapeMatMulPatternL_input19 |T1: 2 x 2 x 1024 x 512 - TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm23
Transpose: TransposeReshapeMatMulPatternL_input19 -> MatMulReshape2Of3PatternL__unsafe_view_3 |T1: 2 x 2 x 512 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm232
MatMul: transpose_5, MatMulReshape2Of3PatternL__unsafe_view_3 -> view_33 |T1: 2 x 2 x 1024 x 1024 - MatMulReshape2Of3Pattern--bmm23
Add: input39, view_32 -> add_13 |T1: 2 x 2 x 1024 x 512 - add_Tensor5
Mul: view_33, input17 -> _onx_mul022 |T1: 2 x 2 x 1024 x 1024 - Mul
ReduceSum: _onx_mul022, init7_s1_-1 -> _onx_reducesum0 |T1: 2 x 2 x 1024 x 1 - softmax_backward_data
Mul: input17, _onx_reducesum0 -> _onx_mul023 |T1: 2 x 2 x 1024 x 1024 - softmax_backward_data2
Sub: _onx_mul022, _onx_mul023 -> _softmax_backward_data |T1: 2 x 2 x 1024 x 1024 - softmax_backward_data3
Div: _softmax_backward_data, init1_s_8 -> div_3 |T1: 2 x 2 x 1024 x 1024 - div_Tensor
Reshape: input15, init7_s4_2_2_1024_5123 -> TransposeReshapeMatMulPatternL_input15 |T1: 2 x 2 x 1024 x 512 - TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm32
Transpose: TransposeReshapeMatMulPatternL_input15 -> MatMulReshape2Of3PatternL_transpose_8 |T1: 2 x 2 x 512 x 1024- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm322
MatMul: MatMulReshape2Of3PatternL_transpose_8, div_3 -> view_35 |T1: 2 x 2 x 512 x 1024 - MatMulReshape2Of3Pattern--bmm32
Reshape: input16, init7_s4_2_2_512_10243 -> TransposeReshapeMatMulPatternL_input16 |T1: 2 x 2 x 512 x 1024 - TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm42
Transpose: TransposeReshapeMatMulPatternL_input16 -> MatMulReshape2Of3PatternL_view_34 |T1: 2 x 2 x 1024 x 512- TransposeReshapeMatMulPattern--MatMulReshape2Of3Pattern--bmm422
MatMul: div_3, MatMulReshape2Of3PatternL_view_34 -> view_36 |T1: 2 x 2 x 1024 x 512 - MatMulReshape2Of3Pattern--bmm42
Transpose: view_35 -> transpose_10 |T1: 2 x 2 x 1024 x 512 - Transpose12345
Add: input38, transpose_10 -> add_14 |T1: 2 x 2 x 1024 x 512 - add_Tensor6
Mul: add_14, input14 -> _onx_mul024 |T1: 2 x 2 x 1024 x 512 - mul43
Split: _onx_mul024, init7_s2_256_256 -> slice_10, slice_11 |T1: 2 x 2 x 1024 x 256 T1: 2 x 2 x 1024 x 256- SlicesSplitPattern--Slice
Neg: slice_10 -> neg_2 |T1: 2 x 2 x 1024 x 256 - Neg
Concat: slice_11, neg_2 -> add_15 |T1: 2 x 2 x 1024 x 512 - RotaryConcatPartPattern--add_Tensor7
Mul: add_14, input13 -> _onx_mul025 |T1: 2 x 2 x 1024 x 512 - mul45
Add: add_15, _onx_mul025 -> add_16 |T1: 2 x 2 x 1024 x 512 - add_Tensor8
Mul: view_36, input14 -> _onx_mul026 |T1: 2 x 2 x 1024 x 512 - mul47
Split: _onx_mul026, init7_s2_256_256 -> slice_12, slice_13 |T1: 2 x 2 x 1024 x 256 T1: 2 x 2 x 1024 x 256- SlicesSplitPattern--Slice12
Neg: slice_12 -> neg_3 |T1: 2 x 2 x 1024 x 256 - Neg1
Concat: slice_13, neg_3 -> add_17 |T1: 2 x 2 x 1024 x 512 - RotaryConcatPartPattern--add_Tensor9
Mul: view_36, input13 -> _onx_mul027 |T1: 2 x 2 x 1024 x 512 - mul49
Add: add_17, _onx_mul027 -> add_18 |T1: 2 x 2 x 1024 x 512 - add_Tensor10
Transpose: add_13 -> transpose_11 |T1: 2 x 1024 x 2 x 512 - Transpose123456
Transpose: add_16 -> transpose_12 |T1: 2 x 1024 x 2 x 512 - Transpose1234567
Transpose: add_18 -> transpose_13 |T1: 2 x 1024 x 2 x 512 - Transpose12345678
Reshape: transpose_11, init7_s2_2048_10245 -> view_37 |T1: 2048 x 1024 - ReshapeReshapePattern--_unsafe_view2
Gemm: view_37, input12 -> output_6 |T1: 1024 x 1024 - TransposeMatMulPattern--mm9
Reshape: transpose_12, init7_s2_2048_10246 -> view_39 |T1: 2048 x 1024 - ReshapeReshapePattern--_unsafe_view3
Gemm: view_39, input10 -> output_5 |T1: 1024 x 1024 - TransposeMatMulPattern--mm11
Reshape: transpose_13, init7_s2_2048_10247 -> view_41 |T1: 2048 x 1024 - ReshapeReshapePattern--_unsafe_view4
Gemm: view_41, input7 -> mm_21 |T1: 2048 x 1024 - TransposeMatMulPattern--mm14
Gemm: view_41, input8 -> output_4 |T1: 1024 x 1024 - TransposeMatMulPattern--mm13
Gemm: view_39, input9, mm_21 -> MatMulAddPattern--mm_19 |T1: 2048 x 1024 - MatMulAddPattern--TransposeMatMulPattern--mm12
Gemm: view_37, input11, MatMulAddPattern--mm_19 -> add-Reshape2Of3PatternL_add_19 |T1: 2048 x 1024 - MatMulAddPattern--MatMulAddPattern--TransposeMatMulPattern--mm102
Reshape: add-Reshape2Of3PatternL_add_19, init7_s3_2_1024_102413 -> add_20 |T1: 2 x 1024 x 1024 - ReshapeReshapeBinaryPattern--add_Tensor122
Mul: add_20, input0 -> _onx_mul028 |T1: 2 x 1024 x 1024 - mul51
Mul: add_20, input6 -> _onx_mul029 |T1: 2 x 1024 x 1024 - mul53
ReduceSum: _onx_mul029, init7_s2_0_13 -> output_0 |T1: 1024 - ReduceReshapePattern--sum5
Mul: _onx_mul028, input4 -> _onx_mul030 |T1: 2 x 1024 x 1024 - mul55
Mul: _onx_mul028, input5 -> _onx_mul031 |T1: 2 x 1024 x 1024 - mul57
ReduceSum: _onx_mul030, init7_s1_23 -> sum_6 |T1: 2 x 1024 x 1 - sum6
Add: add_12, _onx_mul031 -> add_21 |T1: 2 x 1024 x 1024 - add_Tensor13
Pow: input5, init1_s1_3 -> pow_8 |T1: 2 x 1024 x 1 - Pow12
Mul: sum_6, init1_s_9 -> _onx_mul032 |T1: 2 x 1024 x 1 - mul59
Mul: _onx_mul032, pow_8 -> _onx_mul033 |T1: 2 x 1024 x 1 - mul61
Mul: _onx_mul033, init1_s_12 -> mul-_onx_mul033 |T1: 2 x 1024 x 1 - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst
Mul: mul-_onx_mul033, input4 -> _onx_mul035 |T1: 2 x 1024 x 1024 - SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst2
Add: add_21, _onx_mul035 -> add_22 |T1: 2 x 1024 x 1024 - add_Tensor14
Equal: input3, init7_s_-1 -> eq_2 |T9: 2 x 1024 - Equal
Unsqueeze: eq_2, init7_s1_-12 -> unsqueeze_6 |T9: 2 x 1024 x 1 - Unsqueeze
Where: unsqueeze_6, init1_s1_4, add_22 -> _onx_where0 |T1: 2 x 1024 x 1024 - masked_fill_Scalar
Unsqueeze: input3, init7_s1_-13 -> _onx_unsqueeze0 |T7: 2 x 1024 x 1 - aten__unsafe_index_put
ConstantOfShape: init7_s2_1024_10242 -> _onx_constantofshape05 |T1: 1024 x 1024 - aten__unsafe_index_put2
ScatterND: _onx_constantofshape05, _onx_unsqueeze0, _onx_where0 -> output_3 |T1: 1024 x 1024 - aten__unsafe_index_put3
Constant: -> output_11 |T1: - Constant
Identity: output_11 -> output_12 |T1: - ._update_structures_with_proto
Identity: output_11 -> output_13 |T1: - ._update_structures_with_proto
Identity: output_11 -> output_14 |T1: - ._update_structures_with_proto
output:: output_0 |T1: 1024
output:: output_1 |T1: 1024
output:: output_2 |T1: 1024
output:: output_3 |T1: 1024 x 1024
output:: output_4 |T1: 1024 x 1024
output:: output_5 |T1: 1024 x 1024
output:: output_6 |T1: 1024 x 1024
output:: output_7 |T1: 1024 x 1024
output:: output_8 |T1: 1024 x 1024
output:: output_9 |T1: 1024 x 1024
output:: output_10 |T1: 1024 x 1024
output:: output_11 |T1:
output:: output_12 |T1:
output:: output_13 |T1:
output:: output_14 |T1:
--
[GraphBuilderPatternOptimization.optimize] iteration 0: 117 nodes, priority=1
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul5
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul7
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul11
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul15-Cst2
[SwitchOrderBinaryPattern.match] NONE - line: 184:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul25
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul31
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul33
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor3
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul37
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul41-Cst2
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor4
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul55
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul57
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor13
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=mul61
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Mul, name=SwitchOrderBinaryPattern--MulMulMulScalarPattern--mul65-Cst2
[SwitchOrderBinaryPattern.match] NONE - line: 175:experimental_experiment.xoptim.patterns.onnx_mul, op_type=Add, name=add_Tensor14
[GraphBuilderPatternOptimization.optimize] done all: -0 +0 nodes
[GraphBuilderPatternOptimization.optimize] done after 1 iterations with 117 nodes in 0.002
STAT build_graph_for_pattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.00035422100336290896
STAT check_pattern_00 +0 -0 #it=1 maxmatch=0 i=0 - time=0.00019108100241282955
STAT check_pattern_B0 +0 -0 #it=1 maxmatch=0 i=0 - time=0.0002679980025277473
STAT match_SwitchOrderBinaryPattern +0 -0 #it=1 maxmatch=0 i=0 - time=0.0006592560021090321
STAT remove_identity_nodes +0 -0 #it=1 maxmatch=0 i=0 - time=0.0003321409967611544
--MODEL: 117 nodes, 40 inputs, 15 outputs, 36 initializers--
INPUT: 39 x 1t
INPUT: 1 x 7t
INPUT-SEQ: 40 x Falset
OUTPUT: 15 x 1t
OUTPUT-SEQ: 15 x Falset
INIT: 10 x 1t
INIT: 26 x 7t
NODE: 10 x Add
NODE: 2 x Concat
NODE: 1 x Constant
NODE: 1 x ConstantOfShape
NODE: 1 x Div
NODE: 1 x Equal
NODE: 14 x Gemm
NODE: 3 x Identity
NODE: 4 x MatMul
NODE: 35 x Mul
NODE: 2 x Neg
NODE: 3 x Pow
NODE: 7 x ReduceSum
NODE: 15 x Reshape
NODE: 1 x ScatterND
NODE: 1 x Sigmoid
NODE: 2 x Split
NODE: 2 x Sub
NODE: 9 x Transpose
NODE: 2 x Unsqueeze
NODE: 1 x Where
--MODEL: 117 nodes, 40 inputs, 15 outputs, 36 initializers--DETAILED--
INPUT: 3 x 1t[1024]
INPUT: 7 x 1t[1024x1024]
INPUT: 2 x 1t[1024x512]
INPUT: 7 x 1t[2048x1024]
INPUT: 10 x 1t[2x1024x1024]
INPUT: 3 x 1t[2x1024x1]
INPUT: 1 x 1t[2x2x1024x1024]
INPUT: 2 x 1t[2x2x1024x512]
INPUT: 1 x 1t[4x1024x1024]
INPUT: 2 x 1t[4x1024x512]
INPUT: 1 x 1t[4x512x1024]
INPUT: 1 x 7t[2x1024]
OUTPUT: 3 x 1t[1024]
OUTPUT: 8 x 1t[1024x1024]
OUTPUT: 4 x 1t[1]
INIT: 10 x 1t[1]
INIT: 7 x 7t[1]
INIT: 12 x 7t[2]
INIT: 3 x 7t[3]
INIT: 4 x 7t[4]
NODE: 1 x Add -SIG- 1t[2x1024x1024], 1t[1]
NODE: 5 x Add -SIG- 1t[2x1024x1024], 1t[2x1024x1024]
NODE: 4 x Add -SIG- 1t[2x2x1024x512], 1t[2x2x1024x512]
NODE: 2 x Concat -SIG- 1t[2x2x1024x256], 1t[2x2x1024x256]
NODE: 1 x Constant -SIG-
NODE: 1 x ConstantOfShape -SIG- 7t[2]
NODE: 1 x Div -SIG- 1t[2x2x1024x1024], 1t[1]
NODE: 1 x Equal -SIG- 7t[2x1024], 7t[1]
NODE: 4 x Gemm -SIG- 1t[2048x1024], 1t[1024x1024]
NODE: 3 x Gemm -SIG- 1t[2048x1024], 1t[1024x1024], 1t[2048x1024]
NODE: 7 x Gemm -SIG- 1t[2048x1024], 1t[2048x1024]
NODE: 3 x Identity -SIG- 1t[1]
NODE: 2 x MatMul -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x512]
NODE: 1 x MatMul -SIG- 1t[2x2x1024x512], 1t[2x2x512x1024]
NODE: 1 x MatMul -SIG- 1t[2x2x512x1024], 1t[2x2x1024x1024]
NODE: 1 x Mul -SIG- 1t[2048x1024], 1t[2048x1024]
NODE: 3 x Mul -SIG- 1t[2x1024x1024], 1t[1024]
NODE: 10 x Mul -SIG- 1t[2x1024x1024], 1t[2x1024x1024]
NODE: 3 x Mul -SIG- 1t[2x1024x1024], 1t[2x1024x1]
NODE: 6 x Mul -SIG- 1t[2x1024x1], 1t[1]
NODE: 3 x Mul -SIG- 1t[2x1024x1], 1t[2x1024x1024]
NODE: 3 x Mul -SIG- 1t[2x1024x1], 1t[2x1024x1]
NODE: 1 x Mul -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x1024]
NODE: 1 x Mul -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x1]
NODE: 4 x Mul -SIG- 1t[2x2x1024x512], 1t[1024x512]
NODE: 2 x Neg -SIG- 1t[2x2x1024x256]
NODE: 3 x Pow -SIG- 1t[2x1024x1], 1t[1]
NODE: 3 x ReduceSum -SIG- 1t[2x1024x1024], 7t[1]
NODE: 3 x ReduceSum -SIG- 1t[2x1024x1024], 7t[2]
NODE: 1 x ReduceSum -SIG- 1t[2x2x1024x1024], 7t[1]
NODE: 3 x Reshape -SIG- 1t[2048x1024], 7t[3]
NODE: 1 x Reshape -SIG- 1t[2048x1024], 7t[4]
NODE: 4 x Reshape -SIG- 1t[2x1024x1024], 7t[2]
NODE: 3 x Reshape -SIG- 1t[2x1024x2x512], 7t[2]
NODE: 1 x Reshape -SIG- 1t[4x1024x1024], 7t[4]
NODE: 2 x Reshape -SIG- 1t[4x1024x512], 7t[4]
NODE: 1 x Reshape -SIG- 1t[4x512x1024], 7t[4]
NODE: 1 x ScatterND -SIG- 1t[1024x1024], 7t[2x1024x1], 1t[2x1024x1024]
NODE: 1 x Sigmoid -SIG- 1t[2x1024x1024]
NODE: 2 x Split -SIG- 1t[2x2x1024x512], 7t[2]
NODE: 1 x Sub -SIG- 1t[2x1024x1024], 1t[2x1024x1024]
NODE: 1 x Sub -SIG- 1t[2x2x1024x1024], 1t[2x2x1024x1024]
NODE: 1 x Transpose -SIG- 1t[2x1024x2x512]-perm=0;2;1;3
NODE: 1 x Transpose -SIG- 1t[2x2x1024x1024]-perm=0;1;3;2
NODE: 2 x Transpose -SIG- 1t[2x2x1024x512]-perm=0;1;3;2
NODE: 3 x Transpose -SIG- 1t[2x2x1024x512]-perm=0;2;1;3
NODE: 2 x Transpose -SIG- 1t[2x2x512x1024]-perm=0;1;3;2
NODE: 1 x Unsqueeze -SIG- 7t[2x1024], 7t[1]
NODE: 1 x Unsqueeze -SIG- 9t[2x1024], 7t[1]
NODE: 1 x Where -SIG- 9t[2x1024x1], 1t[1], 1t[2x1024x1024]
[GraphBuilder.optimize] done with 117 nodes in 0.008
Total running time of the script: (0 minutes 0.832 seconds)
Related examples
101: Onnx Model Rewriting
201: Evaluate different ways to export a torch model to ONNX
201: Evaluate different ways to export a torch model to ONNX
102: Fuse kernels in a small Llama Model
102: Fuse kernels in a small Llama Model