101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.17552273, 0.68732756, 0.5290164 , ..., 0.35221276, 0.5952169 ,
        0.6361147 ],
       [0.26412433, 0.6056844 , 0.26679057, ..., 0.1014453 , 0.26523158,
        0.12201253],
       [0.15094997, 0.23318888, 0.33176416, ..., 0.47327882, 0.01688397,
        0.6390618 ],
       ...,
       [0.45506528, 0.60604316, 0.20031592, ..., 0.44629756, 0.14123793,
        0.605792  ],
       [0.2856596 , 0.5170656 , 0.1257047 , ..., 0.2743392 , 0.86900175,
        0.3817029 ],
       [0.68995863, 0.35227323, 0.3087481 , ..., 0.70574284, 0.8777836 ,
        0.08797178]], shape=(128, 1024), dtype=float32), array([[0.6262958 , 0.14943743, 0.2705886 , ..., 0.96264994, 0.17708224,
        0.6566122 ],
       [0.75425136, 0.6893254 , 0.55515844, ..., 0.31009108, 0.33601576,
        0.7627456 ],
       [0.09976625, 0.9933214 , 0.92155707, ..., 0.7390853 , 0.75784385,
        0.10904144],
       ...,
       [0.3405509 , 0.74223226, 0.43465543, ..., 0.16283236, 0.17136581,
        0.6028398 ],
       [0.7576841 , 0.14133763, 0.0268735 , ..., 0.36411503, 0.42016342,
        0.2523382 ],
       [0.6622517 , 0.9972709 , 0.70545095, ..., 0.55234516, 0.55279505,
        0.37863356]], shape=(1024, 30752), dtype=float32), array([[0.87253857, 0.61326844, 0.29135424, ..., 0.520614  , 0.4448589 ,
        0.3812911 ],
       [0.2610425 , 0.5773703 , 0.28815785, ..., 0.3792513 , 0.19267453,
        0.5602341 ],
       [0.00478233, 0.33551973, 0.9061004 , ..., 0.37026966, 0.23843631,
        0.37388778],
       ...,
       [0.6506074 , 0.08489661, 0.07420032, ..., 0.38674375, 0.18430923,
        0.5142526 ],
       [0.19755149, 0.05648618, 0.5339526 , ..., 0.13197313, 0.27616256,
        0.34396595],
       [0.5768615 , 0.06713143, 0.7600817 , ..., 0.83105594, 0.00970379,
        0.36359698]], shape=(10, 128), dtype=float32), array([[[[6.541774 , 7.1397333, 6.4583387, ..., 6.6771774, 6.4730577,
          6.143708 ],
         [6.991296 , 7.9154997, 6.8231926, ..., 6.8475914, 7.4929714,
          5.790102 ],
         [7.8834457, 6.731552 , 8.608542 , ..., 6.4620914, 5.9557953,
          5.6758785],
         ...,
         [7.0229406, 7.358061 , 6.3541336, ..., 8.081129 , 6.435375 ,
          7.715969 ],
         [6.887336 , 6.7774124, 7.4658194, ..., 5.5332785, 6.699683 ,
          6.575598 ],
         [7.5383596, 6.978982 , 6.1099443, ..., 5.840475 , 5.1531305,
          5.4639664]],

        [[6.4287634, 7.3069305, 6.4111567, ..., 5.5177803, 5.7740135,
          6.4742136],
         [7.5329466, 6.5277605, 7.914441 , ..., 6.233163 , 6.396261 ,
          6.311652 ],
         [7.757143 , 7.386943 , 5.841728 , ..., 6.425473 , 6.9136934,
          5.5440044],
         ...,
         [6.0037737, 6.775131 , 7.2628813, ..., 6.22062  , 6.3684034,
          6.139326 ],
         [6.118169 , 5.3894033, 6.7282476, ..., 5.6896772, 5.2998266,
          6.248169 ],
         [6.0344553, 5.4952593, 5.9007463, ..., 5.010995 , 4.769184 ,
          5.8169947]],

        [[7.710986 , 7.8632936, 7.7742176, ..., 7.750189 , 7.619346 ,
          7.581566 ],
         [8.428953 , 8.669618 , 7.4389567, ..., 6.1240788, 6.937768 ,
          8.110433 ],
         [8.529059 , 8.606105 , 6.8140907, ..., 7.2021084, 7.4796557,
          8.078393 ],
         ...,
         [7.1082826, 7.734694 , 7.134812 , ..., 8.082739 , 6.9610605,
          7.4709454],
         [6.705849 , 7.094438 , 6.665864 , ..., 6.979936 , 5.616958 ,
          6.6198244],
         [7.051154 , 8.128046 , 7.561047 , ..., 6.1680183, 5.8016076,
          6.852093 ]],

        ...,

        [[7.6443763, 8.171648 , 7.9598823, ..., 7.3544908, 6.9462285,
          6.963777 ],
         [8.032927 , 9.0624275, 7.686527 , ..., 6.6499424, 7.2192874,
          8.614247 ],
         [8.6131735, 8.59194  , 6.555879 , ..., 6.4765887, 7.7468696,
          8.064987 ],
         ...,
         [7.4804873, 7.6413074, 6.7742167, ..., 8.395108 , 7.581279 ,
          6.6565967],
         [6.882045 , 7.4179277, 6.6342173, ..., 7.494274 , 5.386202 ,
          6.537912 ],
         [6.5171466, 7.4882607, 7.8532877, ..., 6.2411585, 5.8037887,
          7.885693 ]],

        [[6.826994 , 7.509895 , 7.068952 , ..., 6.610836 , 7.276745 ,
          6.538041 ],
         [7.6680346, 8.223108 , 7.161263 , ..., 5.848102 , 6.4660506,
          7.023722 ],
         [8.188906 , 7.121763 , 6.209566 , ..., 5.0943704, 6.4900093,
          7.317706 ],
         ...,
         [6.9751906, 6.9915376, 6.3570476, ..., 7.2226825, 5.8441195,
          6.5887012],
         [5.93663  , 6.758532 , 6.5631847, ..., 6.236425 , 6.0695944,
          6.1516733],
         [6.4594965, 6.3136163, 6.73944  , ..., 5.298742 , 4.3623834,
          7.0184884]],

        [[7.250488 , 7.5170054, 7.9495263, ..., 7.8090196, 7.5962143,
          6.4389358],
         [8.3508005, 8.858233 , 7.555114 , ..., 7.0662875, 8.539405 ,
          7.5458837],
         [7.890096 , 8.068041 , 7.548929 , ..., 6.439225 , 7.410716 ,
          7.466496 ],
         ...,
         [7.985826 , 7.102221 , 6.146627 , ..., 8.187226 , 7.072011 ,
          7.066733 ],
         [7.3340826, 7.1575885, 6.4419236, ..., 7.5089645, 6.5330195,
          6.7033505],
         [7.217041 , 8.488352 , 7.5277205, ..., 5.8744206, 5.734039 ,
          7.2435207]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[  373,   254,   383, ...,   484,   243,   494],
         [  499,   501,   754, ...,   982,   862,   990],
         [ 1117,  1368,  1127, ...,  1478,  1234,  1114],
         ...,
         [14261, 13894, 14146, ..., 14374, 14130, 14258],
         [14508, 14515, 14393, ..., 14499, 14875, 14877],
         [14881, 15010, 15012, ..., 15242, 14999, 15124]],

        [[  126,   130,   256, ...,   238,   241,   493],
         [  497,   503,   755, ...,   981,   863,   866],
         [ 1118,  1371,  1127, ...,  1479,  1359,  1114],
         ...,
         [14263, 14018, 14146, ..., 14372, 14006, 14382],
         [14384, 14388, 14517, ..., 14620, 14749, 14876],
         [14881, 14886, 14888, ..., 15241, 14998, 15001]],

        [[  125,   129,   380, ...,   238,   119,   494],
         [  496,   627,   755, ...,   857,   738,   742],
         [ 1117,  1247,  1127, ...,  1476,  1235,  1114],
         ...,
         [14261, 14143, 14147, ..., 14372, 14130, 14134],
         [14757, 14515, 14394, ..., 14496, 14627, 14876],
         [14881, 14887, 14888, ..., 14995, 14997, 15000]],

        ...,

        [[  125,   129,   256, ...,   113,   119,   247],
         [  496,   503,   755, ...,   981,   738,   990],
         [ 1117,  1247,  1127, ...,  1104,  1235,  1114],
         ...,
         [14260, 13894, 14147, ..., 14372, 14129, 14134],
         [14509, 14636, 14640, ..., 14496, 14627, 14876],
         [14881, 14884, 15136, ..., 14995, 14997, 15124]],

        [[  125,   253,   257, ...,   239,   119,   494],
         [  496,   503,   631, ...,   980,   739,   742],
         [ 1117,  1120,  1127, ...,  1478,  1480,  1239],
         ...,
         [14138, 14142, 14147, ..., 14372, 14255, 14258],
         [14509, 14636, 14517, ..., 14496, 14627, 14876],
         [14881, 14887, 14888, ..., 15116, 14997, 15000]],

        [[  125,     6,   383, ...,   238,   243,   493],
         [  620,   501,   755, ...,   980,   862,   990],
         [ 1117,  1120,  1126, ...,  1478,  1359,  1114],
         ...,
         [14260, 13894, 14146, ..., 14374, 14252, 14259],
         [14386, 14515, 14640, ..., 14498, 14627, 14876],
         [14881, 14884, 15136, ..., 14995, 14999, 15124]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[ 8.818667 ,  8.741206 ,  7.6588907, ...,  7.623483 ,  9.450416 ,
        10.103    ]], shape=(1, 30752), dtype=float32), array([[121938.49, 122991.26, 122431.5 , ..., 122242.68, 122499.35,
        121873.24]], shape=(1, 1024), dtype=float32), array([[6.2450444e+07, 6.2220700e+07, 6.1190176e+07, 6.3167696e+07,
        6.2004308e+07, 6.1224968e+07, 6.3300728e+07, 6.3419072e+07,
        6.1955488e+07, 6.2982268e+07, 6.2610872e+07, 6.3773728e+07,
        6.4853100e+07, 6.3461352e+07, 6.3389968e+07, 6.1539860e+07,
        6.4640764e+07, 6.1633064e+07, 6.3224144e+07, 6.0951288e+07,
        6.2754852e+07, 6.4091496e+07, 6.2561488e+07, 6.3191880e+07,
        6.4009104e+07, 6.2058772e+07, 6.2818204e+07, 6.4956424e+07,
        6.4168700e+07, 6.1692984e+07, 6.2049728e+07, 6.3018308e+07,
        6.2309576e+07, 6.3348184e+07, 6.2978608e+07, 6.2252444e+07,
        6.4097916e+07, 6.1873804e+07, 6.1026408e+07, 6.3245608e+07,
        6.2778556e+07, 6.4267408e+07, 6.3394620e+07, 6.1036804e+07,
        6.3017628e+07, 6.2088788e+07, 6.2658288e+07, 6.1447776e+07,
        6.3300524e+07, 6.1952940e+07, 6.3819272e+07, 6.2927200e+07,
        6.1671440e+07, 6.2876016e+07, 6.3558792e+07, 6.3583796e+07,
        6.3187736e+07, 6.3879340e+07, 6.3893040e+07, 6.1954264e+07,
        6.2760296e+07, 6.2056180e+07, 6.4518320e+07, 6.2761804e+07,
        6.3484816e+07, 6.3962104e+07, 6.3886112e+07, 6.1592400e+07,
        6.1213908e+07, 6.3018244e+07, 6.3060572e+07, 6.4491304e+07,
        6.1076944e+07, 6.2846312e+07, 6.3781372e+07, 6.1998904e+07,
        6.2545476e+07, 6.3056836e+07, 6.2192024e+07, 6.3603904e+07,
        6.3614616e+07, 6.1620200e+07, 6.1389972e+07, 6.0547028e+07,
        6.0045824e+07, 6.2011444e+07, 6.2373192e+07, 6.1064968e+07,
        6.3699060e+07, 6.2242632e+07, 6.3014820e+07, 6.3220616e+07,
        6.6438148e+07, 6.0932744e+07, 6.2267676e+07, 6.3499132e+07,
        6.3538168e+07, 6.3654944e+07, 6.3216960e+07, 6.2468476e+07,
        6.3452512e+07, 6.3855344e+07, 6.3842688e+07, 6.1678632e+07,
        6.0655608e+07, 6.4783488e+07, 6.2537184e+07, 6.3015240e+07,
        6.1708304e+07, 6.4084712e+07, 6.4186248e+07, 6.3160056e+07,
        6.2585448e+07, 6.4225536e+07, 6.3082508e+07, 6.3511488e+07,
        6.3496200e+07, 6.3591412e+07, 6.1728960e+07, 6.3251836e+07,
        6.3813520e+07, 6.1531944e+07, 6.4221516e+07, 6.4271608e+07,
        6.1765020e+07, 6.3652768e+07, 6.2478248e+07, 6.3571648e+07]],
      dtype=float32), array([[3.9998950e+09, 3.9757458e+09, 4.4043233e+09, 3.8797891e+09,
        4.1281039e+09, 3.8398976e+09, 4.0783470e+09, 4.0691031e+09,
        4.0202798e+09, 4.1086341e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.190 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Better shape inference

201: Better shape inference

101: A custom backend for torch

101: A custom backend for torch

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

Gallery generated by Sphinx-Gallery