101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in "filename,repeat".split(","):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.91398245, 0.57628196, 0.48764992, ..., 0.3955947 , 0.52806693,
        0.52786803],
       [0.31164184, 0.58315444, 0.80397445, ..., 0.61994034, 0.94090754,
        0.17720585],
       [0.3395866 , 0.5943828 , 0.53631425, ..., 0.35516572, 0.09682213,
        0.33556533],
       ...,
       [0.25707904, 0.7535674 , 0.7554002 , ..., 0.36074117, 0.06085893,
        0.13981643],
       [0.07814518, 0.09656924, 0.0421176 , ..., 0.8285064 , 0.28555152,
        0.42705196],
       [0.41327047, 0.14391418, 0.40519983, ..., 0.5858572 , 0.62700355,
        0.7861884 ]], dtype=float32), array([[0.47667688, 0.39695138, 0.9381092 , ..., 0.68025386, 0.58981436,
        0.19682835],
       [0.96116537, 0.8519381 , 0.12407897, ..., 0.91617495, 0.2277926 ,
        0.9581247 ],
       [0.69623035, 0.5257477 , 0.20292161, ..., 0.4543687 , 0.6661971 ,
        0.22071014],
       ...,
       [0.38079822, 0.68998075, 0.7934502 , ..., 0.99610054, 0.05723822,
        0.02473171],
       [0.54504544, 0.2616041 , 0.29621118, ..., 0.4465734 , 0.28895664,
        0.769591  ],
       [0.4225101 , 0.58036405, 0.61927366, ..., 0.5072816 , 0.4815128 ,
        0.2072598 ]], dtype=float32), array([[0.87149805, 0.46779412, 0.7452013 , ..., 0.59170425, 0.9209106 ,
        0.42639002],
       [0.82107943, 0.32070556, 0.8573875 , ..., 0.56483036, 0.260574  ,
        0.09135629],
       [0.24738845, 0.2951498 , 0.7146122 , ..., 0.380619  , 0.09730798,
        0.60640943],
       ...,
       [0.46793056, 0.66070056, 0.48432958, ..., 0.99664325, 0.9882933 ,
        0.22435805],
       [0.20465992, 0.15862264, 0.65854317, ..., 0.84058243, 0.66796696,
        0.79107755],
       [0.33797637, 0.05030651, 0.7483493 , ..., 0.0564948 , 0.6848731 ,
        0.46056956]], dtype=float32), array([[[[7.205508 , 7.9895926, 8.416544 , ..., 6.770774 , 7.1625376,
          6.483858 ],
         [7.2672195, 7.465618 , 7.3520966, ..., 6.974065 , 7.622259 ,
          7.517603 ],
         [8.056629 , 7.7550955, 7.3357353, ..., 7.7675734, 8.218457 ,
          8.343167 ],
         ...,
         [6.913253 , 7.083952 , 6.453188 , ..., 7.847801 , 7.494713 ,
          7.8264074],
         [5.8578596, 5.2047787, 5.60857  , ..., 7.2954445, 7.638356 ,
          7.395913 ],
         [7.3160095, 5.8252716, 5.454154 , ..., 7.4482   , 7.358966 ,
          6.83504  ]],

        [[6.260421 , 6.4745965, 7.3947716, ..., 6.7765603, 6.589321 ,
          6.5809693],
         [6.9425473, 6.665863 , 7.7645326, ..., 6.6559825, 6.4143305,
          7.185762 ],
         [7.4409366, 7.647948 , 6.9305844, ..., 6.788102 , 7.252824 ,
          7.472464 ],
         ...,
         [6.580416 , 5.977857 , 6.6630354, ..., 7.151876 , 7.6054873,
          6.188504 ],
         [6.3581014, 4.8171573, 5.2133064, ..., 6.865668 , 5.930305 ,
          6.7570367],
         [6.217736 , 5.5824056, 5.3415155, ..., 7.3240457, 6.229552 ,
          6.772177 ]],

        [[5.116146 , 6.1506577, 6.2882648, ..., 5.455061 , 5.7735896,
          6.2993226],
         [5.7724338, 5.2778144, 5.490645 , ..., 5.749422 , 5.0189543,
          6.112331 ],
         [5.991749 , 6.41042  , 6.491544 , ..., 6.086307 , 6.8470335,
          5.5497823],
         ...,
         [5.746537 , 5.4410415, 6.043397 , ..., 6.1263237, 6.614131 ,
          5.5966153],
         [4.83159  , 3.9778526, 3.850477 , ..., 5.3486814, 5.485555 ,
          6.6554236],
         [5.3024807, 5.1484284, 5.582978 , ..., 6.465997 , 5.312218 ,
          5.423742 ]],

        ...,

        [[7.871123 , 8.718124 , 8.500517 , ..., 8.246982 , 6.7822537,
          7.3001914],
         [7.943476 , 8.12264  , 8.797746 , ..., 7.1407943, 8.218125 ,
          8.40282  ],
         [8.7769985, 8.391742 , 7.024096 , ..., 8.839397 , 9.07862  ,
          8.473246 ],
         ...,
         [7.268252 , 6.7797985, 7.61532  , ..., 8.0960045, 8.261134 ,
          8.577002 ],
         [6.6318307, 6.080726 , 6.6229615, ..., 8.542605 , 7.6251235,
          9.220777 ],
         [6.7647943, 6.6094775, 5.9505105, ..., 8.322737 , 8.047828 ,
          8.65976  ]],

        [[6.0740004, 6.434598 , 6.6850815, ..., 5.139678 , 7.0877285,
          4.9569206],
         [5.4855757, 6.0544186, 6.5418234, ..., 5.907666 , 6.399254 ,
          6.6314116],
         [7.3262463, 6.4814157, 6.3402596, ..., 7.255395 , 5.759486 ,
          6.988225 ],
         ...,
         [5.3537436, 6.1104374, 6.298465 , ..., 6.884762 , 6.7736454,
          6.212181 ],
         [4.9507933, 3.9594915, 4.143952 , ..., 5.8256702, 6.6111565,
          5.860141 ],
         [6.217616 , 5.4622936, 4.6362677, ..., 5.642166 , 6.3643227,
          5.842002 ]],

        [[6.9261794, 6.779761 , 7.083336 , ..., 6.6269736, 5.9768634,
          5.9925976],
         [6.841673 , 6.8821535, 7.004058 , ..., 7.052854 , 6.4486403,
          6.5491166],
         [7.0067296, 7.0405135, 6.5167766, ..., 7.1424656, 7.1956654,
          7.4732656],
         ...,
         [5.8668175, 5.9137573, 6.1626215, ..., 6.865367 , 7.8011456,
          6.1410894],
         [5.517689 , 4.248037 , 5.042473 , ..., 7.0548387, 6.4594965,
          6.0213876],
         [5.998788 , 5.2361135, 5.016924 , ..., 7.831756 , 6.2911563,
          6.0944743]]]], dtype=float32), array([[[[  372,   379,   258, ...,   362,   365,   371],
         [  497,   872,   879, ...,   980,   739,   618],
         [  995,   996,  1003, ...,  1104,  1233,  1114],
         ...,
         [14137, 14018, 13899, ..., 14001, 14006, 14135],
         [14510, 14763, 14643, ..., 14871, 14873, 14752],
         [14883, 15259, 15262, ..., 14995, 14997, 15000]],

        [[  373,   128,   382, ...,   238,   117,   371],
         [  496,   624,   878, ...,   980,   613,   618],
         [  995,   996,  1003, ...,  1104,  1357,  1113],
         ...,
         [14261, 14264, 14269, ..., 14002, 14006, 14009],
         [14510, 14762, 14766, ..., 14870, 14873, 14752],
         [14883, 15259, 15014, ..., 14995, 14998, 15124]],

        [[  372,     5,    11, ...,   362,   365,   370],
         [  746,   875,   879, ...,   856,   613,   617],
         [  994,   998,  1127, ...,  1104,  1356,  1114],
         ...,
         [14260, 13893, 13899, ..., 14002, 14006, 14135],
         [14387, 14760, 14767, ..., 14870, 14872, 14752],
         [15007, 15135, 15260, ..., 14995, 15120, 15251]],

        ...,

        [[  372,   254,   258, ...,   238,   240,   494],
         [  496,   873,   879, ...,   980,   613,   741],
         [  994,   996,  1003, ...,  1105,  1232,  1484],
         ...,
         [14261, 14018, 13898, ..., 14127, 14007, 14135],
         [14510, 14389, 14767, ..., 14870, 14872, 14752],
         [14883, 15135, 15137, ..., 14995, 14996, 15000]],

        [[  248,     4,   382, ...,   363,   241,   369],
         [  497,   874,   879, ...,   980,   863,   742],
         [  995,   997,  1375, ...,  1104,  1358,  1112],
         ...,
         [14137, 14019, 13899, ..., 14001, 14006, 14011],
         [14510, 14636, 14766, ..., 14871, 14873, 14876],
         [15006, 15259, 15014, ..., 14995, 15244, 15001]],

        [[  373,   252,   382, ...,   362,   240,   368],
         [  496,   624,   506, ...,   980,   862,   618],
         [  995,   996,  1003, ...,  1104,  1357,  1114],
         ...,
         [14262, 14142, 13898, ..., 14000, 14005, 14135],
         [14509, 14636, 14766, ..., 14870, 14873, 14752],
         [15007, 15259, 15013, ..., 14995, 14997, 15373]]]], dtype=int64), array([[8.478149 , 7.7415524, 8.194025 , ..., 7.9691305, 8.42513  ,
        7.831756 ]], dtype=float32), array([[122311.34 , 122470.555, 122021.74 , ..., 121663.766, 122541.27 ,
        122380.37 ]], dtype=float32), array([[62954464., 61740216., 61013392., 60682196., 63875272., 62679800.,
        63408672., 60948412., 64281992., 63381936., 64307360., 64403020.,
        61565344., 62251208., 62248408., 63189680., 63147016., 63402980.,
        62256272., 62496312., 64481904., 62637632., 62345368., 64697372.,
        61741924., 61665400., 61223840., 61516116., 64228740., 62423644.,
        62610144., 61648480., 64921048., 62355680., 62512224., 63721700.,
        61828672., 61592120., 62517488., 64757480., 62091776., 61604048.,
        61134364., 62620720., 64929752., 62785368., 62697552., 64157768.,
        61666024., 63453024., 63248024., 63152216., 62553164., 60997044.,
        61797560., 63103188., 63727820., 62445656., 62467520., 63751204.,
        65052816., 64381952., 60688352., 63484808., 62104568., 61991008.,
        64230896., 63531720., 62413112., 62044092., 63717560., 63264736.,
        63445592., 63367236., 60130480., 62355728., 62062500., 63019924.,
        61160204., 63512624., 62444544., 63664784., 64227928., 62368784.,
        63602248., 63713976., 62248200., 61052456., 62105060., 61357232.,
        63723532., 63136000., 59691656., 62853992., 62107412., 63981192.,
        62717824., 61267524., 63453692., 61288608., 62116480., 62351944.,
        65580556., 63909912., 63410272., 62190400., 62707108., 61451220.,
        62037912., 62947060., 63640856., 62525152., 61416576., 61228320.,
        62707616., 63277360., 63264856., 60375660., 61273596., 61515272.,
        62446944., 62716088., 61571832., 62582452., 61839944., 61273944.,
        62814292., 62870916.]], dtype=float32), array([[4.0849713e+09, 3.9629972e+09, 3.8007370e+09, 4.1178373e+09,
        3.7346391e+09, 4.1855841e+09, 4.5715620e+09, 3.7984412e+09,
        4.2013747e+09, 4.2602982e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for i in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name', 'args_op_name',
       'op_name', 'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'event_name', 'iteration', 'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 5.058 seconds)

Gallery generated by Sphinx-Gallery