Note
Go to the end to download the full example code.
101: Profile an existing model with onnxruntime¶
Profiles any onnx model on CPU.
Preparation¶
import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args
try:
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
)
except ImportError:
js_profile_to_dataframe = None
try:
filename = os.path.join(
os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
)
except NameError:
filename = "data/example_4700-CPUep-opt.onnx"
script_args = get_parsed_args(
"plot_profile_existing_onnx",
filename=(filename, "input file"),
repeat=10,
expose="",
)
for att in ("filename", "repeat"):
print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10
Random inputs.
def create_random_input(sess):
feeds = {}
for i in sess.get_inputs():
shape = i.shape
ot = i.type
if ot == "tensor(float)":
dtype = np.float32
else:
raise ValueError(f"Unsupposed onnx type {ot}.")
t = np.random.rand(*shape).astype(dtype)
feeds[i.name] = t
return feeds
def create_session(filename, profiling=False):
from onnxruntime import InferenceSession, SessionOptions
if not profiling:
return InferenceSession(filename, providers=["CPUExecutionProvider"])
opts = SessionOptions()
opts.enable_profiling = True
return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])
sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.09610969, 0.7728384 , 0.75196004, ..., 0.84491324, 0.79128575,
0.7401204 ],
[0.9479155 , 0.47235975, 0.5942689 , ..., 0.19155996, 0.488694 ,
0.34412995],
[0.14876392, 0.56245506, 0.42884845, ..., 0.9198805 , 0.5506332 ,
0.13670772],
...,
[0.5250744 , 0.6346656 , 0.05917783, ..., 0.4984336 , 0.67380965,
0.06971068],
[0.10134266, 0.26747262, 0.51159984, ..., 0.5425203 , 0.87270105,
0.7203621 ],
[0.6009859 , 0.74893504, 0.90190494, ..., 0.8438301 , 0.91150784,
0.26544774]], shape=(128, 1024), dtype=float32), array([[0.5450286 , 0.05108841, 0.74802214, ..., 0.34395698, 0.09970681,
0.2158931 ],
[0.3145486 , 0.5574005 , 0.38632682, ..., 0.65387124, 0.6815179 ,
0.3011002 ],
[0.78981787, 0.38738513, 0.67306495, ..., 0.6603452 , 0.34233794,
0.9149113 ],
...,
[0.47101635, 0.08823352, 0.884104 , ..., 0.50432026, 0.2064694 ,
0.26791674],
[0.5662113 , 0.0291094 , 0.19492099, ..., 0.416333 , 0.570644 ,
0.12336027],
[0.46335468, 0.7540493 , 0.2930684 , ..., 0.84592444, 0.83217835,
0.17364559]], shape=(1024, 30752), dtype=float32), array([[0.77160484, 0.44732657, 0.8793495 , ..., 0.13767523, 0.28322092,
0.38347045],
[0.50376034, 0.6951072 , 0.34270832, ..., 0.9852718 , 0.59448594,
0.97335804],
[0.6232289 , 0.1364311 , 0.9206978 , ..., 0.13986202, 0.686012 ,
0.40576643],
...,
[0.640396 , 0.49620828, 0.10661684, ..., 0.336249 , 0.5838923 ,
0.7280593 ],
[0.5700163 , 0.2603261 , 0.15854801, ..., 0.8565149 , 0.34878585,
0.24813211],
[0.5626924 , 0.34012517, 0.9681925 , ..., 0.36979583, 0.6830576 ,
0.03190738]], shape=(10, 128), dtype=float32), array([[[[4.675813 , 4.8393917, 4.530805 , ..., 6.5850463, 6.2330737,
5.4244423],
[4.9031816, 5.4459305, 5.242003 , ..., 5.8558264, 5.349012 ,
5.2379303],
[6.1449842, 6.261963 , 5.8365436, ..., 6.8934035, 5.818215 ,
6.8576436],
...,
[6.31502 , 6.184564 , 5.1498985, ..., 4.7691255, 5.3211126,
5.9128103],
[5.7589073, 5.65446 , 5.1514096, ..., 5.3978014, 5.7793846,
5.7055755],
[6.150046 , 6.0458603, 5.7789774, ..., 5.6919484, 4.9626207,
5.324859 ]],
[[5.23308 , 6.6074586, 5.5526505, ..., 6.7088737, 6.4515886,
6.605689 ],
[5.907752 , 7.2733054, 6.4803004, ..., 7.2436705, 7.625379 ,
6.598878 ],
[5.995107 , 6.6457796, 7.346325 , ..., 7.8733835, 7.744594 ,
7.3948293],
...,
[7.216176 , 7.8466363, 7.0026503, ..., 6.7702537, 6.845193 ,
6.908416 ],
[7.3323965, 7.7593193, 6.0310626, ..., 6.0389333, 6.5763497,
6.9780807],
[7.755426 , 7.8468566, 6.6688113, ..., 6.5624204, 6.388437 ,
6.630726 ]],
[[6.1693115, 6.3360524, 5.7711062, ..., 7.301279 , 7.476224 ,
7.1803865],
[5.6562405, 6.940252 , 6.729952 , ..., 7.276113 , 7.164065 ,
6.4023805],
[6.6390224, 7.4007125, 7.504481 , ..., 7.2014318, 7.6115255,
7.6915755],
...,
[8.021874 , 7.5101 , 7.514248 , ..., 6.454876 , 6.2781773,
6.844224 ],
[7.1982107, 7.663066 , 6.7969904, ..., 5.8152018, 6.313769 ,
6.8283544],
[6.97674 , 7.514856 , 7.1654916, ..., 6.6485343, 6.0571275,
6.1950116]],
...,
[[5.250869 , 6.19024 , 5.323181 , ..., 5.994178 , 6.261137 ,
6.1661167],
[5.456216 , 5.6767855, 5.1586323, ..., 6.633404 , 5.7539883,
5.8595333],
[6.2778807, 6.055975 , 6.7682457, ..., 6.6117325, 6.6505585,
5.7998695],
...,
[6.555544 , 7.179186 , 6.5355883, ..., 5.6287804, 6.0172253,
4.982135 ],
[6.309924 , 6.5869055, 5.8591833, ..., 5.7681103, 5.536516 ,
5.5897274],
[5.7137356, 6.009883 , 6.0124655, ..., 5.9652815, 5.05702 ,
6.3460093]],
[[6.806348 , 7.5467434, 6.6534033, ..., 8.538446 , 7.9266596,
8.176761 ],
[7.7786427, 7.4173274, 7.744265 , ..., 8.736114 , 8.012027 ,
8.194593 ],
[8.307279 , 8.539384 , 8.796736 , ..., 9.735919 , 8.960765 ,
8.6386175],
...,
[8.377878 , 9.7039995, 8.404376 , ..., 7.522732 , 8.194346 ,
7.623464 ],
[8.239339 , 9.0724745, 7.5216002, ..., 7.866619 , 7.4154396,
8.1689825],
[8.725547 , 8.374278 , 8.149051 , ..., 8.232766 , 7.010842 ,
7.2677546]],
[[7.031036 , 6.4757147, 6.307449 , ..., 8.507093 , 8.194411 ,
7.3640084],
[7.44024 , 7.726051 , 6.5270677, ..., 7.9953556, 8.34257 ,
7.2676883],
[7.387988 , 9.045776 , 8.43912 , ..., 8.32619 , 8.143594 ,
7.542256 ],
...,
[8.217804 , 8.8999195, 8.855006 , ..., 6.9249506, 7.4461083,
8.200155 ],
[7.690132 , 8.070051 , 7.920203 , ..., 7.6894836, 6.5439115,
7.548409 ],
[8.138238 , 8.376512 , 7.9650464, ..., 6.967308 , 7.0520144,
6.8442097]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[ 374, 131, 9, ..., 237, 367, 369],
[ 746, 748, 630, ..., 733, 615, 616],
[ 1364, 1245, 1127, ..., 1477, 1109, 1360],
...,
[14263, 14266, 14144, ..., 14000, 14007, 14257],
[14511, 14512, 14767, ..., 14622, 14624, 14752],
[14880, 14884, 15015, ..., 14992, 14998, 15127]],
[[ 373, 7, 132, ..., 485, 491, 369],
[ 498, 874, 755, ..., 857, 615, 616],
[ 1242, 1369, 1375, ..., 1477, 1481, 1113],
...,
[14263, 14264, 14268, ..., 14002, 14007, 14381],
[14759, 14636, 14640, ..., 14746, 14749, 14504],
[14883, 15008, 15015, ..., 15240, 14996, 15002]],
[[ 374, 7, 132, ..., 485, 366, 371],
[ 498, 500, 755, ..., 858, 615, 616],
[ 1242, 1369, 1375, ..., 1477, 1110, 1113],
...,
[14263, 14264, 14268, ..., 14002, 14006, 14008],
[14759, 14512, 14519, ..., 14746, 14624, 14505],
[15004, 15008, 15263, ..., 14992, 14999, 15002]],
...,
[[ 373, 131, 133, ..., 112, 367, 370],
[ 622, 873, 755, ..., 857, 860, 867],
[ 1242, 1244, 1127, ..., 1477, 1481, 1112],
...,
[14139, 14265, 14144, ..., 14373, 14007, 14008],
[14759, 14636, 14767, ..., 14621, 14749, 14505],
[15005, 15256, 15263, ..., 15240, 14999, 15375]],
[[ 373, 131, 132, ..., 112, 367, 369],
[ 746, 624, 631, ..., 733, 615, 617],
[ 1366, 1246, 1127, ..., 1476, 1110, 1114],
...,
[14263, 14265, 14144, ..., 14003, 14007, 14008],
[14759, 14636, 14518, ..., 14621, 14872, 14628],
[15005, 15009, 15139, ..., 15240, 14999, 15373]],
[[ 249, 131, 382, ..., 112, 366, 368],
[ 746, 749, 507, ..., 859, 614, 616],
[ 1240, 1246, 1127, ..., 1228, 1109, 1113],
...,
[14262, 14141, 14145, ..., 14002, 14004, 14257],
[14387, 14388, 14766, ..., 14621, 14748, 14628],
[15005, 15008, 15139, ..., 15117, 14999, 15127]]]],
shape=(1, 32, 31, 31), dtype=int64), array([[6.7413177, 6.5252485, 6.6804986, ..., 7.863183 , 8.615847 ,
8.200155 ]], shape=(1, 30752), dtype=float32), array([[118367.37 , 119177.234, 118987.086, ..., 119547.74 , 118254.24 ,
118718.13 ]], shape=(1, 1024), dtype=float32), array([[6.4128168e+07, 5.8999764e+07, 5.9586596e+07, 6.0295636e+07,
6.1974340e+07, 6.0698212e+07, 6.1241008e+07, 6.1970800e+07,
6.1567644e+07, 6.1991200e+07, 6.1572856e+07, 6.1601976e+07,
6.1340664e+07, 6.0645956e+07, 5.8501168e+07, 6.0780040e+07,
6.0628152e+07, 5.9184696e+07, 6.1584416e+07, 6.0284976e+07,
6.0002884e+07, 6.1890352e+07, 6.1937068e+07, 6.1375384e+07,
6.2477452e+07, 6.2767800e+07, 6.0903304e+07, 6.0515528e+07,
6.2023640e+07, 6.0238416e+07, 5.9209660e+07, 5.9731120e+07,
6.0374568e+07, 5.9482168e+07, 6.0115768e+07, 6.1582104e+07,
6.1052744e+07, 6.0998608e+07, 6.1180288e+07, 5.9299032e+07,
6.0731440e+07, 5.9612764e+07, 5.9659656e+07, 5.9558632e+07,
6.0549936e+07, 6.1873928e+07, 5.6873096e+07, 6.1038344e+07,
6.2429520e+07, 6.1385976e+07, 5.9541248e+07, 5.9505952e+07,
6.0852816e+07, 6.1084104e+07, 6.0501532e+07, 6.1744480e+07,
5.9608024e+07, 5.8768036e+07, 6.0296912e+07, 5.9994420e+07,
5.8566616e+07, 6.0935168e+07, 6.1183752e+07, 6.1137880e+07,
6.3117792e+07, 6.0234340e+07, 6.0272708e+07, 6.0511352e+07,
6.1614240e+07, 6.0392688e+07, 6.1473696e+07, 6.0956428e+07,
5.9723384e+07, 6.1914472e+07, 5.9426000e+07, 6.1341472e+07,
5.9775700e+07, 6.1335480e+07, 6.0282920e+07, 6.0643136e+07,
6.1424400e+07, 5.9674412e+07, 6.0923216e+07, 5.9795080e+07,
5.8237068e+07, 6.0031240e+07, 6.0288172e+07, 6.1723260e+07,
6.0229552e+07, 5.9320568e+07, 6.1351368e+07, 6.0019828e+07,
6.2556932e+07, 5.9816776e+07, 6.0898860e+07, 6.0151792e+07,
6.2236608e+07, 6.1135512e+07, 6.1322792e+07, 5.9843996e+07,
6.1660856e+07, 5.9824032e+07, 6.1159356e+07, 5.8991304e+07,
5.9940408e+07, 6.2538884e+07, 6.0827760e+07, 6.1265472e+07,
6.1207960e+07, 6.1216996e+07, 6.1483088e+07, 6.1859440e+07,
5.9887332e+07, 6.2837212e+07, 6.2696132e+07, 6.1264368e+07,
6.0190904e+07, 6.0846368e+07, 6.1244204e+07, 6.0282240e+07,
5.9752948e+07, 5.9783396e+07, 5.9826720e+07, 6.1863904e+07,
6.1392704e+07, 6.0093600e+07, 5.9726020e+07, 5.9484752e+07]],
dtype=float32), array([[3.8916920e+09, 4.4963220e+09, 3.7605591e+09, 3.7983089e+09,
3.7505807e+09, 3.8865334e+09, 4.0423076e+09, 4.2628155e+09,
3.9225544e+09, 3.7468856e+09]], dtype=float32)]
Profiling¶
sess = create_session(script_args.filename, profiling=True)
for _ in range(script_args.repeat):
sess.run(None, feeds)
prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
df = js_profile_to_dataframe(prof, first_it_out=True)
print(df.columns)
df.to_csv("plot_profile_existing_onnx.csv")
df.to_excel("plot_profile_existing_onnx.xlsx")
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "dort")
fig.tight_layout()
fig.savefig("plot_profile_existing_onnx.png")
else:
print("Install onnx-extended first.")

Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
'args_thread_scheduling_stats', 'args_output_size',
'args_parameter_size', 'args_activation_size', 'args_node_index',
'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
'it==0'],
dtype='object')
Total running time of the script: (0 minutes 1.231 seconds)
Related examples

201: Evaluate different ways to export a torch model to ONNX
201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting
101: Onnx Model Optimization based on Pattern Rewriting

201: Use torch to export a scikit-learn model into ONNX
201: Use torch to export a scikit-learn model into ONNX