Note
Go to the end to download the full example code.
101: Profile an existing model with onnxruntime¶
Profiles any onnx model on CPU.
Preparation¶
import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args
try:
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
)
except ImportError:
js_profile_to_dataframe = None
try:
filename = os.path.join(
os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
)
except NameError:
filename = "data/example_4700-CPUep-opt.onnx"
script_args = get_parsed_args(
"plot_profile_existing_onnx",
filename=(filename, "input file"),
repeat=10,
expose="",
)
for att in ("filename", "repeat"):
print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10
Random inputs.
def create_random_input(sess):
feeds = {}
for i in sess.get_inputs():
shape = i.shape
ot = i.type
if ot == "tensor(float)":
dtype = np.float32
else:
raise ValueError(f"Unsupposed onnx type {ot}.")
t = np.random.rand(*shape).astype(dtype)
feeds[i.name] = t
return feeds
def create_session(filename, profiling=False):
from onnxruntime import InferenceSession, SessionOptions
if not profiling:
return InferenceSession(filename, providers=["CPUExecutionProvider"])
opts = SessionOptions()
opts.enable_profiling = True
return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])
sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.38556445, 0.55397576, 0.8805988 , ..., 0.80855435, 0.9698065 ,
0.1651492 ],
[0.60911006, 0.98146534, 0.48810425, ..., 0.37745422, 0.0707617 ,
0.21648818],
[0.14771894, 0.10950058, 0.7909586 , ..., 0.05426677, 0.37327752,
0.26818016],
...,
[0.3218134 , 0.29144892, 0.0526025 , ..., 0.02053035, 0.82496136,
0.9514841 ],
[0.635222 , 0.64783275, 0.4547749 , ..., 0.76839703, 0.34234193,
0.40970457],
[0.3180374 , 0.5930903 , 0.9649414 , ..., 0.33332193, 0.47703722,
0.55091107]], shape=(128, 1024), dtype=float32), array([[0.95870274, 0.636761 , 0.44476703, ..., 0.19968031, 0.4780304 ,
0.26984775],
[0.49667543, 0.8193919 , 0.7951488 , ..., 0.1742966 , 0.426188 ,
0.35868287],
[0.13918132, 0.42133337, 0.48337352, ..., 0.80362266, 0.58344287,
0.6889565 ],
...,
[0.6976248 , 0.8877859 , 0.1375428 , ..., 0.6435721 , 0.95254046,
0.40901002],
[0.8864794 , 0.25991818, 0.49433857, ..., 0.7553312 , 0.5885717 ,
0.6678185 ],
[0.13902307, 0.8430048 , 0.49835652, ..., 0.65506876, 0.72827005,
0.16622455]], shape=(1024, 30752), dtype=float32), array([[0.9065616 , 0.9900758 , 0.51137394, ..., 0.48558617, 0.63175607,
0.62513673],
[0.2600353 , 0.67811406, 0.19555514, ..., 0.10207321, 0.66219664,
0.88368404],
[0.4334533 , 0.8996908 , 0.46073616, ..., 0.67588544, 0.0141007 ,
0.89307386],
...,
[0.51598513, 0.14081694, 0.06316388, ..., 0.98565435, 0.5844102 ,
0.8166332 ],
[0.46951342, 0.7719806 , 0.8183161 , ..., 0.46453735, 0.8912257 ,
0.81608504],
[0.24607968, 0.4938955 , 0.7717674 , ..., 0.5502433 , 0.18733628,
0.7372484 ]], shape=(10, 128), dtype=float32), array([[[[7.2124248, 6.949165 , 5.3856463, ..., 4.8469324, 5.320979 ,
6.0521 ],
[7.0614085, 6.2695427, 5.4931307, ..., 6.3549614, 6.6163974,
6.017617 ],
[7.381121 , 7.3006744, 5.7618985, ..., 7.3029084, 7.318741 ,
5.310724 ],
...,
[6.0006537, 6.3537717, 6.3752923, ..., 5.9468803, 5.795259 ,
5.810827 ],
[5.379063 , 5.876003 , 6.685616 , ..., 6.7443995, 6.578243 ,
6.7477827],
[5.1108017, 5.9684243, 5.794866 , ..., 6.321777 , 7.0116224,
6.7306232]],
[[8.06898 , 7.4926353, 6.1334796, ..., 5.615914 , 7.085943 ,
6.8165927],
[7.9092493, 6.7172837, 5.5994554, ..., 7.534926 , 7.068978 ,
6.920303 ],
[8.383128 , 7.9394937, 5.7300982, ..., 7.8446493, 7.263592 ,
5.735348 ],
...,
[6.9940324, 7.7478275, 7.1182747, ..., 6.857957 , 6.194977 ,
6.4526353],
[6.337092 , 7.0728784, 7.845219 , ..., 8.164216 , 7.142987 ,
8.17029 ],
[6.4320693, 6.7336698, 7.240597 , ..., 7.342244 , 7.521635 ,
8.012213 ]],
[[5.124064 , 4.732128 , 3.5098467, ..., 4.012354 , 3.9715738,
3.7085536],
[4.8587728, 4.8069825, 2.9407508, ..., 4.826569 , 5.049167 ,
4.556578 ],
[5.6522174, 4.7376804, 3.9291642, ..., 4.69229 , 4.4731402,
4.2415004],
...,
[4.896876 , 4.9004974, 5.440536 , ..., 4.2603846, 4.126876 ,
4.629271 ],
[3.7297525, 4.5046086, 5.397951 , ..., 4.871149 , 5.1618586,
6.174954 ],
[4.2740426, 4.448325 , 5.3924384, ..., 5.0491276, 5.3751755,
4.833969 ]],
...,
[[7.6567993, 6.9736595, 5.819581 , ..., 6.2047315, 6.4378743,
5.8060503],
[7.087953 , 6.772794 , 5.4935155, ..., 7.0531735, 6.8689923,
7.3181067],
[7.8979726, 7.473395 , 6.11533 , ..., 6.4378467, 5.8727493,
6.642461 ],
...,
[6.6677213, 7.990225 , 6.453491 , ..., 5.8332553, 5.967316 ,
6.534334 ],
[5.6889033, 7.0696597, 6.1350913, ..., 6.7491493, 7.232734 ,
7.3192263],
[5.3458953, 6.3537183, 6.6923842, ..., 6.86066 , 6.358967 ,
7.2293086]],
[[6.190681 , 4.934544 , 4.2880063, ..., 4.067191 , 4.7604604,
4.095345 ],
[5.4029784, 5.408377 , 3.620253 , ..., 5.2773747, 5.409656 ,
5.743461 ],
[6.1972303, 5.4862523, 4.1798673, ..., 5.4588537, 4.4254813,
4.4029136],
...,
[5.1725945, 4.882221 , 5.595881 , ..., 4.0541406, 3.5928552,
4.3137593],
[4.4695835, 4.2526155, 5.578326 , ..., 5.469363 , 5.4409103,
5.0957747],
[4.410438 , 4.3117046, 5.8523507, ..., 5.515564 , 4.527907 ,
4.909782 ]],
[[7.51237 , 6.875618 , 5.907875 , ..., 5.4989085, 6.8030105,
5.2218556],
[7.7072873, 6.429263 , 5.3156986, ..., 7.142754 , 6.151346 ,
6.922994 ],
[7.2347255, 6.926053 , 5.7444644, ..., 7.704699 , 6.722652 ,
6.2124476],
...,
[6.441152 , 7.2276235, 6.91788 , ..., 6.3062444, 5.544852 ,
5.9097867],
[6.163632 , 5.882812 , 6.2885513, ..., 6.486774 , 6.972126 ,
6.839702 ],
[5.8554616, 5.5420575, 6.187286 , ..., 6.4345636, 7.115575 ,
6.2969093]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[ 248, 252, 380, ..., 361, 491, 244],
[ 744, 503, 753, ..., 611, 738, 617],
[ 992, 997, 1000, ..., 1476, 1111, 1239],
...,
[14136, 14141, 13899, ..., 14372, 14377, 14011],
[14757, 14762, 14394, ..., 14496, 14624, 14507],
[15131, 15010, 15137, ..., 14994, 15121, 15374]],
[[ 248, 131, 256, ..., 361, 491, 493],
[ 744, 874, 752, ..., 859, 615, 617],
[ 994, 998, 1000, ..., 1353, 1111, 1238],
...,
[14136, 13892, 14146, ..., 14248, 14377, 14135],
[14757, 14513, 14641, ..., 14496, 14872, 14507],
[15130, 15258, 15260, ..., 14993, 15246, 15251]],
[[ 248, 131, 381, ..., 361, 119, 493],
[ 620, 874, 754, ..., 610, 737, 741],
[ 994, 997, 1000, ..., 1477, 1235, 1236],
...,
[13889, 14143, 13898, ..., 14372, 14005, 14011],
[14758, 14638, 14642, ..., 14869, 14624, 14506],
[15006, 15258, 15261, ..., 14993, 14996, 15251]],
...,
[[ 248, 131, 381, ..., 360, 491, 492],
[ 744, 751, 752, ..., 858, 739, 741],
[ 994, 998, 1002, ..., 1477, 1111, 1114],
...,
[14012, 14140, 14145, ..., 14126, 14130, 14135],
[14508, 14638, 14517, ..., 14869, 14748, 14507],
[15005, 15258, 15263, ..., 15119, 15245, 15251]],
[[ 248, 255, 380, ..., 484, 489, 493],
[ 497, 750, 752, ..., 610, 862, 617],
[ 1118, 1247, 1002, ..., 1352, 1235, 1487],
...,
[13890, 13895, 14144, ..., 14372, 14005, 14008],
[14758, 14761, 14640, ..., 14496, 14872, 14755],
[15254, 15009, 15263, ..., 14993, 14997, 15002]],
[[ 124, 131, 258, ..., 361, 365, 493],
[ 744, 874, 628, ..., 859, 614, 617],
[ 1118, 997, 1375, ..., 1476, 1111, 1114],
...,
[13889, 14141, 13899, ..., 14372, 14005, 14011],
[14756, 14761, 14640, ..., 14496, 14872, 14507],
[15255, 15134, 15260, ..., 14995, 15121, 15374]]]],
shape=(1, 32, 31, 31), dtype=int64), array([[7.381121 , 6.113424 , 6.964289 , ..., 7.1669254, 7.2326717,
7.115575 ]], shape=(1, 30752), dtype=float32), array([[119141.46 , 119613.23 , 119708.78 , ..., 119339.086, 118628.85 ,
118833.1 ]], shape=(1, 1024), dtype=float32), array([[6.1533384e+07, 5.9396112e+07, 6.0553608e+07, 6.1904116e+07,
6.0651256e+07, 6.0473864e+07, 6.0628596e+07, 6.1006848e+07,
6.1611816e+07, 6.1494788e+07, 5.9946136e+07, 6.0892072e+07,
6.0989100e+07, 6.0879668e+07, 5.9669440e+07, 6.1841072e+07,
6.0864296e+07, 5.8424244e+07, 6.0586568e+07, 6.0708544e+07,
6.1408496e+07, 6.3036920e+07, 6.1415088e+07, 6.1482232e+07,
6.0187580e+07, 6.2735524e+07, 6.1316716e+07, 5.9624520e+07,
6.0540560e+07, 6.0819616e+07, 6.3122228e+07, 6.0894616e+07,
6.0451604e+07, 6.1070192e+07, 6.0979504e+07, 6.1360276e+07,
5.9039568e+07, 6.0540144e+07, 6.1145744e+07, 6.0742980e+07,
6.1351368e+07, 6.0078768e+07, 6.2439900e+07, 6.2247788e+07,
5.9820968e+07, 5.8503176e+07, 6.0834992e+07, 5.8239728e+07,
6.1263068e+07, 5.9641888e+07, 6.2209932e+07, 6.1295636e+07,
6.0269064e+07, 6.0916872e+07, 6.2716144e+07, 5.9071256e+07,
6.1943084e+07, 6.0251592e+07, 5.9849740e+07, 6.0636216e+07,
6.1342548e+07, 5.9601248e+07, 6.1807060e+07, 6.0105000e+07,
6.1794720e+07, 5.9975672e+07, 6.0113808e+07, 5.9534040e+07,
5.9255896e+07, 6.0921696e+07, 6.1040104e+07, 5.9613088e+07,
6.0227500e+07, 6.0268120e+07, 5.9954272e+07, 6.1388912e+07,
6.3367608e+07, 6.1473328e+07, 6.2184904e+07, 6.2094508e+07,
6.1093552e+07, 6.0453852e+07, 6.1754576e+07, 6.1102236e+07,
6.0100800e+07, 6.0864548e+07, 6.1891544e+07, 6.0204456e+07,
6.1565576e+07, 6.0354372e+07, 6.0864624e+07, 5.9297096e+07,
6.0754444e+07, 6.1920328e+07, 6.0301968e+07, 6.1914372e+07,
6.1900076e+07, 6.1327152e+07, 6.2442824e+07, 6.1349680e+07,
6.0438016e+07, 6.2174912e+07, 6.0849416e+07, 5.8883824e+07,
6.1998108e+07, 5.9313568e+07, 6.2345872e+07, 6.0202680e+07,
6.1111028e+07, 6.2123064e+07, 6.0621932e+07, 6.0387752e+07,
5.9783000e+07, 6.1904984e+07, 5.8992736e+07, 6.0704196e+07,
6.1475780e+07, 6.1883072e+07, 6.1640540e+07, 6.1268072e+07,
6.1788800e+07, 6.0115072e+07, 5.9435696e+07, 6.0516424e+07,
6.0247128e+07, 6.0656156e+07, 6.0185568e+07, 6.3598848e+07]],
dtype=float32), array([[3.8433152e+09, 3.5135181e+09, 4.3521690e+09, 4.0556605e+09,
3.5809229e+09, 4.0594386e+09, 4.1181942e+09, 4.0656640e+09,
3.9362488e+09, 3.9547407e+09]], dtype=float32)]
Profiling¶
sess = create_session(script_args.filename, profiling=True)
for _ in range(script_args.repeat):
sess.run(None, feeds)
prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
df = js_profile_to_dataframe(prof, first_it_out=True)
print(df.columns)
df.to_csv("plot_profile_existing_onnx.csv")
df.to_excel("plot_profile_existing_onnx.xlsx")
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "dort")
fig.tight_layout()
fig.savefig("plot_profile_existing_onnx.png")
else:
print("Install onnx-extended first.")

Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
'args_thread_scheduling_stats', 'args_output_size',
'args_parameter_size', 'args_activation_size', 'args_node_index',
'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
'it==0'],
dtype='object')
Total running time of the script: (0 minutes 1.626 seconds)
Related examples
201: Evaluate different ways to export a torch model to ONNX
201: Evaluate different ways to export a torch model to ONNX
101: Onnx Model Optimization based on Pattern Rewriting
101: Onnx Model Optimization based on Pattern Rewriting
201: Use torch to export a scikit-learn model into ONNX
201: Use torch to export a scikit-learn model into ONNX