101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.7622397 , 0.3948375 , 0.00454115, ..., 0.3603591 , 0.8313883 ,
        0.9305894 ],
       [0.867129  , 0.81864125, 0.9579705 , ..., 0.01149698, 0.22738685,
        0.8824927 ],
       [0.06152409, 0.7195818 , 0.6953397 , ..., 0.12988514, 0.45296088,
        0.20958205],
       ...,
       [0.14514922, 0.4850222 , 0.29493746, ..., 0.19283625, 0.28072295,
        0.5928812 ],
       [0.06665322, 0.5587319 , 0.13076316, ..., 0.37115642, 0.47280303,
        0.4993226 ],
       [0.45474702, 0.7475003 , 0.47418386, ..., 0.2874437 , 0.16821039,
        0.9798428 ]], shape=(128, 1024), dtype=float32), array([[0.23882882, 0.8788065 , 0.73025626, ..., 0.27776092, 0.32640588,
        0.66311795],
       [0.7866556 , 0.5745767 , 0.90008336, ..., 0.22392821, 0.6119184 ,
        0.04824727],
       [0.3842913 , 0.26284382, 0.09479997, ..., 0.6735889 , 0.1882229 ,
        0.9013381 ],
       ...,
       [0.7093272 , 0.07671154, 0.5008208 , ..., 0.45241088, 0.7314623 ,
        0.40232268],
       [0.1170959 , 0.76129943, 0.3027121 , ..., 0.85832405, 0.7117575 ,
        0.8150239 ],
       [0.7202666 , 0.24943186, 0.04251959, ..., 0.42018646, 0.72920686,
        0.2768651 ]], shape=(1024, 30752), dtype=float32), array([[0.08462225, 0.10620564, 0.88498867, ..., 0.06644436, 0.7688141 ,
        0.48664108],
       [0.5396219 , 0.5220608 , 0.2595914 , ..., 0.43502304, 0.07777964,
        0.13756065],
       [0.69854313, 0.8913882 , 0.07625081, ..., 0.211255  , 0.7514542 ,
        0.05432117],
       ...,
       [0.6838675 , 0.36867303, 0.88452   , ..., 0.27762923, 0.24816072,
        0.5181851 ],
       [0.3923327 , 0.59165126, 0.18200317, ..., 0.8978244 , 0.39478236,
        0.6790107 ],
       [0.45270625, 0.80843145, 0.98569936, ..., 0.48746213, 0.29303834,
        0.27993268]], shape=(10, 128), dtype=float32), array([[[[7.25335  , 6.092447 , 7.1044426, ..., 6.288283 , 6.0907507,
          6.578005 ],
         [6.340392 , 5.859031 , 6.8526893, ..., 5.611672 , 4.7989955,
          5.7885695],
         [6.139291 , 5.454653 , 5.7528486, ..., 5.1224375, 5.1665998,
          6.1372833],
         ...,
         [5.663054 , 6.542413 , 6.31105  , ..., 8.339398 , 8.214965 ,
          6.5121355],
         [6.5103974, 6.844792 , 5.8238344, ..., 7.1897564, 6.343715 ,
          7.6876006],
         [6.350681 , 6.652443 , 5.664641 , ..., 6.0560694, 6.3888054,
          6.9841714]],

        [[5.7222915, 5.27071  , 5.787939 , ..., 4.573338 , 4.4101624,
          5.4591727],
         [5.347622 , 4.7774625, 5.1061506, ..., 4.492314 , 4.622383 ,
          5.1965666],
         [4.2397056, 4.280427 , 4.4080253, ..., 4.4916224, 3.6859689,
          4.137514 ],
         ...,
         [5.795001 , 5.6777983, 5.3661776, ..., 6.371372 , 6.6373405,
          6.1602263],
         [5.0105305, 5.1061244, 4.945548 , ..., 5.714592 , 5.2842774,
          6.185224 ],
         [5.2444296, 5.0435405, 5.3091407, ..., 5.465613 , 5.358578 ,
          5.330665 ]],

        [[6.3701806, 6.268568 , 6.685126 , ..., 5.700896 , 5.7933445,
          6.310914 ],
         [6.3644857, 5.7881184, 5.9648323, ..., 5.19162  , 5.082944 ,
          5.835834 ],
         [5.052043 , 4.6574144, 5.2456083, ..., 5.5615635, 5.666439 ,
          5.345464 ],
         ...,
         [6.1138115, 6.6561394, 6.3554044, ..., 7.7756014, 7.6107287,
          6.5494533],
         [6.259518 , 6.4505787, 5.8155065, ..., 6.705423 , 6.405305 ,
          6.9957724],
         [6.01953  , 5.9426394, 5.9624224, ..., 6.463517 , 6.324596 ,
          6.3861804]],

        ...,

        [[6.4881315, 5.355841 , 7.2149835, ..., 5.4985733, 4.6323256,
          6.0953207],
         [5.887835 , 4.7430115, 5.839446 , ..., 5.3647013, 4.6528296,
          5.216407 ],
         [4.731409 , 5.5231156, 5.037255 , ..., 4.947035 , 4.7249427,
          5.1621213],
         ...,
         [5.943808 , 6.598487 , 5.279142 , ..., 7.032588 , 6.25327  ,
          7.3865237],
         [5.508034 , 5.4245048, 5.4503727, ..., 6.4917355, 5.982107 ,
          6.5126786],
         [6.5220366, 5.4298773, 5.8106337, ..., 5.668078 , 5.9818754,
          6.049671 ]],

        [[5.707375 , 5.537081 , 6.193824 , ..., 5.104591 , 5.1793427,
          5.3827815],
         [5.689829 , 5.1114807, 5.611065 , ..., 4.686502 , 4.844582 ,
          5.524507 ],
         [5.0667048, 4.7684093, 5.114347 , ..., 4.333488 , 3.8729854,
          4.2466536],
         ...,
         [5.2362638, 6.1093154, 4.747502 , ..., 6.542044 , 6.26671  ,
          5.6878943],
         [5.0081067, 5.202534 , 5.1538105, ..., 5.7579174, 6.5322604,
          6.2236714],
         [5.534713 , 5.7451835, 5.420356 , ..., 5.7245054, 5.6039653,
          5.694382 ]],

        [[7.33721  , 6.5606656, 7.5397043, ..., 6.723059 , 6.309789 ,
          6.8766327],
         [7.075066 , 6.3946285, 6.517103 , ..., 6.07599  , 6.0083413,
          5.92645  ],
         [5.471342 , 6.0628424, 6.409721 , ..., 5.255186 , 4.6954293,
          6.592276 ],
         ...,
         [6.8590612, 7.5713954, 6.8450813, ..., 7.989971 , 7.597007 ,
          8.374088 ],
         [6.225486 , 6.791375 , 6.517845 , ..., 7.8340845, 7.4529495,
          7.456212 ],
         [6.784119 , 6.360591 , 6.4721723, ..., 6.890297 , 6.8633842,
          7.3899875]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[    3,     6,   257, ...,   484,   118,   120],
         [  869,   749,   879, ...,   735,   612,   867],
         [ 1117,   996,  1000, ...,  1476,  1481,  1236],
         ...,
         [13891, 14141, 14271, ..., 14373, 14376, 14010],
         [14759, 14763, 14516, ..., 14623, 14750, 14878],
         [15129, 14887, 14888, ..., 15364, 14998, 15125]],

        [[    3,     6,   381, ...,   360,   117,   123],
         [  869,   749,   876, ...,   609,   985,   867],
         [ 1241,  1120,  1248, ...,  1352,  1481,  1360],
         ...,
         [14139, 14267, 14268, ..., 14372, 14252, 14011],
         [14757, 14639, 14519, ..., 14499, 14874, 14877],
         [15007, 14887, 14888, ..., 15364, 15122, 15001]],

        [[    3,     6,     9, ...,   360,   117,   123],
         [  745,   749,   877, ...,   734,   985,   867],
         [ 1116,  1244,  1248, ...,  1352,  1481,  1484],
         ...,
         [14014, 14267, 14269, ..., 14372, 14376, 14008],
         [14635, 14639, 14392, ..., 14747, 14626, 14878],
         [15007, 14887, 14891, ..., 15364, 14998, 15003]],

        ...,

        [[    2,     4,   381, ...,   363,   116,   123],
         [  745,   503,   876, ...,   733,   736,   867],
         [ 1117,   996,  1000, ...,  1476,  1481,  1360],
         ...,
         [13890, 14019, 14269, ..., 14374, 14252, 14008],
         [14757, 14763, 14641, ..., 14498, 14875, 14754],
         [15005, 14884, 14888, ..., 14995, 14998, 15127]],

        [[    3,     6,   381, ...,   487,   241,   120],
         [  869,   749,   876, ...,   733,   736,   991],
         [  993,   996,  1000, ...,  1478,  1481,  1360],
         ...,
         [13890, 14267, 14268, ..., 14373, 14252, 14134],
         [14635, 14515, 14516, ..., 14747, 14500, 14877],
         [15005, 14887, 14888, ..., 15364, 14998, 15002]],

        [[    3,     4,   257, ...,   484,   116,   123],
         [  745,   875,   876, ...,   733,   736,   743],
         [  993,   996,  1000, ...,  1476,  1357,  1237],
         ...,
         [14138, 14265, 14147, ..., 14375, 14252, 14010],
         [14635, 14639, 14392, ..., 14499, 14500, 14877],
         [15005, 15010, 15136, ..., 15240, 15122, 15001]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[7.8439555, 7.374462 , 7.6010556, ..., 8.507759 , 8.757925 ,
        8.998404 ]], shape=(1, 30752), dtype=float32), array([[117425.445, 116953.836, 117018.1  , ..., 117248.9  , 117419.86 ,
        117128.484]], shape=(1, 1024), dtype=float32), array([[6.1343380e+07, 6.2636124e+07, 5.9247464e+07, 6.0892408e+07,
        5.8150496e+07, 6.1070312e+07, 5.8421432e+07, 6.1112796e+07,
        6.0667044e+07, 5.9761064e+07, 6.0731944e+07, 6.0729200e+07,
        6.0464232e+07, 5.9464180e+07, 5.9798892e+07, 5.9603756e+07,
        6.0550064e+07, 5.8835552e+07, 5.9909992e+07, 5.9340920e+07,
        5.8943608e+07, 5.8785496e+07, 6.0178400e+07, 6.0820968e+07,
        5.9141816e+07, 5.9457196e+07, 5.9368456e+07, 6.0708072e+07,
        5.9597232e+07, 6.1526856e+07, 6.0072272e+07, 6.0797724e+07,
        5.9145564e+07, 5.9507552e+07, 6.1506528e+07, 6.0343884e+07,
        5.9225664e+07, 6.0059012e+07, 6.1453308e+07, 6.1497408e+07,
        6.1240820e+07, 5.8557168e+07, 6.1503244e+07, 5.9650908e+07,
        5.9718768e+07, 5.9576368e+07, 6.0403072e+07, 5.9344496e+07,
        5.9739256e+07, 6.0360028e+07, 5.9566040e+07, 5.9019644e+07,
        6.1985736e+07, 6.1077056e+07, 6.2145440e+07, 6.0551552e+07,
        6.1109576e+07, 5.8255640e+07, 6.2266184e+07, 6.2205528e+07,
        5.9402564e+07, 6.0764832e+07, 6.0396964e+07, 5.8904640e+07,
        6.1015764e+07, 5.8554672e+07, 6.1066844e+07, 5.9413208e+07,
        5.9035636e+07, 6.1841156e+07, 5.7880032e+07, 6.0978224e+07,
        5.8592384e+07, 5.8777848e+07, 5.8602480e+07, 5.9209092e+07,
        5.9232752e+07, 6.0686652e+07, 6.1306652e+07, 6.0160496e+07,
        5.9661200e+07, 6.0648920e+07, 6.0720400e+07, 5.9872164e+07,
        6.0106492e+07, 6.0077072e+07, 5.9365584e+07, 5.8362724e+07,
        5.7927708e+07, 6.0234168e+07, 6.1023460e+07, 6.0089428e+07,
        5.9343032e+07, 6.1039844e+07, 6.1351236e+07, 5.9827360e+07,
        5.9070420e+07, 5.9593320e+07, 6.0585360e+07, 5.9890728e+07,
        5.9452552e+07, 5.9800864e+07, 5.8910272e+07, 6.0819172e+07,
        6.1327764e+07, 6.0617828e+07, 5.8974616e+07, 5.9805400e+07,
        6.1213064e+07, 5.9533720e+07, 6.1804220e+07, 5.8574240e+07,
        6.0623760e+07, 5.9123880e+07, 6.1209360e+07, 6.0802704e+07,
        5.9781792e+07, 6.0100936e+07, 6.0957520e+07, 6.0918424e+07,
        5.8822968e+07, 5.9812152e+07, 6.1266608e+07, 5.9730688e+07,
        5.9199544e+07, 6.1678616e+07, 6.0241704e+07, 6.0782984e+07]],
      dtype=float32), array([[3.5455688e+09, 3.9710582e+09, 3.3279688e+09, 3.8877512e+09,
        4.1192499e+09, 3.7947940e+09, 4.1565317e+09, 3.9360046e+09,
        3.9853967e+09, 3.9054011e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.225 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Better shape inference

201: Better shape inference

101: A custom backend for torch

101: A custom backend for torch

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

Gallery generated by Sphinx-Gallery