Note
Go to the end to download the full example code.
101: Profile an existing model with onnxruntime¶
Profiles any onnx model on CPU.
Preparation¶
import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args
try:
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
)
except ImportError:
js_profile_to_dataframe = None
try:
filename = os.path.join(
os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
)
except NameError:
filename = "data/example_4700-CPUep-opt.onnx"
script_args = get_parsed_args(
"plot_profile_existing_onnx",
filename=(filename, "input file"),
repeat=10,
expose="",
)
for att in ("filename", "repeat"):
print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10
Random inputs.
def create_random_input(sess):
feeds = {}
for i in sess.get_inputs():
shape = i.shape
ot = i.type
if ot == "tensor(float)":
dtype = np.float32
else:
raise ValueError(f"Unsupposed onnx type {ot}.")
t = np.random.rand(*shape).astype(dtype)
feeds[i.name] = t
return feeds
def create_session(filename, profiling=False):
from onnxruntime import InferenceSession, SessionOptions
if not profiling:
return InferenceSession(filename, providers=["CPUExecutionProvider"])
opts = SessionOptions()
opts.enable_profiling = True
return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])
sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.48701167, 0.8743634 , 0.41765413, ..., 0.5896347 , 0.50106037,
0.79936886],
[0.22311887, 0.7746693 , 0.36482093, ..., 0.0888119 , 0.63223225,
0.24123782],
[0.42908487, 0.31348827, 0.37753874, ..., 0.2996923 , 0.5862409 ,
0.46235615],
...,
[0.63901 , 0.958252 , 0.8576642 , ..., 0.84840333, 0.27156168,
0.92425513],
[0.7107026 , 0.22514978, 0.7041683 , ..., 0.8769297 , 0.5114278 ,
0.2611587 ],
[0.09398966, 0.04657802, 0.36166787, ..., 0.12336926, 0.382032 ,
0.8705768 ]], shape=(128, 1024), dtype=float32), array([[0.70109737, 0.30644372, 0.73326844, ..., 0.7341234 , 0.19501053,
0.7508927 ],
[0.559656 , 0.33357525, 0.8111176 , ..., 0.17147014, 0.36077788,
0.14913547],
[0.41688848, 0.24148993, 0.57769305, ..., 0.9043198 , 0.32304505,
0.93879217],
...,
[0.78704005, 0.15369062, 0.5913906 , ..., 0.37713814, 0.7967207 ,
0.8414134 ],
[0.44902855, 0.52110785, 0.875147 , ..., 0.24375667, 0.66569865,
0.08956924],
[0.10142928, 0.09177364, 0.41900054, ..., 0.17988825, 0.7206119 ,
0.42871594]], shape=(1024, 30752), dtype=float32), array([[0.86050165, 0.9987919 , 0.6943593 , ..., 0.6296961 , 0.9132347 ,
0.36226135],
[0.39775822, 0.28806388, 0.61356646, ..., 0.12404636, 0.74102914,
0.8745291 ],
[0.3514584 , 0.78925663, 0.69602776, ..., 0.08101459, 0.2081717 ,
0.57725066],
...,
[0.21829125, 0.6406352 , 0.10680309, ..., 0.19909513, 0.99277294,
0.20461668],
[0.5313196 , 0.05398373, 0.03478933, ..., 0.08970655, 0.14648752,
0.2508738 ],
[0.66879046, 0.98973495, 0.37505123, ..., 0.32469818, 0.3674183 ,
0.25382754]], shape=(10, 128), dtype=float32), array([[[[6.9381046, 6.5217695, 7.6244864, ..., 5.582769 , 5.8021383,
6.2628956],
[7.6268106, 7.2249756, 7.9783726, ..., 5.9969845, 6.0601482,
6.789772 ],
[8.471845 , 8.172918 , 7.848505 , ..., 6.512267 , 6.198912 ,
7.7962027],
...,
[7.6225095, 7.7220645, 8.021209 , ..., 8.20195 , 8.130193 ,
7.501075 ],
[7.16429 , 7.809957 , 8.504215 , ..., 8.331161 , 7.5941954,
8.112395 ],
[6.6636386, 6.8771586, 8.429311 , ..., 7.214736 , 7.764557 ,
6.867657 ]],
[[6.942608 , 7.002553 , 7.092223 , ..., 5.5438323, 5.919307 ,
6.4588156],
[7.931553 , 6.8467336, 8.158717 , ..., 5.805227 , 6.538684 ,
7.9998164],
[8.64638 , 7.670035 , 9.19433 , ..., 5.948997 , 6.564786 ,
7.8416657],
...,
[7.624467 , 6.9467664, 8.990562 , ..., 8.466427 , 7.6095743,
7.165219 ],
[7.355692 , 8.671929 , 8.821998 , ..., 6.906161 , 8.366258 ,
8.648574 ],
[6.6341805, 8.415855 , 8.414853 , ..., 7.8235483, 7.994326 ,
6.9552045]],
[[8.164729 , 6.5867295, 8.375455 , ..., 6.138112 , 6.583139 ,
6.581451 ],
[8.739716 , 7.1404786, 8.5598 , ..., 6.6141286, 6.797036 ,
8.238909 ],
[8.733824 , 8.301625 , 9.551962 , ..., 6.5382314, 6.8334813,
8.334098 ],
...,
[7.4179096, 7.4367905, 8.479232 , ..., 9.271135 , 7.3919373,
7.9040008],
[7.5609527, 8.912743 , 8.987174 , ..., 7.988373 , 8.6921215,
8.542471 ],
[7.477041 , 8.285168 , 8.529076 , ..., 7.860715 , 8.320988 ,
7.324272 ]],
...,
[[6.5135617, 6.6306157, 7.2058196, ..., 5.273611 , 6.055018 ,
6.3206887],
[7.1804037, 8.146592 , 7.4991827, ..., 5.9763246, 5.8871 ,
7.3664207],
[6.8489943, 8.127537 , 8.35078 , ..., 5.554768 , 5.6488175,
7.7726083],
...,
[6.4850345, 7.0976562, 7.1712484, ..., 7.5530953, 6.952131 ,
7.9194865],
[7.8035326, 8.187436 , 8.276912 , ..., 7.218614 , 7.766487 ,
8.483334 ],
[7.1173396, 7.586395 , 7.9764857, ..., 7.0274477, 6.920168 ,
6.6587515]],
[[6.4631553, 6.628012 , 7.0430827, ..., 5.1135397, 5.8247857,
6.3276763],
[6.99526 , 6.884177 , 6.9137273, ..., 5.8827195, 6.256574 ,
6.662396 ],
[6.997056 , 7.254697 , 8.030955 , ..., 5.343504 , 6.2028036,
7.165063 ],
...,
[6.3499227, 6.697188 , 7.2855396, ..., 6.9896274, 7.5395117,
6.501973 ],
[6.8816857, 7.331365 , 7.9585133, ..., 7.4408507, 6.645596 ,
7.789961 ],
[6.2344813, 6.9750338, 7.0859184, ..., 6.461979 , 7.0840373,
6.7824125]],
[[5.157804 , 5.0492544, 6.3488107, ..., 4.6396203, 5.6160116,
6.3968463],
[6.204927 , 6.5673246, 7.0573864, ..., 5.4988527, 5.306164 ,
5.5089493],
[7.104061 , 7.3269444, 6.97711 , ..., 5.161616 , 5.345628 ,
6.5845017],
...,
[6.328454 , 7.4622273, 6.851907 , ..., 6.154734 , 6.6091948,
6.6437707],
[6.2652664, 7.0809107, 7.2807364, ..., 7.6275997, 7.6195917,
6.556688 ],
[6.112903 , 5.683591 , 6.3526325, ..., 6.390601 , 6.3589277,
6.285007 ]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[ 374, 255, 256, ..., 487, 364, 495],
[ 622, 502, 506, ..., 610, 613, 743],
[ 995, 1121, 1375, ..., 1231, 1108, 1486],
...,
[14136, 14266, 14020, ..., 14000, 14254, 14132],
[14632, 14512, 14765, ..., 14868, 14872, 14754],
[14881, 15257, 15012, ..., 14993, 14996, 15249]],
[[ 250, 255, 380, ..., 486, 489, 247],
[ 496, 501, 505, ..., 859, 613, 619],
[ 995, 1120, 1251, ..., 1107, 1234, 1236],
...,
[14263, 14265, 13896, ..., 14124, 14004, 14256],
[14632, 14761, 14641, ..., 14744, 14872, 14878],
[15006, 15011, 15012, ..., 15364, 14996, 15251]],
[[ 250, 254, 257, ..., 363, 488, 371],
[ 497, 503, 506, ..., 859, 612, 619],
[ 993, 1244, 1251, ..., 1107, 1110, 1237],
...,
[14263, 14266, 14144, ..., 14248, 14254, 14132],
[14756, 14761, 14641, ..., 14744, 14624, 14878],
[14883, 15010, 15012, ..., 15364, 15368, 15125]],
...,
[[ 251, 255, 381, ..., 486, 489, 371],
[ 746, 501, 504, ..., 735, 860, 619],
[ 1119, 997, 1374, ..., 1104, 1108, 1484],
...,
[14261, 14142, 14269, ..., 14248, 14129, 14259],
[14632, 14637, 14641, ..., 14620, 14872, 14878],
[15130, 14884, 15014, ..., 15365, 14996, 15251]],
[[ 250, 255, 381, ..., 487, 489, 495],
[ 622, 502, 504, ..., 735, 612, 741],
[ 1119, 996, 1375, ..., 1104, 1108, 1237],
...,
[14263, 14266, 14145, ..., 14000, 14129, 14133],
[14633, 14512, 14764, ..., 14620, 14872, 14877],
[15130, 15009, 15012, ..., 15365, 15369, 15251]],
[[ 249, 378, 256, ..., 487, 365, 495],
[ 746, 503, 506, ..., 611, 612, 616],
[ 994, 1120, 1251, ..., 1231, 1109, 1484],
...,
[14136, 14265, 14023, ..., 14372, 14006, 14382],
[14632, 14636, 14640, ..., 14620, 14872, 14878],
[15005, 15258, 15260, ..., 14992, 15370, 15249]]]],
shape=(1, 32, 31, 31), dtype=int64), array([[ 8.551404 , 9.075711 , 10.183338 , ..., 7.4780765, 7.504664 ,
7.6275997]], shape=(1, 30752), dtype=float32), array([[124201.74, 123932.5 , 123731.24, ..., 124368.14, 123931.2 ,
124096.99]], shape=(1, 1024), dtype=float32), array([[6.4778988e+07, 6.4752188e+07, 6.3478560e+07, 6.3431992e+07,
6.2845112e+07, 6.2816280e+07, 6.3186184e+07, 6.2828616e+07,
6.3401080e+07, 6.4651852e+07, 6.5227608e+07, 6.4261348e+07,
6.3334416e+07, 6.4268824e+07, 6.4465036e+07, 6.2655608e+07,
6.3980860e+07, 6.3351448e+07, 6.3707960e+07, 6.3757176e+07,
6.3070980e+07, 6.3671792e+07, 6.2698872e+07, 6.4065440e+07,
6.2696552e+07, 6.1894188e+07, 6.1968096e+07, 6.3503116e+07,
6.3428416e+07, 6.2804112e+07, 6.2969936e+07, 6.4613808e+07,
6.3114368e+07, 6.1673216e+07, 6.1399904e+07, 6.3363136e+07,
6.2868064e+07, 6.3285360e+07, 6.3200552e+07, 6.3906976e+07,
6.3544144e+07, 6.4617608e+07, 6.4236456e+07, 6.4171144e+07,
6.4546072e+07, 6.5735400e+07, 6.1626588e+07, 6.2508152e+07,
6.4805248e+07, 6.3258272e+07, 6.2658680e+07, 6.4459128e+07,
6.2330432e+07, 6.3170896e+07, 6.3915176e+07, 6.5436168e+07,
6.1975424e+07, 6.2712816e+07, 6.1869632e+07, 6.4432088e+07,
6.1149184e+07, 6.5277240e+07, 6.4590152e+07, 6.3650064e+07,
6.5898972e+07, 6.4919588e+07, 6.5656608e+07, 6.1004656e+07,
6.3604208e+07, 6.3417440e+07, 6.3228576e+07, 6.5363364e+07,
6.2866740e+07, 6.2075844e+07, 6.3093392e+07, 6.2610008e+07,
6.4688008e+07, 6.3479416e+07, 6.2789948e+07, 6.3242288e+07,
6.5709732e+07, 6.2552616e+07, 6.2463848e+07, 6.4400504e+07,
6.3953224e+07, 6.3333696e+07, 6.0803796e+07, 6.4233304e+07,
6.2097032e+07, 6.2286620e+07, 6.2913760e+07, 6.2893208e+07,
6.3360688e+07, 6.2201636e+07, 6.4829096e+07, 6.2872588e+07,
6.5251840e+07, 6.5018896e+07, 6.3913596e+07, 6.3603968e+07,
6.6101288e+07, 6.3634424e+07, 6.2189604e+07, 6.3058096e+07,
6.3686656e+07, 6.3947236e+07, 6.3781296e+07, 6.5288384e+07,
6.3459504e+07, 6.3316600e+07, 6.3293832e+07, 6.4853800e+07,
6.2266100e+07, 6.2312656e+07, 6.3572748e+07, 6.4707192e+07,
6.3457200e+07, 6.2271480e+07, 6.1815512e+07, 6.4171308e+07,
6.2221656e+07, 6.3431888e+07, 6.3596744e+07, 6.3966296e+07,
6.3165764e+07, 6.2923784e+07, 6.3093224e+07, 6.1638304e+07]],
dtype=float32), array([[4.1098030e+09, 3.7085363e+09, 3.6985411e+09, 3.6896095e+09,
4.0884142e+09, 4.2186007e+09, 4.1056561e+09, 4.1638871e+09,
4.0401833e+09, 4.2523576e+09]], dtype=float32)]
Profiling¶
sess = create_session(script_args.filename, profiling=True)
for _ in range(script_args.repeat):
sess.run(None, feeds)
prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
df = js_profile_to_dataframe(prof, first_it_out=True)
print(df.columns)
df.to_csv("plot_profile_existing_onnx.csv")
df.to_excel("plot_profile_existing_onnx.xlsx")
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "dort")
fig.tight_layout()
fig.savefig("plot_profile_existing_onnx.png")
else:
print("Install onnx-extended first.")

Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
'args_thread_scheduling_stats', 'args_output_size',
'args_parameter_size', 'args_activation_size', 'args_node_index',
'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
'it==0'],
dtype='object')
Total running time of the script: (0 minutes 1.202 seconds)
Related examples
201: Evaluate different ways to export a torch model to ONNX
201: Evaluate different ways to export a torch model to ONNX
101: Onnx Model Optimization based on Pattern Rewriting
101: Onnx Model Optimization based on Pattern Rewriting
201: Use torch to export a scikit-learn model into ONNX
201: Use torch to export a scikit-learn model into ONNX