Note
Go to the end to download the full example code.
101: Profile an existing model with onnxruntime¶
Profiles any onnx model on CPU.
Preparation¶
import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args
try:
from onnx_extended.tools.js_profile import (
js_profile_to_dataframe,
plot_ort_profile,
)
except ImportError:
js_profile_to_dataframe = None
try:
filename = os.path.join(
os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
)
except NameError:
filename = "data/example_4700-CPUep-opt.onnx"
script_args = get_parsed_args(
"plot_profile_existing_onnx",
filename=(filename, "input file"),
repeat=10,
expose="",
)
for att in "filename,repeat".split(","):
print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10
Random inputs.
def create_random_input(sess):
feeds = {}
for i in sess.get_inputs():
shape = i.shape
ot = i.type
if ot == "tensor(float)":
dtype = np.float32
else:
raise ValueError(f"Unsupposed onnx type {ot}.")
t = np.random.rand(*shape).astype(dtype)
feeds[i.name] = t
return feeds
def create_session(filename, profiling=False):
from onnxruntime import InferenceSession, SessionOptions
if not profiling:
return InferenceSession(filename, providers=["CPUExecutionProvider"])
opts = SessionOptions()
opts.enable_profiling = True
return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])
sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.49811938, 0.82152057, 0.3373039 , ..., 0.9690496 , 0.7567 ,
0.31627807],
[0.05741106, 0.61383396, 0.69814515, ..., 0.00981634, 0.03271949,
0.76451415],
[0.37014586, 0.5706923 , 0.94200885, ..., 0.6017973 , 0.10141023,
0.8193094 ],
...,
[0.21534732, 0.9362901 , 0.31112844, ..., 0.9897062 , 0.6552406 ,
0.94650924],
[0.8500216 , 0.2480933 , 0.42366222, ..., 0.93902403, 0.12038182,
0.47769457],
[0.18478155, 0.6985593 , 0.8998448 , ..., 0.03175963, 0.89774 ,
0.8473401 ]], dtype=float32), array([[0.23654357, 0.8846211 , 0.7895143 , ..., 0.9358064 , 0.5486316 ,
0.8861085 ],
[0.4148546 , 0.89466226, 0.4516352 , ..., 0.6255415 , 0.6314595 ,
0.28775826],
[0.5825523 , 0.14544408, 0.6117085 , ..., 0.05096382, 0.13781106,
0.2697828 ],
...,
[0.19106524, 0.04777327, 0.1348371 , ..., 0.00879816, 0.8628678 ,
0.28616464],
[0.7186147 , 0.51735187, 0.5539227 , ..., 0.5829653 , 0.43757367,
0.10427039],
[0.3955041 , 0.66906875, 0.3431797 , ..., 0.03259649, 0.9346118 ,
0.29555398]], dtype=float32), array([[0.8238373 , 0.22705522, 0.42292988, ..., 0.22040927, 0.9317856 ,
0.37155324],
[0.06378181, 0.39271095, 0.10566166, ..., 0.22306833, 0.09067719,
0.05038857],
[0.5688872 , 0.04833359, 0.69702303, ..., 0.3372865 , 0.80998343,
0.55208874],
...,
[0.24066575, 0.81232107, 0.24020675, ..., 0.41089785, 0.76754063,
0.09112121],
[0.73953503, 0.14319196, 0.89437425, ..., 0.05385933, 0.2680572 ,
0.23563877],
[0.97741354, 0.87072533, 0.87953866, ..., 0.85288197, 0.13573973,
0.01255301]], dtype=float32), array([[[[5.565089 , 5.103101 , 4.5174675, ..., 6.0058846, 5.958382 ,
5.4275723],
[5.745741 , 5.5332584, 4.767103 , ..., 7.137676 , 6.752444 ,
5.708514 ],
[5.605507 , 6.2635736, 5.501947 , ..., 6.328216 , 6.764166 ,
6.2347217],
...,
[6.753322 , 5.8139067, 6.4643807, ..., 6.5022063, 5.671295 ,
6.1274652],
[6.793723 , 6.1422615, 6.3913918, ..., 5.9583516, 4.947596 ,
5.0844755],
[5.7757854, 5.1850967, 6.2562585, ..., 5.155083 , 4.8032727,
5.463747 ]],
[[7.232889 , 5.9127665, 5.027247 , ..., 7.32901 , 6.9056263,
6.888383 ],
[6.209973 , 6.109516 , 5.9258747, ..., 7.99459 , 6.8433356,
6.627201 ],
[6.6054955, 6.3402686, 6.400939 , ..., 7.273025 , 7.7028365,
7.488401 ],
...,
[8.074051 , 7.767504 , 7.467769 , ..., 8.136668 , 7.035548 ,
8.015495 ],
[7.398261 , 6.792018 , 7.899523 , ..., 6.983091 , 6.2857585,
6.1718197],
[7.281235 , 7.182981 , 7.4691615, ..., 6.1005874, 5.021295 ,
6.4314747]],
[[5.7637143, 4.8422503, 4.1405168, ..., 5.827357 , 5.811146 ,
5.7332745],
[5.7922034, 3.9866033, 4.2269087, ..., 5.785495 , 6.3121614,
5.7971487],
[4.8831177, 4.9326544, 4.3113036, ..., 5.959783 , 5.5856047,
5.4234543],
...,
[5.5585494, 5.7830715, 5.203156 , ..., 5.5613346, 5.3291354,
6.0616064],
[5.38927 , 5.802525 , 6.0341926, ..., 5.338025 , 4.907255 ,
6.0103025],
[6.3623343, 5.3191495, 6.3359327, ..., 5.0720263, 4.6790676,
4.8685727]],
...,
[[5.3656626, 5.2313848, 5.1438403, ..., 6.5407805, 5.8744655,
5.0493045],
[5.342044 , 4.833243 , 4.9250784, ..., 6.265502 , 6.7738924,
6.4219494],
[5.494974 , 4.834587 , 5.6865044, ..., 5.897386 , 6.338793 ,
5.8890886],
...,
[6.899637 , 6.706461 , 6.3771486, ..., 6.59177 , 5.4907427,
6.230597 ],
[6.20377 , 5.6799264, 5.947048 , ..., 5.151845 , 5.4751916,
5.16739 ],
[5.8000546, 6.888856 , 5.3176284, ..., 4.7611704, 4.56488 ,
4.786338 ]],
[[6.156503 , 5.811774 , 5.3238378, ..., 7.531144 , 6.7949357,
6.5749903],
[5.286597 , 5.5069985, 6.0806537, ..., 7.161507 , 7.6698594,
5.885137 ],
[5.483968 , 5.7780085, 5.9242854, ..., 6.459801 , 6.5940285,
6.609454 ],
...,
[6.9081225, 6.7646356, 7.3245196, ..., 6.762631 , 6.119577 ,
6.1035166],
[7.1474175, 6.0217876, 6.609377 , ..., 6.1596785, 5.8335376,
5.6525984],
[6.874523 , 6.061458 , 7.470292 , ..., 6.034182 , 4.957194 ,
4.944079 ]],
[[6.6298065, 5.4778523, 4.0081897, ..., 6.080302 , 5.915414 ,
5.071162 ],
[5.2896533, 4.846699 , 5.0159574, ..., 6.968491 , 5.915704 ,
5.1313715],
[4.542079 , 5.0459566, 5.3878794, ..., 5.9239216, 5.8458543,
5.7964206],
...,
[5.7237134, 6.85091 , 5.3416386, ..., 6.1415 , 5.8549585,
6.7072096],
[6.004997 , 5.7064233, 6.2460246, ..., 6.116028 , 4.9547224,
4.8880754],
[7.134811 , 5.745888 , 5.4453382, ..., 4.930555 , 4.0033393,
4.497093 ]]]], dtype=float32), array([[[[ 249, 379, 133, ..., 363, 364, 245],
[ 745, 624, 629, ..., 857, 615, 741],
[ 1243, 1244, 1127, ..., 1476, 1356, 1114],
...,
[14137, 14017, 14021, ..., 14375, 14377, 14135],
[14509, 14515, 14764, ..., 14498, 14500, 14753],
[15255, 14887, 14888, ..., 15116, 15247, 15125]],
[[ 374, 377, 134, ..., 363, 365, 492],
[ 622, 872, 877, ..., 859, 738, 619],
[ 1367, 1368, 1125, ..., 1230, 1357, 1238],
...,
[14015, 14019, 14021, ..., 14375, 14377, 14380],
[14385, 14514, 14640, ..., 14746, 14875, 14630],
[14883, 14885, 14888, ..., 15242, 15370, 15125]],
[[ 124, 376, 133, ..., 363, 366, 244],
[ 871, 626, 630, ..., 735, 737, 742],
[ 1367, 1121, 1375, ..., 1229, 1482, 1362],
...,
[14136, 13895, 14268, ..., 14372, 14376, 14259],
[14385, 14762, 14764, ..., 14622, 14500, 14877],
[15007, 14887, 15013, ..., 15116, 15123, 15001]],
...,
[[ 375, 377, 134, ..., 114, 366, 368],
[ 745, 749, 628, ..., 735, 739, 990],
[ 995, 1368, 1375, ..., 1353, 1358, 1238],
...,
[13891, 14019, 13897, ..., 14375, 14376, 14380],
[14385, 14639, 14516, ..., 14620, 14500, 14877],
[15004, 14887, 14889, ..., 15242, 15368, 15125]],
[[ 375, 376, 135, ..., 363, 366, 246],
[ 498, 751, 629, ..., 983, 614, 616],
[ 1366, 1121, 1126, ..., 1229, 1233, 1238],
...,
[14015, 13895, 14020, ..., 14372, 14376, 14259],
[14758, 14515, 14640, ..., 14744, 14749, 14631],
[15254, 14886, 15013, ..., 15116, 15368, 15125]],
[[ 0, 379, 134, ..., 363, 366, 245],
[ 747, 749, 628, ..., 859, 738, 990],
[ 1367, 1245, 1375, ..., 1478, 1358, 1362],
...,
[14015, 14019, 14021, ..., 14375, 14377, 14383],
[14632, 14515, 14640, ..., 14747, 14500, 14631],
[15252, 14886, 14889, ..., 14994, 14999, 15127]]]], dtype=int64), array([[6.2635736, 5.6826043, 6.172062 , ..., 6.1040564, 5.822977 ,
6.7072096]], dtype=float32), array([[117578.305, 116757.78 , 118194.61 , ..., 117556.78 , 117410.74 ,
117894.62 ]], dtype=float32), array([[59846544., 58236288., 58976596., 60590736., 62058104., 60432408.,
59663544., 59718448., 61311092., 59508752., 59787840., 60822952.,
59920728., 60014712., 60831268., 61057852., 61220080., 61549000.,
58610184., 60068620., 61335560., 61194504., 60502104., 60003704.,
59070296., 58443152., 61277268., 59282320., 60398944., 59343580.,
59360288., 60844792., 59364908., 60653352., 60576632., 62950032.,
60844288., 62809440., 60193900., 59834456., 59278804., 61281816.,
59741584., 60839232., 59899280., 60719816., 59188440., 61837368.,
61231392., 60013048., 60707972., 60957232., 60685696., 59340524.,
62570220., 59074100., 62156332., 59758624., 61961888., 59178672.,
59295944., 60203348., 58871752., 59733404., 59596988., 60141568.,
60038548., 61583332., 61143492., 59138024., 60528984., 60534524.,
60534380., 58101636., 58841352., 58926016., 59822960., 60417440.,
59730784., 61367840., 62503976., 58468848., 58949380., 60468720.,
58595856., 60187576., 60268520., 63818352., 58762480., 59508080.,
62196608., 61047408., 60215336., 60450892., 59283104., 57954864.,
61185920., 61690728., 60737716., 60368768., 58991252., 60454824.,
60972984., 60767536., 59426240., 60702664., 60382692., 60766444.,
59530424., 59392020., 59829224., 60664052., 61300568., 60710384.,
60141088., 59636528., 58949032., 60760988., 59714968., 62630448.,
62124576., 60557640., 59309288., 59972832., 61738392., 60953704.,
60185584., 60170140.]], dtype=float32), array([[3.8632909e+09, 3.9355768e+09, 3.8385060e+09, 3.7788078e+09,
3.9902116e+09, 3.7555151e+09, 3.8698854e+09, 4.2214922e+09,
3.6336955e+09, 3.9290688e+09]], dtype=float32)]
Profiling¶
sess = create_session(script_args.filename, profiling=True)
for _ in range(script_args.repeat):
sess.run(None, feeds)
prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
df = js_profile_to_dataframe(prof, first_it_out=True)
print(df.columns)
df.to_csv("plot_profile_existing_onnx.csv")
df.to_excel("plot_profile_existing_onnx.xlsx")
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plot_ort_profile(df, ax[0], ax[1], "dort")
fig.tight_layout()
fig.savefig("plot_profile_existing_onnx.png")
else:
print("Install onnx-extended first.")
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
'args_thread_scheduling_stats', 'args_output_size',
'args_parameter_size', 'args_activation_size', 'args_node_index',
'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
'it==0'],
dtype='object')
Total running time of the script: (0 minutes 1.166 seconds)