101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.09610969, 0.7728384 , 0.75196004, ..., 0.84491324, 0.79128575,
        0.7401204 ],
       [0.9479155 , 0.47235975, 0.5942689 , ..., 0.19155996, 0.488694  ,
        0.34412995],
       [0.14876392, 0.56245506, 0.42884845, ..., 0.9198805 , 0.5506332 ,
        0.13670772],
       ...,
       [0.5250744 , 0.6346656 , 0.05917783, ..., 0.4984336 , 0.67380965,
        0.06971068],
       [0.10134266, 0.26747262, 0.51159984, ..., 0.5425203 , 0.87270105,
        0.7203621 ],
       [0.6009859 , 0.74893504, 0.90190494, ..., 0.8438301 , 0.91150784,
        0.26544774]], shape=(128, 1024), dtype=float32), array([[0.5450286 , 0.05108841, 0.74802214, ..., 0.34395698, 0.09970681,
        0.2158931 ],
       [0.3145486 , 0.5574005 , 0.38632682, ..., 0.65387124, 0.6815179 ,
        0.3011002 ],
       [0.78981787, 0.38738513, 0.67306495, ..., 0.6603452 , 0.34233794,
        0.9149113 ],
       ...,
       [0.47101635, 0.08823352, 0.884104  , ..., 0.50432026, 0.2064694 ,
        0.26791674],
       [0.5662113 , 0.0291094 , 0.19492099, ..., 0.416333  , 0.570644  ,
        0.12336027],
       [0.46335468, 0.7540493 , 0.2930684 , ..., 0.84592444, 0.83217835,
        0.17364559]], shape=(1024, 30752), dtype=float32), array([[0.77160484, 0.44732657, 0.8793495 , ..., 0.13767523, 0.28322092,
        0.38347045],
       [0.50376034, 0.6951072 , 0.34270832, ..., 0.9852718 , 0.59448594,
        0.97335804],
       [0.6232289 , 0.1364311 , 0.9206978 , ..., 0.13986202, 0.686012  ,
        0.40576643],
       ...,
       [0.640396  , 0.49620828, 0.10661684, ..., 0.336249  , 0.5838923 ,
        0.7280593 ],
       [0.5700163 , 0.2603261 , 0.15854801, ..., 0.8565149 , 0.34878585,
        0.24813211],
       [0.5626924 , 0.34012517, 0.9681925 , ..., 0.36979583, 0.6830576 ,
        0.03190738]], shape=(10, 128), dtype=float32), array([[[[4.675813 , 4.8393917, 4.530805 , ..., 6.5850463, 6.2330737,
          5.4244423],
         [4.9031816, 5.4459305, 5.242003 , ..., 5.8558264, 5.349012 ,
          5.2379303],
         [6.1449842, 6.261963 , 5.8365436, ..., 6.8934035, 5.818215 ,
          6.8576436],
         ...,
         [6.31502  , 6.184564 , 5.1498985, ..., 4.7691255, 5.3211126,
          5.9128103],
         [5.7589073, 5.65446  , 5.1514096, ..., 5.3978014, 5.7793846,
          5.7055755],
         [6.150046 , 6.0458603, 5.7789774, ..., 5.6919484, 4.9626207,
          5.324859 ]],

        [[5.23308  , 6.6074586, 5.5526505, ..., 6.7088737, 6.4515886,
          6.605689 ],
         [5.907752 , 7.2733054, 6.4803004, ..., 7.2436705, 7.625379 ,
          6.598878 ],
         [5.995107 , 6.6457796, 7.346325 , ..., 7.8733835, 7.744594 ,
          7.3948293],
         ...,
         [7.216176 , 7.8466363, 7.0026503, ..., 6.7702537, 6.845193 ,
          6.908416 ],
         [7.3323965, 7.7593193, 6.0310626, ..., 6.0389333, 6.5763497,
          6.9780807],
         [7.755426 , 7.8468566, 6.6688113, ..., 6.5624204, 6.388437 ,
          6.630726 ]],

        [[6.1693115, 6.3360524, 5.7711062, ..., 7.301279 , 7.476224 ,
          7.1803865],
         [5.6562405, 6.940252 , 6.729952 , ..., 7.276113 , 7.164065 ,
          6.4023805],
         [6.6390224, 7.4007125, 7.504481 , ..., 7.2014318, 7.6115255,
          7.6915755],
         ...,
         [8.021874 , 7.5101   , 7.514248 , ..., 6.454876 , 6.2781773,
          6.844224 ],
         [7.1982107, 7.663066 , 6.7969904, ..., 5.8152018, 6.313769 ,
          6.8283544],
         [6.97674  , 7.514856 , 7.1654916, ..., 6.6485343, 6.0571275,
          6.1950116]],

        ...,

        [[5.250869 , 6.19024  , 5.323181 , ..., 5.994178 , 6.261137 ,
          6.1661167],
         [5.456216 , 5.6767855, 5.1586323, ..., 6.633404 , 5.7539883,
          5.8595333],
         [6.2778807, 6.055975 , 6.7682457, ..., 6.6117325, 6.6505585,
          5.7998695],
         ...,
         [6.555544 , 7.179186 , 6.5355883, ..., 5.6287804, 6.0172253,
          4.982135 ],
         [6.309924 , 6.5869055, 5.8591833, ..., 5.7681103, 5.536516 ,
          5.5897274],
         [5.7137356, 6.009883 , 6.0124655, ..., 5.9652815, 5.05702  ,
          6.3460093]],

        [[6.806348 , 7.5467434, 6.6534033, ..., 8.538446 , 7.9266596,
          8.176761 ],
         [7.7786427, 7.4173274, 7.744265 , ..., 8.736114 , 8.012027 ,
          8.194593 ],
         [8.307279 , 8.539384 , 8.796736 , ..., 9.735919 , 8.960765 ,
          8.6386175],
         ...,
         [8.377878 , 9.7039995, 8.404376 , ..., 7.522732 , 8.194346 ,
          7.623464 ],
         [8.239339 , 9.0724745, 7.5216002, ..., 7.866619 , 7.4154396,
          8.1689825],
         [8.725547 , 8.374278 , 8.149051 , ..., 8.232766 , 7.010842 ,
          7.2677546]],

        [[7.031036 , 6.4757147, 6.307449 , ..., 8.507093 , 8.194411 ,
          7.3640084],
         [7.44024  , 7.726051 , 6.5270677, ..., 7.9953556, 8.34257  ,
          7.2676883],
         [7.387988 , 9.045776 , 8.43912  , ..., 8.32619  , 8.143594 ,
          7.542256 ],
         ...,
         [8.217804 , 8.8999195, 8.855006 , ..., 6.9249506, 7.4461083,
          8.200155 ],
         [7.690132 , 8.070051 , 7.920203 , ..., 7.6894836, 6.5439115,
          7.548409 ],
         [8.138238 , 8.376512 , 7.9650464, ..., 6.967308 , 7.0520144,
          6.8442097]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[  374,   131,     9, ...,   237,   367,   369],
         [  746,   748,   630, ...,   733,   615,   616],
         [ 1364,  1245,  1127, ...,  1477,  1109,  1360],
         ...,
         [14263, 14266, 14144, ..., 14000, 14007, 14257],
         [14511, 14512, 14767, ..., 14622, 14624, 14752],
         [14880, 14884, 15015, ..., 14992, 14998, 15127]],

        [[  373,     7,   132, ...,   485,   491,   369],
         [  498,   874,   755, ...,   857,   615,   616],
         [ 1242,  1369,  1375, ...,  1477,  1481,  1113],
         ...,
         [14263, 14264, 14268, ..., 14002, 14007, 14381],
         [14759, 14636, 14640, ..., 14746, 14749, 14504],
         [14883, 15008, 15015, ..., 15240, 14996, 15002]],

        [[  374,     7,   132, ...,   485,   366,   371],
         [  498,   500,   755, ...,   858,   615,   616],
         [ 1242,  1369,  1375, ...,  1477,  1110,  1113],
         ...,
         [14263, 14264, 14268, ..., 14002, 14006, 14008],
         [14759, 14512, 14519, ..., 14746, 14624, 14505],
         [15004, 15008, 15263, ..., 14992, 14999, 15002]],

        ...,

        [[  373,   131,   133, ...,   112,   367,   370],
         [  622,   873,   755, ...,   857,   860,   867],
         [ 1242,  1244,  1127, ...,  1477,  1481,  1112],
         ...,
         [14139, 14265, 14144, ..., 14373, 14007, 14008],
         [14759, 14636, 14767, ..., 14621, 14749, 14505],
         [15005, 15256, 15263, ..., 15240, 14999, 15375]],

        [[  373,   131,   132, ...,   112,   367,   369],
         [  746,   624,   631, ...,   733,   615,   617],
         [ 1366,  1246,  1127, ...,  1476,  1110,  1114],
         ...,
         [14263, 14265, 14144, ..., 14003, 14007, 14008],
         [14759, 14636, 14518, ..., 14621, 14872, 14628],
         [15005, 15009, 15139, ..., 15240, 14999, 15373]],

        [[  249,   131,   382, ...,   112,   366,   368],
         [  746,   749,   507, ...,   859,   614,   616],
         [ 1240,  1246,  1127, ...,  1228,  1109,  1113],
         ...,
         [14262, 14141, 14145, ..., 14002, 14004, 14257],
         [14387, 14388, 14766, ..., 14621, 14748, 14628],
         [15005, 15008, 15139, ..., 15117, 14999, 15127]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[6.7413177, 6.5252485, 6.6804986, ..., 7.863183 , 8.615847 ,
        8.200155 ]], shape=(1, 30752), dtype=float32), array([[118367.37 , 119177.234, 118987.086, ..., 119547.74 , 118254.24 ,
        118718.13 ]], shape=(1, 1024), dtype=float32), array([[6.4128168e+07, 5.8999764e+07, 5.9586596e+07, 6.0295636e+07,
        6.1974340e+07, 6.0698212e+07, 6.1241008e+07, 6.1970800e+07,
        6.1567644e+07, 6.1991200e+07, 6.1572856e+07, 6.1601976e+07,
        6.1340664e+07, 6.0645956e+07, 5.8501168e+07, 6.0780040e+07,
        6.0628152e+07, 5.9184696e+07, 6.1584416e+07, 6.0284976e+07,
        6.0002884e+07, 6.1890352e+07, 6.1937068e+07, 6.1375384e+07,
        6.2477452e+07, 6.2767800e+07, 6.0903304e+07, 6.0515528e+07,
        6.2023640e+07, 6.0238416e+07, 5.9209660e+07, 5.9731120e+07,
        6.0374568e+07, 5.9482168e+07, 6.0115768e+07, 6.1582104e+07,
        6.1052744e+07, 6.0998608e+07, 6.1180288e+07, 5.9299032e+07,
        6.0731440e+07, 5.9612764e+07, 5.9659656e+07, 5.9558632e+07,
        6.0549936e+07, 6.1873928e+07, 5.6873096e+07, 6.1038344e+07,
        6.2429520e+07, 6.1385976e+07, 5.9541248e+07, 5.9505952e+07,
        6.0852816e+07, 6.1084104e+07, 6.0501532e+07, 6.1744480e+07,
        5.9608024e+07, 5.8768036e+07, 6.0296912e+07, 5.9994420e+07,
        5.8566616e+07, 6.0935168e+07, 6.1183752e+07, 6.1137880e+07,
        6.3117792e+07, 6.0234340e+07, 6.0272708e+07, 6.0511352e+07,
        6.1614240e+07, 6.0392688e+07, 6.1473696e+07, 6.0956428e+07,
        5.9723384e+07, 6.1914472e+07, 5.9426000e+07, 6.1341472e+07,
        5.9775700e+07, 6.1335480e+07, 6.0282920e+07, 6.0643136e+07,
        6.1424400e+07, 5.9674412e+07, 6.0923216e+07, 5.9795080e+07,
        5.8237068e+07, 6.0031240e+07, 6.0288172e+07, 6.1723260e+07,
        6.0229552e+07, 5.9320568e+07, 6.1351368e+07, 6.0019828e+07,
        6.2556932e+07, 5.9816776e+07, 6.0898860e+07, 6.0151792e+07,
        6.2236608e+07, 6.1135512e+07, 6.1322792e+07, 5.9843996e+07,
        6.1660856e+07, 5.9824032e+07, 6.1159356e+07, 5.8991304e+07,
        5.9940408e+07, 6.2538884e+07, 6.0827760e+07, 6.1265472e+07,
        6.1207960e+07, 6.1216996e+07, 6.1483088e+07, 6.1859440e+07,
        5.9887332e+07, 6.2837212e+07, 6.2696132e+07, 6.1264368e+07,
        6.0190904e+07, 6.0846368e+07, 6.1244204e+07, 6.0282240e+07,
        5.9752948e+07, 5.9783396e+07, 5.9826720e+07, 6.1863904e+07,
        6.1392704e+07, 6.0093600e+07, 5.9726020e+07, 5.9484752e+07]],
      dtype=float32), array([[3.8916920e+09, 4.4963220e+09, 3.7605591e+09, 3.7983089e+09,
        3.7505807e+09, 3.8865334e+09, 4.0423076e+09, 4.2628155e+09,
        3.9225544e+09, 3.7468856e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.231 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

102: Tweak onnx export

102: Tweak onnx export

101: Onnx Model Rewriting

101: Onnx Model Rewriting

Gallery generated by Sphinx-Gallery