Note
Go to the end to download the full example code
101: Profile an existing model with onnxruntime¶
Profiles any onnx model on CPU.
Preparation¶
import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args
try:
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
)
except ImportError:
js_profile_to_dataframe = None
try:
filename = os.path.join(
os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
)
except NameError:
filename = "data/example_4700-CPUep-opt.onnx"
script_args = get_parsed_args(
"plot_profile_existing_onnx",
filename=(filename, "input file"),
repeat=10,
expose="",
)
for att in "filename,repeat".split(","):
print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10
Random inputs.
def create_random_input(sess):
feeds = {}
for i in sess.get_inputs():
shape = i.shape
ot = i.type
if ot == "tensor(float)":
dtype = np.float32
else:
raise ValueError(f"Unsupposed onnx type {ot}.")
t = np.random.rand(*shape).astype(dtype)
feeds[i.name] = t
return feeds
def create_session(filename, profiling=False):
from onnxruntime import InferenceSession, SessionOptions
if not profiling:
return InferenceSession(filename, providers=["CPUExecutionProvider"])
opts = SessionOptions()
opts.enable_profiling = True
return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])
sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.91398245, 0.57628196, 0.48764992, ..., 0.3955947 , 0.52806693,
0.52786803],
[0.31164184, 0.58315444, 0.80397445, ..., 0.61994034, 0.94090754,
0.17720585],
[0.3395866 , 0.5943828 , 0.53631425, ..., 0.35516572, 0.09682213,
0.33556533],
...,
[0.25707904, 0.7535674 , 0.7554002 , ..., 0.36074117, 0.06085893,
0.13981643],
[0.07814518, 0.09656924, 0.0421176 , ..., 0.8285064 , 0.28555152,
0.42705196],
[0.41327047, 0.14391418, 0.40519983, ..., 0.5858572 , 0.62700355,
0.7861884 ]], dtype=float32), array([[0.47667688, 0.39695138, 0.9381092 , ..., 0.68025386, 0.58981436,
0.19682835],
[0.96116537, 0.8519381 , 0.12407897, ..., 0.91617495, 0.2277926 ,
0.9581247 ],
[0.69623035, 0.5257477 , 0.20292161, ..., 0.4543687 , 0.6661971 ,
0.22071014],
...,
[0.38079822, 0.68998075, 0.7934502 , ..., 0.99610054, 0.05723822,
0.02473171],
[0.54504544, 0.2616041 , 0.29621118, ..., 0.4465734 , 0.28895664,
0.769591 ],
[0.4225101 , 0.58036405, 0.61927366, ..., 0.5072816 , 0.4815128 ,
0.2072598 ]], dtype=float32), array([[0.87149805, 0.46779412, 0.7452013 , ..., 0.59170425, 0.9209106 ,
0.42639002],
[0.82107943, 0.32070556, 0.8573875 , ..., 0.56483036, 0.260574 ,
0.09135629],
[0.24738845, 0.2951498 , 0.7146122 , ..., 0.380619 , 0.09730798,
0.60640943],
...,
[0.46793056, 0.66070056, 0.48432958, ..., 0.99664325, 0.9882933 ,
0.22435805],
[0.20465992, 0.15862264, 0.65854317, ..., 0.84058243, 0.66796696,
0.79107755],
[0.33797637, 0.05030651, 0.7483493 , ..., 0.0564948 , 0.6848731 ,
0.46056956]], dtype=float32), array([[[[7.205508 , 7.9895926, 8.416544 , ..., 6.770774 , 7.1625376,
6.483858 ],
[7.2672195, 7.465618 , 7.3520966, ..., 6.974065 , 7.622259 ,
7.517603 ],
[8.056629 , 7.7550955, 7.3357353, ..., 7.7675734, 8.218457 ,
8.343167 ],
...,
[6.913253 , 7.083952 , 6.453188 , ..., 7.847801 , 7.494713 ,
7.8264074],
[5.8578596, 5.2047787, 5.60857 , ..., 7.2954445, 7.638356 ,
7.395913 ],
[7.3160095, 5.8252716, 5.454154 , ..., 7.4482 , 7.358966 ,
6.83504 ]],
[[6.260421 , 6.4745965, 7.3947716, ..., 6.7765603, 6.589321 ,
6.5809693],
[6.9425473, 6.665863 , 7.7645326, ..., 6.6559825, 6.4143305,
7.185762 ],
[7.4409366, 7.647948 , 6.9305844, ..., 6.788102 , 7.252824 ,
7.472464 ],
...,
[6.580416 , 5.977857 , 6.6630354, ..., 7.151876 , 7.6054873,
6.188504 ],
[6.3581014, 4.8171573, 5.2133064, ..., 6.865668 , 5.930305 ,
6.7570367],
[6.217736 , 5.5824056, 5.3415155, ..., 7.3240457, 6.229552 ,
6.772177 ]],
[[5.116146 , 6.1506577, 6.2882648, ..., 5.455061 , 5.7735896,
6.2993226],
[5.7724338, 5.2778144, 5.490645 , ..., 5.749422 , 5.0189543,
6.112331 ],
[5.991749 , 6.41042 , 6.491544 , ..., 6.086307 , 6.8470335,
5.5497823],
...,
[5.746537 , 5.4410415, 6.043397 , ..., 6.1263237, 6.614131 ,
5.5966153],
[4.83159 , 3.9778526, 3.850477 , ..., 5.3486814, 5.485555 ,
6.6554236],
[5.3024807, 5.1484284, 5.582978 , ..., 6.465997 , 5.312218 ,
5.423742 ]],
...,
[[7.871123 , 8.718124 , 8.500517 , ..., 8.246982 , 6.7822537,
7.3001914],
[7.943476 , 8.12264 , 8.797746 , ..., 7.1407943, 8.218125 ,
8.40282 ],
[8.7769985, 8.391742 , 7.024096 , ..., 8.839397 , 9.07862 ,
8.473246 ],
...,
[7.268252 , 6.7797985, 7.61532 , ..., 8.0960045, 8.261134 ,
8.577002 ],
[6.6318307, 6.080726 , 6.6229615, ..., 8.542605 , 7.6251235,
9.220777 ],
[6.7647943, 6.6094775, 5.9505105, ..., 8.322737 , 8.047828 ,
8.65976 ]],
[[6.0740004, 6.434598 , 6.6850815, ..., 5.139678 , 7.0877285,
4.9569206],
[5.4855757, 6.0544186, 6.5418234, ..., 5.907666 , 6.399254 ,
6.6314116],
[7.3262463, 6.4814157, 6.3402596, ..., 7.255395 , 5.759486 ,
6.988225 ],
...,
[5.3537436, 6.1104374, 6.298465 , ..., 6.884762 , 6.7736454,
6.212181 ],
[4.9507933, 3.9594915, 4.143952 , ..., 5.8256702, 6.6111565,
5.860141 ],
[6.217616 , 5.4622936, 4.6362677, ..., 5.642166 , 6.3643227,
5.842002 ]],
[[6.9261794, 6.779761 , 7.083336 , ..., 6.6269736, 5.9768634,
5.9925976],
[6.841673 , 6.8821535, 7.004058 , ..., 7.052854 , 6.4486403,
6.5491166],
[7.0067296, 7.0405135, 6.5167766, ..., 7.1424656, 7.1956654,
7.4732656],
...,
[5.8668175, 5.9137573, 6.1626215, ..., 6.865367 , 7.8011456,
6.1410894],
[5.517689 , 4.248037 , 5.042473 , ..., 7.0548387, 6.4594965,
6.0213876],
[5.998788 , 5.2361135, 5.016924 , ..., 7.831756 , 6.2911563,
6.0944743]]]], dtype=float32), array([[[[ 372, 379, 258, ..., 362, 365, 371],
[ 497, 872, 879, ..., 980, 739, 618],
[ 995, 996, 1003, ..., 1104, 1233, 1114],
...,
[14137, 14018, 13899, ..., 14001, 14006, 14135],
[14510, 14763, 14643, ..., 14871, 14873, 14752],
[14883, 15259, 15262, ..., 14995, 14997, 15000]],
[[ 373, 128, 382, ..., 238, 117, 371],
[ 496, 624, 878, ..., 980, 613, 618],
[ 995, 996, 1003, ..., 1104, 1357, 1113],
...,
[14261, 14264, 14269, ..., 14002, 14006, 14009],
[14510, 14762, 14766, ..., 14870, 14873, 14752],
[14883, 15259, 15014, ..., 14995, 14998, 15124]],
[[ 372, 5, 11, ..., 362, 365, 370],
[ 746, 875, 879, ..., 856, 613, 617],
[ 994, 998, 1127, ..., 1104, 1356, 1114],
...,
[14260, 13893, 13899, ..., 14002, 14006, 14135],
[14387, 14760, 14767, ..., 14870, 14872, 14752],
[15007, 15135, 15260, ..., 14995, 15120, 15251]],
...,
[[ 372, 254, 258, ..., 238, 240, 494],
[ 496, 873, 879, ..., 980, 613, 741],
[ 994, 996, 1003, ..., 1105, 1232, 1484],
...,
[14261, 14018, 13898, ..., 14127, 14007, 14135],
[14510, 14389, 14767, ..., 14870, 14872, 14752],
[14883, 15135, 15137, ..., 14995, 14996, 15000]],
[[ 248, 4, 382, ..., 363, 241, 369],
[ 497, 874, 879, ..., 980, 863, 742],
[ 995, 997, 1375, ..., 1104, 1358, 1112],
...,
[14137, 14019, 13899, ..., 14001, 14006, 14011],
[14510, 14636, 14766, ..., 14871, 14873, 14876],
[15006, 15259, 15014, ..., 14995, 15244, 15001]],
[[ 373, 252, 382, ..., 362, 240, 368],
[ 496, 624, 506, ..., 980, 862, 618],
[ 995, 996, 1003, ..., 1104, 1357, 1114],
...,
[14262, 14142, 13898, ..., 14000, 14005, 14135],
[14509, 14636, 14766, ..., 14870, 14873, 14752],
[15007, 15259, 15013, ..., 14995, 14997, 15373]]]], dtype=int64), array([[8.478149 , 7.7415524, 8.194025 , ..., 7.9691305, 8.42513 ,
7.831756 ]], dtype=float32), array([[122311.34 , 122470.555, 122021.74 , ..., 121663.766, 122541.27 ,
122380.37 ]], dtype=float32), array([[62954464., 61740216., 61013392., 60682196., 63875272., 62679800.,
63408672., 60948412., 64281992., 63381936., 64307360., 64403020.,
61565344., 62251208., 62248408., 63189680., 63147016., 63402980.,
62256272., 62496312., 64481904., 62637632., 62345368., 64697372.,
61741924., 61665400., 61223840., 61516116., 64228740., 62423644.,
62610144., 61648480., 64921048., 62355680., 62512224., 63721700.,
61828672., 61592120., 62517488., 64757480., 62091776., 61604048.,
61134364., 62620720., 64929752., 62785368., 62697552., 64157768.,
61666024., 63453024., 63248024., 63152216., 62553164., 60997044.,
61797560., 63103188., 63727820., 62445656., 62467520., 63751204.,
65052816., 64381952., 60688352., 63484808., 62104568., 61991008.,
64230896., 63531720., 62413112., 62044092., 63717560., 63264736.,
63445592., 63367236., 60130480., 62355728., 62062500., 63019924.,
61160204., 63512624., 62444544., 63664784., 64227928., 62368784.,
63602248., 63713976., 62248200., 61052456., 62105060., 61357232.,
63723532., 63136000., 59691656., 62853992., 62107412., 63981192.,
62717824., 61267524., 63453692., 61288608., 62116480., 62351944.,
65580556., 63909912., 63410272., 62190400., 62707108., 61451220.,
62037912., 62947060., 63640856., 62525152., 61416576., 61228320.,
62707616., 63277360., 63264856., 60375660., 61273596., 61515272.,
62446944., 62716088., 61571832., 62582452., 61839944., 61273944.,
62814292., 62870916.]], dtype=float32), array([[4.0849713e+09, 3.9629972e+09, 3.8007370e+09, 4.1178373e+09,
3.7346391e+09, 4.1855841e+09, 4.5715620e+09, 3.7984412e+09,
4.2013747e+09, 4.2602982e+09]], dtype=float32)]
Profiling¶
sess = create_session(script_args.filename, profiling=True)
for i in range(script_args.repeat):
sess.run(None, feeds)
prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
df = js_profile_to_dataframe(prof, first_it_out=True)
print(df.columns)
df.to_csv("plot_profile_existing_onnx.csv")
df.to_excel("plot_profile_existing_onnx.xlsx")
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "dort")
fig.tight_layout()
fig.savefig("plot_profile_existing_onnx.png")
else:
print("Install onnx-extended first.")
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name', 'args_op_name',
'op_name', 'args_thread_scheduling_stats', 'args_output_size',
'args_parameter_size', 'args_activation_size', 'args_node_index',
'args_provider', 'event_name', 'iteration', 'it==0'],
dtype='object')
Total running time of the script: (0 minutes 5.058 seconds)