Note
Go to the end to download the full example code.
101: Profile an existing model with onnxruntime¶
Profiles any onnx model on CPU.
Preparation¶
import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args
try:
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
)
except ImportError:
js_profile_to_dataframe = None
try:
filename = os.path.join(
os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
)
except NameError:
filename = "data/example_4700-CPUep-opt.onnx"
script_args = get_parsed_args(
"plot_profile_existing_onnx",
filename=(filename, "input file"),
repeat=10,
expose="",
)
for att in ("filename", "repeat"):
print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10
Random inputs.
def create_random_input(sess):
feeds = {}
for i in sess.get_inputs():
shape = i.shape
ot = i.type
if ot == "tensor(float)":
dtype = np.float32
else:
raise ValueError(f"Unsupposed onnx type {ot}.")
t = np.random.rand(*shape).astype(dtype)
feeds[i.name] = t
return feeds
def create_session(filename, profiling=False):
from onnxruntime import InferenceSession, SessionOptions
if not profiling:
return InferenceSession(filename, providers=["CPUExecutionProvider"])
opts = SessionOptions()
opts.enable_profiling = True
return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])
sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.510395 , 0.09870491, 0.75266755, ..., 0.5195775 , 0.57707196,
0.7490044 ],
[0.8236655 , 0.46533 , 0.9431352 , ..., 0.5351355 , 0.01673928,
0.48355934],
[0.21692038, 0.3470374 , 0.4777763 , ..., 0.8625944 , 0.74289674,
0.32841423],
...,
[0.8921656 , 0.13616458, 0.46435517, ..., 0.8052121 , 0.65348035,
0.2504524 ],
[0.01860473, 0.39064184, 0.87167305, ..., 0.43814293, 0.869365 ,
0.73878515],
[0.6230802 , 0.38814384, 0.11948249, ..., 0.4887564 , 0.21102405,
0.73262644]], shape=(128, 1024), dtype=float32), array([[0.540234 , 0.24377689, 0.59050584, ..., 0.7471279 , 0.5253918 ,
0.4479876 ],
[0.53054404, 0.89304227, 0.8826398 , ..., 0.60216653, 0.12245536,
0.8086819 ],
[0.08696877, 0.45958072, 0.94682866, ..., 0.3839614 , 0.31130317,
0.6475578 ],
...,
[0.9040087 , 0.83760476, 0.42550895, ..., 0.98305 , 0.6565117 ,
0.96121 ],
[0.72511977, 0.28390366, 0.6875673 , ..., 0.94175047, 0.03937987,
0.7381576 ],
[0.35966566, 0.93998736, 0.08909597, ..., 0.413577 , 0.36816013,
0.20294131]], shape=(1024, 30752), dtype=float32), array([[0.85992175, 0.40893006, 0.91688865, ..., 0.77127737, 0.61978394,
0.1234277 ],
[0.79496896, 0.02235206, 0.71498513, ..., 0.55956376, 0.99967736,
0.00727563],
[0.9504341 , 0.08722831, 0.1550522 , ..., 0.5807844 , 0.05640255,
0.04382119],
...,
[0.91768384, 0.65331656, 0.8954979 , ..., 0.11000395, 0.6667147 ,
0.5823768 ],
[0.02742839, 0.8250603 , 0.6195953 , ..., 0.63641924, 0.47468057,
0.4685825 ],
[0.2796416 , 0.5099486 , 0.2765593 , ..., 0.00334817, 0.27891853,
0.03767961]], shape=(10, 128), dtype=float32), array([[[[5.806987 , 5.72013 , 6.157274 , ..., 6.0545344, 6.0556784,
4.8173094],
[5.006063 , 5.585447 , 5.7751427, ..., 7.25757 , 5.9223084,
6.1376605],
[5.662475 , 5.541954 , 5.4569254, ..., 7.0972886, 6.890184 ,
5.921923 ],
...,
[7.3290105, 6.1417856, 6.660156 , ..., 6.4197974, 4.817443 ,
6.5680733],
[6.981306 , 5.95116 , 6.1571884, ..., 4.5505958, 4.821188 ,
5.9442015],
[6.3763127, 6.270368 , 5.9747205, ..., 5.8328056, 4.9997 ,
6.6118903]],
[[7.573307 , 6.5594516, 6.727903 , ..., 7.9036674, 6.7875094,
6.2193174],
[6.427677 , 6.47501 , 6.579662 , ..., 8.290004 , 7.887871 ,
6.4572086],
[6.797099 , 6.7822533, 6.449536 , ..., 7.9760294, 7.972854 ,
7.420871 ],
...,
[8.562127 , 8.30687 , 8.352249 , ..., 6.1537843, 6.6855364,
6.839553 ],
[8.273453 , 8.019238 , 7.298452 , ..., 5.799291 , 5.9442134,
6.7026215],
[7.6595435, 6.855773 , 7.2733693, ..., 5.8534894, 5.257635 ,
7.074494 ]],
[[6.09562 , 6.7099795, 6.398712 , ..., 6.0222287, 5.7456374,
5.703014 ],
[5.0917773, 6.155957 , 5.435502 , ..., 7.1121435, 6.876064 ,
6.320275 ],
[6.7134843, 5.6703215, 5.376893 , ..., 7.0713973, 7.6736345,
6.6342244],
...,
[7.2245374, 7.2838416, 7.4497914, ..., 5.919186 , 6.2245483,
6.95924 ],
[7.426585 , 7.0253086, 6.164856 , ..., 5.324277 , 5.532171 ,
6.470504 ],
[6.863257 , 6.337329 , 5.7022367, ..., 6.1614037, 5.120746 ,
6.5681777]],
...,
[[7.062579 , 7.1482577, 6.780678 , ..., 7.158068 , 7.3833985,
6.317856 ],
[6.6447344, 6.860033 , 6.8069925, ..., 8.233599 , 7.6687574,
7.788963 ],
[6.79121 , 6.6062245, 6.7519684, ..., 8.406888 , 8.140712 ,
7.4982758],
...,
[8.328917 , 8.669029 , 7.8935614, ..., 6.456511 , 6.5150266,
7.0281615],
[7.913402 , 7.1996055, 7.08994 , ..., 5.751628 , 5.792744 ,
7.0866942],
[8.275266 , 7.7689857, 6.9132953, ..., 6.9814863, 6.5053353,
8.01444 ]],
[[7.197398 , 6.4248443, 6.1547756, ..., 6.64271 , 6.639652 ,
5.758892 ],
[6.9436264, 5.676034 , 5.6722555, ..., 8.202153 , 7.6702285,
6.667065 ],
[7.3901343, 6.4414344, 6.5813518, ..., 7.9361444, 7.8210273,
7.4510965],
...,
[7.976137 , 7.1362987, 6.7917104, ..., 7.2537456, 6.6555095,
7.178766 ],
[7.793176 , 6.9324927, 7.0725803, ..., 5.5457215, 5.4725423,
6.641788 ],
[8.095506 , 7.33383 , 6.4648485, ..., 6.4031425, 5.2241807,
6.689626 ]],
[[7.427603 , 6.369605 , 7.0653505, ..., 7.7532473, 7.7680655,
5.86767 ],
[6.354887 , 6.89118 , 6.5639534, ..., 8.676196 , 7.833905 ,
7.8430805],
[7.0764723, 7.4269753, 5.7521772, ..., 8.7820015, 8.515766 ,
7.6656866],
...,
[9.665207 , 8.070228 , 7.947341 , ..., 6.6797304, 6.740755 ,
7.251953 ],
[8.025931 , 7.7180023, 7.346319 , ..., 6.4853144, 5.9467134,
7.3486214],
[8.240556 , 7.143462 , 6.993849 , ..., 6.2916117, 6.252305 ,
8.599945 ]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[ 2, 128, 134, ..., 485, 118, 493],
[ 870, 872, 878, ..., 857, 987, 865],
[ 993, 999, 1124, ..., 1105, 1232, 1360],
...,
[13888, 14019, 14271, ..., 14372, 14006, 14009],
[14632, 14515, 14519, ..., 14868, 14751, 14630],
[14880, 14884, 15261, ..., 15242, 15368, 15375]],
[[ 0, 6, 258, ..., 236, 118, 495],
[ 870, 872, 878, ..., 733, 984, 989],
[ 995, 1120, 1003, ..., 1352, 1480, 1115],
...,
[13891, 14266, 14271, ..., 14374, 14007, 14382],
[14508, 14515, 14394, ..., 14496, 14500, 14631],
[14880, 14884, 15260, ..., 15365, 15244, 15372]],
[[ 248, 5, 134, ..., 237, 118, 494],
[ 869, 875, 876, ..., 733, 987, 617],
[ 1117, 999, 1124, ..., 1105, 1233, 1114],
...,
[13889, 13894, 14269, ..., 14373, 14006, 14008],
[14756, 14515, 14394, ..., 14622, 14503, 14753],
[14880, 15259, 15261, ..., 15364, 15244, 15127]],
...,
[[ 3, 5, 11, ..., 485, 118, 494],
[ 869, 751, 879, ..., 981, 737, 866],
[ 1117, 999, 1248, ..., 1352, 1235, 1113],
...,
[14137, 14017, 14268, ..., 14374, 14005, 14008],
[14384, 14515, 14395, ..., 14622, 14875, 14629],
[15005, 14884, 15261, ..., 15364, 14998, 15375]],
[[ 248, 129, 382, ..., 239, 119, 245],
[ 870, 872, 878, ..., 609, 737, 866],
[ 1118, 996, 1248, ..., 1352, 1233, 1114],
...,
[13888, 13894, 14271, ..., 14374, 14007, 14010],
[14384, 14514, 14516, ..., 14868, 14500, 14631],
[14880, 14884, 15261, ..., 15242, 15244, 15125]],
[[ 0, 5, 258, ..., 112, 118, 494],
[ 869, 872, 878, ..., 857, 985, 990],
[ 995, 1120, 1002, ..., 1352, 1109, 1115],
...,
[14012, 14142, 14270, ..., 14374, 14007, 14008],
[14756, 14638, 14394, ..., 14499, 14501, 14877],
[15004, 14884, 15263, ..., 15366, 15245, 15375]]]],
shape=(1, 32, 31, 31), dtype=int64), array([[6.157274 , 6.732797 , 6.8466873, ..., 9.558445 , 7.943347 ,
8.599945 ]], shape=(1, 30752), dtype=float32), array([[122931.21 , 122812.01 , 122660.01 , ..., 122802.49 , 121542.086,
122556.76 ]], shape=(1, 1024), dtype=float32), array([[6.2685740e+07, 6.3410160e+07, 6.2390724e+07, 6.3186020e+07,
6.4175284e+07, 6.2694756e+07, 6.4098784e+07, 6.2466672e+07,
6.2496920e+07, 6.2552360e+07, 6.3269496e+07, 6.2899440e+07,
6.2978024e+07, 6.4169028e+07, 6.2707548e+07, 6.4332960e+07,
6.3459072e+07, 6.2106776e+07, 6.4202848e+07, 6.3429140e+07,
6.3302292e+07, 6.1958276e+07, 6.1664368e+07, 6.1055936e+07,
6.3270112e+07, 6.1401068e+07, 6.2135536e+07, 6.2427304e+07,
6.0256600e+07, 6.3209992e+07, 6.0593304e+07, 6.0781636e+07,
6.1482384e+07, 6.1717400e+07, 6.2610832e+07, 6.3769260e+07,
6.1473264e+07, 6.1505804e+07, 6.3947816e+07, 6.4301336e+07,
6.4167992e+07, 6.2819408e+07, 6.5220052e+07, 6.1881272e+07,
6.2952696e+07, 6.3190024e+07, 6.2599972e+07, 6.4674996e+07,
6.1654040e+07, 6.2040848e+07, 6.5397256e+07, 6.2409552e+07,
6.0697624e+07, 6.3951028e+07, 6.3041340e+07, 6.4684656e+07,
6.3020680e+07, 6.3066268e+07, 6.2280680e+07, 6.2135076e+07,
6.2253896e+07, 6.1113460e+07, 6.2437828e+07, 6.3035072e+07,
6.4578888e+07, 6.3574296e+07, 6.3139376e+07, 6.5137860e+07,
6.2854876e+07, 6.1501360e+07, 6.7199296e+07, 6.0134952e+07,
6.3340264e+07, 6.3838528e+07, 6.1254408e+07, 6.3398144e+07,
6.2998444e+07, 6.1963864e+07, 6.2621528e+07, 6.0110416e+07,
6.1974800e+07, 6.3285376e+07, 6.1449940e+07, 6.3292648e+07,
6.0577176e+07, 6.3333320e+07, 6.3513912e+07, 6.1636444e+07,
6.2994152e+07, 6.3698160e+07, 6.2220116e+07, 6.4425136e+07,
6.3645176e+07, 6.1214688e+07, 6.4143920e+07, 6.4210424e+07,
6.0982224e+07, 6.2320652e+07, 6.3713752e+07, 6.5227760e+07,
6.3964752e+07, 6.0706800e+07, 6.4768920e+07, 6.2442664e+07,
6.2656032e+07, 6.4456688e+07, 6.0053592e+07, 6.3901548e+07,
6.1754392e+07, 6.2550728e+07, 6.5093520e+07, 6.3075572e+07,
6.4294520e+07, 6.5430116e+07, 6.3140252e+07, 6.2226720e+07,
6.3469724e+07, 6.3753616e+07, 6.1655140e+07, 6.1167440e+07,
6.2269368e+07, 6.3205008e+07, 6.3155536e+07, 6.2887120e+07,
6.2112636e+07, 6.2542400e+07, 6.3138648e+07, 6.1983700e+07]],
dtype=float32), array([[4.0259461e+09, 4.0417464e+09, 3.9178358e+09, 3.7554437e+09,
4.1699364e+09, 4.0457357e+09, 3.7798787e+09, 3.9954289e+09,
4.5021348e+09, 3.9691351e+09]], dtype=float32)]
Profiling¶
sess = create_session(script_args.filename, profiling=True)
for _ in range(script_args.repeat):
sess.run(None, feeds)
prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
df = js_profile_to_dataframe(prof, first_it_out=True)
print(df.columns)
df.to_csv("plot_profile_existing_onnx.csv")
df.to_excel("plot_profile_existing_onnx.xlsx")
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "dort")
fig.tight_layout()
fig.savefig("plot_profile_existing_onnx.png")
else:
print("Install onnx-extended first.")

Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
'args_thread_scheduling_stats', 'args_output_size',
'args_parameter_size', 'args_activation_size', 'args_node_index',
'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
'it==0'],
dtype='str')
Total running time of the script: (0 minutes 1.720 seconds)
Related examples
201: Evaluate different ways to export a torch model to ONNX
201: Evaluate different ways to export a torch model to ONNX
101: Onnx Model Optimization based on Pattern Rewriting
101: Onnx Model Optimization based on Pattern Rewriting