101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in "filename,repeat".split(","):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.49811938, 0.82152057, 0.3373039 , ..., 0.9690496 , 0.7567    ,
        0.31627807],
       [0.05741106, 0.61383396, 0.69814515, ..., 0.00981634, 0.03271949,
        0.76451415],
       [0.37014586, 0.5706923 , 0.94200885, ..., 0.6017973 , 0.10141023,
        0.8193094 ],
       ...,
       [0.21534732, 0.9362901 , 0.31112844, ..., 0.9897062 , 0.6552406 ,
        0.94650924],
       [0.8500216 , 0.2480933 , 0.42366222, ..., 0.93902403, 0.12038182,
        0.47769457],
       [0.18478155, 0.6985593 , 0.8998448 , ..., 0.03175963, 0.89774   ,
        0.8473401 ]], dtype=float32), array([[0.23654357, 0.8846211 , 0.7895143 , ..., 0.9358064 , 0.5486316 ,
        0.8861085 ],
       [0.4148546 , 0.89466226, 0.4516352 , ..., 0.6255415 , 0.6314595 ,
        0.28775826],
       [0.5825523 , 0.14544408, 0.6117085 , ..., 0.05096382, 0.13781106,
        0.2697828 ],
       ...,
       [0.19106524, 0.04777327, 0.1348371 , ..., 0.00879816, 0.8628678 ,
        0.28616464],
       [0.7186147 , 0.51735187, 0.5539227 , ..., 0.5829653 , 0.43757367,
        0.10427039],
       [0.3955041 , 0.66906875, 0.3431797 , ..., 0.03259649, 0.9346118 ,
        0.29555398]], dtype=float32), array([[0.8238373 , 0.22705522, 0.42292988, ..., 0.22040927, 0.9317856 ,
        0.37155324],
       [0.06378181, 0.39271095, 0.10566166, ..., 0.22306833, 0.09067719,
        0.05038857],
       [0.5688872 , 0.04833359, 0.69702303, ..., 0.3372865 , 0.80998343,
        0.55208874],
       ...,
       [0.24066575, 0.81232107, 0.24020675, ..., 0.41089785, 0.76754063,
        0.09112121],
       [0.73953503, 0.14319196, 0.89437425, ..., 0.05385933, 0.2680572 ,
        0.23563877],
       [0.97741354, 0.87072533, 0.87953866, ..., 0.85288197, 0.13573973,
        0.01255301]], dtype=float32), array([[[[5.565089 , 5.103101 , 4.5174675, ..., 6.0058846, 5.958382 ,
          5.4275723],
         [5.745741 , 5.5332584, 4.767103 , ..., 7.137676 , 6.752444 ,
          5.708514 ],
         [5.605507 , 6.2635736, 5.501947 , ..., 6.328216 , 6.764166 ,
          6.2347217],
         ...,
         [6.753322 , 5.8139067, 6.4643807, ..., 6.5022063, 5.671295 ,
          6.1274652],
         [6.793723 , 6.1422615, 6.3913918, ..., 5.9583516, 4.947596 ,
          5.0844755],
         [5.7757854, 5.1850967, 6.2562585, ..., 5.155083 , 4.8032727,
          5.463747 ]],

        [[7.232889 , 5.9127665, 5.027247 , ..., 7.32901  , 6.9056263,
          6.888383 ],
         [6.209973 , 6.109516 , 5.9258747, ..., 7.99459  , 6.8433356,
          6.627201 ],
         [6.6054955, 6.3402686, 6.400939 , ..., 7.273025 , 7.7028365,
          7.488401 ],
         ...,
         [8.074051 , 7.767504 , 7.467769 , ..., 8.136668 , 7.035548 ,
          8.015495 ],
         [7.398261 , 6.792018 , 7.899523 , ..., 6.983091 , 6.2857585,
          6.1718197],
         [7.281235 , 7.182981 , 7.4691615, ..., 6.1005874, 5.021295 ,
          6.4314747]],

        [[5.7637143, 4.8422503, 4.1405168, ..., 5.827357 , 5.811146 ,
          5.7332745],
         [5.7922034, 3.9866033, 4.2269087, ..., 5.785495 , 6.3121614,
          5.7971487],
         [4.8831177, 4.9326544, 4.3113036, ..., 5.959783 , 5.5856047,
          5.4234543],
         ...,
         [5.5585494, 5.7830715, 5.203156 , ..., 5.5613346, 5.3291354,
          6.0616064],
         [5.38927  , 5.802525 , 6.0341926, ..., 5.338025 , 4.907255 ,
          6.0103025],
         [6.3623343, 5.3191495, 6.3359327, ..., 5.0720263, 4.6790676,
          4.8685727]],

        ...,

        [[5.3656626, 5.2313848, 5.1438403, ..., 6.5407805, 5.8744655,
          5.0493045],
         [5.342044 , 4.833243 , 4.9250784, ..., 6.265502 , 6.7738924,
          6.4219494],
         [5.494974 , 4.834587 , 5.6865044, ..., 5.897386 , 6.338793 ,
          5.8890886],
         ...,
         [6.899637 , 6.706461 , 6.3771486, ..., 6.59177  , 5.4907427,
          6.230597 ],
         [6.20377  , 5.6799264, 5.947048 , ..., 5.151845 , 5.4751916,
          5.16739  ],
         [5.8000546, 6.888856 , 5.3176284, ..., 4.7611704, 4.56488  ,
          4.786338 ]],

        [[6.156503 , 5.811774 , 5.3238378, ..., 7.531144 , 6.7949357,
          6.5749903],
         [5.286597 , 5.5069985, 6.0806537, ..., 7.161507 , 7.6698594,
          5.885137 ],
         [5.483968 , 5.7780085, 5.9242854, ..., 6.459801 , 6.5940285,
          6.609454 ],
         ...,
         [6.9081225, 6.7646356, 7.3245196, ..., 6.762631 , 6.119577 ,
          6.1035166],
         [7.1474175, 6.0217876, 6.609377 , ..., 6.1596785, 5.8335376,
          5.6525984],
         [6.874523 , 6.061458 , 7.470292 , ..., 6.034182 , 4.957194 ,
          4.944079 ]],

        [[6.6298065, 5.4778523, 4.0081897, ..., 6.080302 , 5.915414 ,
          5.071162 ],
         [5.2896533, 4.846699 , 5.0159574, ..., 6.968491 , 5.915704 ,
          5.1313715],
         [4.542079 , 5.0459566, 5.3878794, ..., 5.9239216, 5.8458543,
          5.7964206],
         ...,
         [5.7237134, 6.85091  , 5.3416386, ..., 6.1415   , 5.8549585,
          6.7072096],
         [6.004997 , 5.7064233, 6.2460246, ..., 6.116028 , 4.9547224,
          4.8880754],
         [7.134811 , 5.745888 , 5.4453382, ..., 4.930555 , 4.0033393,
          4.497093 ]]]], dtype=float32), array([[[[  249,   379,   133, ...,   363,   364,   245],
         [  745,   624,   629, ...,   857,   615,   741],
         [ 1243,  1244,  1127, ...,  1476,  1356,  1114],
         ...,
         [14137, 14017, 14021, ..., 14375, 14377, 14135],
         [14509, 14515, 14764, ..., 14498, 14500, 14753],
         [15255, 14887, 14888, ..., 15116, 15247, 15125]],

        [[  374,   377,   134, ...,   363,   365,   492],
         [  622,   872,   877, ...,   859,   738,   619],
         [ 1367,  1368,  1125, ...,  1230,  1357,  1238],
         ...,
         [14015, 14019, 14021, ..., 14375, 14377, 14380],
         [14385, 14514, 14640, ..., 14746, 14875, 14630],
         [14883, 14885, 14888, ..., 15242, 15370, 15125]],

        [[  124,   376,   133, ...,   363,   366,   244],
         [  871,   626,   630, ...,   735,   737,   742],
         [ 1367,  1121,  1375, ...,  1229,  1482,  1362],
         ...,
         [14136, 13895, 14268, ..., 14372, 14376, 14259],
         [14385, 14762, 14764, ..., 14622, 14500, 14877],
         [15007, 14887, 15013, ..., 15116, 15123, 15001]],

        ...,

        [[  375,   377,   134, ...,   114,   366,   368],
         [  745,   749,   628, ...,   735,   739,   990],
         [  995,  1368,  1375, ...,  1353,  1358,  1238],
         ...,
         [13891, 14019, 13897, ..., 14375, 14376, 14380],
         [14385, 14639, 14516, ..., 14620, 14500, 14877],
         [15004, 14887, 14889, ..., 15242, 15368, 15125]],

        [[  375,   376,   135, ...,   363,   366,   246],
         [  498,   751,   629, ...,   983,   614,   616],
         [ 1366,  1121,  1126, ...,  1229,  1233,  1238],
         ...,
         [14015, 13895, 14020, ..., 14372, 14376, 14259],
         [14758, 14515, 14640, ..., 14744, 14749, 14631],
         [15254, 14886, 15013, ..., 15116, 15368, 15125]],

        [[    0,   379,   134, ...,   363,   366,   245],
         [  747,   749,   628, ...,   859,   738,   990],
         [ 1367,  1245,  1375, ...,  1478,  1358,  1362],
         ...,
         [14015, 14019, 14021, ..., 14375, 14377, 14383],
         [14632, 14515, 14640, ..., 14747, 14500, 14631],
         [15252, 14886, 14889, ..., 14994, 14999, 15127]]]], dtype=int64), array([[6.2635736, 5.6826043, 6.172062 , ..., 6.1040564, 5.822977 ,
        6.7072096]], dtype=float32), array([[117578.305, 116757.78 , 118194.61 , ..., 117556.78 , 117410.74 ,
        117894.62 ]], dtype=float32), array([[59846544., 58236288., 58976596., 60590736., 62058104., 60432408.,
        59663544., 59718448., 61311092., 59508752., 59787840., 60822952.,
        59920728., 60014712., 60831268., 61057852., 61220080., 61549000.,
        58610184., 60068620., 61335560., 61194504., 60502104., 60003704.,
        59070296., 58443152., 61277268., 59282320., 60398944., 59343580.,
        59360288., 60844792., 59364908., 60653352., 60576632., 62950032.,
        60844288., 62809440., 60193900., 59834456., 59278804., 61281816.,
        59741584., 60839232., 59899280., 60719816., 59188440., 61837368.,
        61231392., 60013048., 60707972., 60957232., 60685696., 59340524.,
        62570220., 59074100., 62156332., 59758624., 61961888., 59178672.,
        59295944., 60203348., 58871752., 59733404., 59596988., 60141568.,
        60038548., 61583332., 61143492., 59138024., 60528984., 60534524.,
        60534380., 58101636., 58841352., 58926016., 59822960., 60417440.,
        59730784., 61367840., 62503976., 58468848., 58949380., 60468720.,
        58595856., 60187576., 60268520., 63818352., 58762480., 59508080.,
        62196608., 61047408., 60215336., 60450892., 59283104., 57954864.,
        61185920., 61690728., 60737716., 60368768., 58991252., 60454824.,
        60972984., 60767536., 59426240., 60702664., 60382692., 60766444.,
        59530424., 59392020., 59829224., 60664052., 61300568., 60710384.,
        60141088., 59636528., 58949032., 60760988., 59714968., 62630448.,
        62124576., 60557640., 59309288., 59972832., 61738392., 60953704.,
        60185584., 60170140.]], dtype=float32), array([[3.8632909e+09, 3.9355768e+09, 3.8385060e+09, 3.7788078e+09,
        3.9902116e+09, 3.7555151e+09, 3.8698854e+09, 4.2214922e+09,
        3.6336955e+09, 3.9290688e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.166 seconds)

Gallery generated by Sphinx-Gallery