101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.48701167, 0.8743634 , 0.41765413, ..., 0.5896347 , 0.50106037,
        0.79936886],
       [0.22311887, 0.7746693 , 0.36482093, ..., 0.0888119 , 0.63223225,
        0.24123782],
       [0.42908487, 0.31348827, 0.37753874, ..., 0.2996923 , 0.5862409 ,
        0.46235615],
       ...,
       [0.63901   , 0.958252  , 0.8576642 , ..., 0.84840333, 0.27156168,
        0.92425513],
       [0.7107026 , 0.22514978, 0.7041683 , ..., 0.8769297 , 0.5114278 ,
        0.2611587 ],
       [0.09398966, 0.04657802, 0.36166787, ..., 0.12336926, 0.382032  ,
        0.8705768 ]], shape=(128, 1024), dtype=float32), array([[0.70109737, 0.30644372, 0.73326844, ..., 0.7341234 , 0.19501053,
        0.7508927 ],
       [0.559656  , 0.33357525, 0.8111176 , ..., 0.17147014, 0.36077788,
        0.14913547],
       [0.41688848, 0.24148993, 0.57769305, ..., 0.9043198 , 0.32304505,
        0.93879217],
       ...,
       [0.78704005, 0.15369062, 0.5913906 , ..., 0.37713814, 0.7967207 ,
        0.8414134 ],
       [0.44902855, 0.52110785, 0.875147  , ..., 0.24375667, 0.66569865,
        0.08956924],
       [0.10142928, 0.09177364, 0.41900054, ..., 0.17988825, 0.7206119 ,
        0.42871594]], shape=(1024, 30752), dtype=float32), array([[0.86050165, 0.9987919 , 0.6943593 , ..., 0.6296961 , 0.9132347 ,
        0.36226135],
       [0.39775822, 0.28806388, 0.61356646, ..., 0.12404636, 0.74102914,
        0.8745291 ],
       [0.3514584 , 0.78925663, 0.69602776, ..., 0.08101459, 0.2081717 ,
        0.57725066],
       ...,
       [0.21829125, 0.6406352 , 0.10680309, ..., 0.19909513, 0.99277294,
        0.20461668],
       [0.5313196 , 0.05398373, 0.03478933, ..., 0.08970655, 0.14648752,
        0.2508738 ],
       [0.66879046, 0.98973495, 0.37505123, ..., 0.32469818, 0.3674183 ,
        0.25382754]], shape=(10, 128), dtype=float32), array([[[[6.9381046, 6.5217695, 7.6244864, ..., 5.582769 , 5.8021383,
          6.2628956],
         [7.6268106, 7.2249756, 7.9783726, ..., 5.9969845, 6.0601482,
          6.789772 ],
         [8.471845 , 8.172918 , 7.848505 , ..., 6.512267 , 6.198912 ,
          7.7962027],
         ...,
         [7.6225095, 7.7220645, 8.021209 , ..., 8.20195  , 8.130193 ,
          7.501075 ],
         [7.16429  , 7.809957 , 8.504215 , ..., 8.331161 , 7.5941954,
          8.112395 ],
         [6.6636386, 6.8771586, 8.429311 , ..., 7.214736 , 7.764557 ,
          6.867657 ]],

        [[6.942608 , 7.002553 , 7.092223 , ..., 5.5438323, 5.919307 ,
          6.4588156],
         [7.931553 , 6.8467336, 8.158717 , ..., 5.805227 , 6.538684 ,
          7.9998164],
         [8.64638  , 7.670035 , 9.19433  , ..., 5.948997 , 6.564786 ,
          7.8416657],
         ...,
         [7.624467 , 6.9467664, 8.990562 , ..., 8.466427 , 7.6095743,
          7.165219 ],
         [7.355692 , 8.671929 , 8.821998 , ..., 6.906161 , 8.366258 ,
          8.648574 ],
         [6.6341805, 8.415855 , 8.414853 , ..., 7.8235483, 7.994326 ,
          6.9552045]],

        [[8.164729 , 6.5867295, 8.375455 , ..., 6.138112 , 6.583139 ,
          6.581451 ],
         [8.739716 , 7.1404786, 8.5598   , ..., 6.6141286, 6.797036 ,
          8.238909 ],
         [8.733824 , 8.301625 , 9.551962 , ..., 6.5382314, 6.8334813,
          8.334098 ],
         ...,
         [7.4179096, 7.4367905, 8.479232 , ..., 9.271135 , 7.3919373,
          7.9040008],
         [7.5609527, 8.912743 , 8.987174 , ..., 7.988373 , 8.6921215,
          8.542471 ],
         [7.477041 , 8.285168 , 8.529076 , ..., 7.860715 , 8.320988 ,
          7.324272 ]],

        ...,

        [[6.5135617, 6.6306157, 7.2058196, ..., 5.273611 , 6.055018 ,
          6.3206887],
         [7.1804037, 8.146592 , 7.4991827, ..., 5.9763246, 5.8871   ,
          7.3664207],
         [6.8489943, 8.127537 , 8.35078  , ..., 5.554768 , 5.6488175,
          7.7726083],
         ...,
         [6.4850345, 7.0976562, 7.1712484, ..., 7.5530953, 6.952131 ,
          7.9194865],
         [7.8035326, 8.187436 , 8.276912 , ..., 7.218614 , 7.766487 ,
          8.483334 ],
         [7.1173396, 7.586395 , 7.9764857, ..., 7.0274477, 6.920168 ,
          6.6587515]],

        [[6.4631553, 6.628012 , 7.0430827, ..., 5.1135397, 5.8247857,
          6.3276763],
         [6.99526  , 6.884177 , 6.9137273, ..., 5.8827195, 6.256574 ,
          6.662396 ],
         [6.997056 , 7.254697 , 8.030955 , ..., 5.343504 , 6.2028036,
          7.165063 ],
         ...,
         [6.3499227, 6.697188 , 7.2855396, ..., 6.9896274, 7.5395117,
          6.501973 ],
         [6.8816857, 7.331365 , 7.9585133, ..., 7.4408507, 6.645596 ,
          7.789961 ],
         [6.2344813, 6.9750338, 7.0859184, ..., 6.461979 , 7.0840373,
          6.7824125]],

        [[5.157804 , 5.0492544, 6.3488107, ..., 4.6396203, 5.6160116,
          6.3968463],
         [6.204927 , 6.5673246, 7.0573864, ..., 5.4988527, 5.306164 ,
          5.5089493],
         [7.104061 , 7.3269444, 6.97711  , ..., 5.161616 , 5.345628 ,
          6.5845017],
         ...,
         [6.328454 , 7.4622273, 6.851907 , ..., 6.154734 , 6.6091948,
          6.6437707],
         [6.2652664, 7.0809107, 7.2807364, ..., 7.6275997, 7.6195917,
          6.556688 ],
         [6.112903 , 5.683591 , 6.3526325, ..., 6.390601 , 6.3589277,
          6.285007 ]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[  374,   255,   256, ...,   487,   364,   495],
         [  622,   502,   506, ...,   610,   613,   743],
         [  995,  1121,  1375, ...,  1231,  1108,  1486],
         ...,
         [14136, 14266, 14020, ..., 14000, 14254, 14132],
         [14632, 14512, 14765, ..., 14868, 14872, 14754],
         [14881, 15257, 15012, ..., 14993, 14996, 15249]],

        [[  250,   255,   380, ...,   486,   489,   247],
         [  496,   501,   505, ...,   859,   613,   619],
         [  995,  1120,  1251, ...,  1107,  1234,  1236],
         ...,
         [14263, 14265, 13896, ..., 14124, 14004, 14256],
         [14632, 14761, 14641, ..., 14744, 14872, 14878],
         [15006, 15011, 15012, ..., 15364, 14996, 15251]],

        [[  250,   254,   257, ...,   363,   488,   371],
         [  497,   503,   506, ...,   859,   612,   619],
         [  993,  1244,  1251, ...,  1107,  1110,  1237],
         ...,
         [14263, 14266, 14144, ..., 14248, 14254, 14132],
         [14756, 14761, 14641, ..., 14744, 14624, 14878],
         [14883, 15010, 15012, ..., 15364, 15368, 15125]],

        ...,

        [[  251,   255,   381, ...,   486,   489,   371],
         [  746,   501,   504, ...,   735,   860,   619],
         [ 1119,   997,  1374, ...,  1104,  1108,  1484],
         ...,
         [14261, 14142, 14269, ..., 14248, 14129, 14259],
         [14632, 14637, 14641, ..., 14620, 14872, 14878],
         [15130, 14884, 15014, ..., 15365, 14996, 15251]],

        [[  250,   255,   381, ...,   487,   489,   495],
         [  622,   502,   504, ...,   735,   612,   741],
         [ 1119,   996,  1375, ...,  1104,  1108,  1237],
         ...,
         [14263, 14266, 14145, ..., 14000, 14129, 14133],
         [14633, 14512, 14764, ..., 14620, 14872, 14877],
         [15130, 15009, 15012, ..., 15365, 15369, 15251]],

        [[  249,   378,   256, ...,   487,   365,   495],
         [  746,   503,   506, ...,   611,   612,   616],
         [  994,  1120,  1251, ...,  1231,  1109,  1484],
         ...,
         [14136, 14265, 14023, ..., 14372, 14006, 14382],
         [14632, 14636, 14640, ..., 14620, 14872, 14878],
         [15005, 15258, 15260, ..., 14992, 15370, 15249]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[ 8.551404 ,  9.075711 , 10.183338 , ...,  7.4780765,  7.504664 ,
         7.6275997]], shape=(1, 30752), dtype=float32), array([[124201.74, 123932.5 , 123731.24, ..., 124368.14, 123931.2 ,
        124096.99]], shape=(1, 1024), dtype=float32), array([[6.4778988e+07, 6.4752188e+07, 6.3478560e+07, 6.3431992e+07,
        6.2845112e+07, 6.2816280e+07, 6.3186184e+07, 6.2828616e+07,
        6.3401080e+07, 6.4651852e+07, 6.5227608e+07, 6.4261348e+07,
        6.3334416e+07, 6.4268824e+07, 6.4465036e+07, 6.2655608e+07,
        6.3980860e+07, 6.3351448e+07, 6.3707960e+07, 6.3757176e+07,
        6.3070980e+07, 6.3671792e+07, 6.2698872e+07, 6.4065440e+07,
        6.2696552e+07, 6.1894188e+07, 6.1968096e+07, 6.3503116e+07,
        6.3428416e+07, 6.2804112e+07, 6.2969936e+07, 6.4613808e+07,
        6.3114368e+07, 6.1673216e+07, 6.1399904e+07, 6.3363136e+07,
        6.2868064e+07, 6.3285360e+07, 6.3200552e+07, 6.3906976e+07,
        6.3544144e+07, 6.4617608e+07, 6.4236456e+07, 6.4171144e+07,
        6.4546072e+07, 6.5735400e+07, 6.1626588e+07, 6.2508152e+07,
        6.4805248e+07, 6.3258272e+07, 6.2658680e+07, 6.4459128e+07,
        6.2330432e+07, 6.3170896e+07, 6.3915176e+07, 6.5436168e+07,
        6.1975424e+07, 6.2712816e+07, 6.1869632e+07, 6.4432088e+07,
        6.1149184e+07, 6.5277240e+07, 6.4590152e+07, 6.3650064e+07,
        6.5898972e+07, 6.4919588e+07, 6.5656608e+07, 6.1004656e+07,
        6.3604208e+07, 6.3417440e+07, 6.3228576e+07, 6.5363364e+07,
        6.2866740e+07, 6.2075844e+07, 6.3093392e+07, 6.2610008e+07,
        6.4688008e+07, 6.3479416e+07, 6.2789948e+07, 6.3242288e+07,
        6.5709732e+07, 6.2552616e+07, 6.2463848e+07, 6.4400504e+07,
        6.3953224e+07, 6.3333696e+07, 6.0803796e+07, 6.4233304e+07,
        6.2097032e+07, 6.2286620e+07, 6.2913760e+07, 6.2893208e+07,
        6.3360688e+07, 6.2201636e+07, 6.4829096e+07, 6.2872588e+07,
        6.5251840e+07, 6.5018896e+07, 6.3913596e+07, 6.3603968e+07,
        6.6101288e+07, 6.3634424e+07, 6.2189604e+07, 6.3058096e+07,
        6.3686656e+07, 6.3947236e+07, 6.3781296e+07, 6.5288384e+07,
        6.3459504e+07, 6.3316600e+07, 6.3293832e+07, 6.4853800e+07,
        6.2266100e+07, 6.2312656e+07, 6.3572748e+07, 6.4707192e+07,
        6.3457200e+07, 6.2271480e+07, 6.1815512e+07, 6.4171308e+07,
        6.2221656e+07, 6.3431888e+07, 6.3596744e+07, 6.3966296e+07,
        6.3165764e+07, 6.2923784e+07, 6.3093224e+07, 6.1638304e+07]],
      dtype=float32), array([[4.1098030e+09, 3.7085363e+09, 3.6985411e+09, 3.6896095e+09,
        4.0884142e+09, 4.2186007e+09, 4.1056561e+09, 4.1638871e+09,
        4.0401833e+09, 4.2523576e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.202 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Better shape inference

201: Better shape inference

101: A custom backend for torch

101: A custom backend for torch

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

Gallery generated by Sphinx-Gallery