101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.510395  , 0.09870491, 0.75266755, ..., 0.5195775 , 0.57707196,
        0.7490044 ],
       [0.8236655 , 0.46533   , 0.9431352 , ..., 0.5351355 , 0.01673928,
        0.48355934],
       [0.21692038, 0.3470374 , 0.4777763 , ..., 0.8625944 , 0.74289674,
        0.32841423],
       ...,
       [0.8921656 , 0.13616458, 0.46435517, ..., 0.8052121 , 0.65348035,
        0.2504524 ],
       [0.01860473, 0.39064184, 0.87167305, ..., 0.43814293, 0.869365  ,
        0.73878515],
       [0.6230802 , 0.38814384, 0.11948249, ..., 0.4887564 , 0.21102405,
        0.73262644]], shape=(128, 1024), dtype=float32), array([[0.540234  , 0.24377689, 0.59050584, ..., 0.7471279 , 0.5253918 ,
        0.4479876 ],
       [0.53054404, 0.89304227, 0.8826398 , ..., 0.60216653, 0.12245536,
        0.8086819 ],
       [0.08696877, 0.45958072, 0.94682866, ..., 0.3839614 , 0.31130317,
        0.6475578 ],
       ...,
       [0.9040087 , 0.83760476, 0.42550895, ..., 0.98305   , 0.6565117 ,
        0.96121   ],
       [0.72511977, 0.28390366, 0.6875673 , ..., 0.94175047, 0.03937987,
        0.7381576 ],
       [0.35966566, 0.93998736, 0.08909597, ..., 0.413577  , 0.36816013,
        0.20294131]], shape=(1024, 30752), dtype=float32), array([[0.85992175, 0.40893006, 0.91688865, ..., 0.77127737, 0.61978394,
        0.1234277 ],
       [0.79496896, 0.02235206, 0.71498513, ..., 0.55956376, 0.99967736,
        0.00727563],
       [0.9504341 , 0.08722831, 0.1550522 , ..., 0.5807844 , 0.05640255,
        0.04382119],
       ...,
       [0.91768384, 0.65331656, 0.8954979 , ..., 0.11000395, 0.6667147 ,
        0.5823768 ],
       [0.02742839, 0.8250603 , 0.6195953 , ..., 0.63641924, 0.47468057,
        0.4685825 ],
       [0.2796416 , 0.5099486 , 0.2765593 , ..., 0.00334817, 0.27891853,
        0.03767961]], shape=(10, 128), dtype=float32), array([[[[5.806987 , 5.72013  , 6.157274 , ..., 6.0545344, 6.0556784,
          4.8173094],
         [5.006063 , 5.585447 , 5.7751427, ..., 7.25757  , 5.9223084,
          6.1376605],
         [5.662475 , 5.541954 , 5.4569254, ..., 7.0972886, 6.890184 ,
          5.921923 ],
         ...,
         [7.3290105, 6.1417856, 6.660156 , ..., 6.4197974, 4.817443 ,
          6.5680733],
         [6.981306 , 5.95116  , 6.1571884, ..., 4.5505958, 4.821188 ,
          5.9442015],
         [6.3763127, 6.270368 , 5.9747205, ..., 5.8328056, 4.9997   ,
          6.6118903]],

        [[7.573307 , 6.5594516, 6.727903 , ..., 7.9036674, 6.7875094,
          6.2193174],
         [6.427677 , 6.47501  , 6.579662 , ..., 8.290004 , 7.887871 ,
          6.4572086],
         [6.797099 , 6.7822533, 6.449536 , ..., 7.9760294, 7.972854 ,
          7.420871 ],
         ...,
         [8.562127 , 8.30687  , 8.352249 , ..., 6.1537843, 6.6855364,
          6.839553 ],
         [8.273453 , 8.019238 , 7.298452 , ..., 5.799291 , 5.9442134,
          6.7026215],
         [7.6595435, 6.855773 , 7.2733693, ..., 5.8534894, 5.257635 ,
          7.074494 ]],

        [[6.09562  , 6.7099795, 6.398712 , ..., 6.0222287, 5.7456374,
          5.703014 ],
         [5.0917773, 6.155957 , 5.435502 , ..., 7.1121435, 6.876064 ,
          6.320275 ],
         [6.7134843, 5.6703215, 5.376893 , ..., 7.0713973, 7.6736345,
          6.6342244],
         ...,
         [7.2245374, 7.2838416, 7.4497914, ..., 5.919186 , 6.2245483,
          6.95924  ],
         [7.426585 , 7.0253086, 6.164856 , ..., 5.324277 , 5.532171 ,
          6.470504 ],
         [6.863257 , 6.337329 , 5.7022367, ..., 6.1614037, 5.120746 ,
          6.5681777]],

        ...,

        [[7.062579 , 7.1482577, 6.780678 , ..., 7.158068 , 7.3833985,
          6.317856 ],
         [6.6447344, 6.860033 , 6.8069925, ..., 8.233599 , 7.6687574,
          7.788963 ],
         [6.79121  , 6.6062245, 6.7519684, ..., 8.406888 , 8.140712 ,
          7.4982758],
         ...,
         [8.328917 , 8.669029 , 7.8935614, ..., 6.456511 , 6.5150266,
          7.0281615],
         [7.913402 , 7.1996055, 7.08994  , ..., 5.751628 , 5.792744 ,
          7.0866942],
         [8.275266 , 7.7689857, 6.9132953, ..., 6.9814863, 6.5053353,
          8.01444  ]],

        [[7.197398 , 6.4248443, 6.1547756, ..., 6.64271  , 6.639652 ,
          5.758892 ],
         [6.9436264, 5.676034 , 5.6722555, ..., 8.202153 , 7.6702285,
          6.667065 ],
         [7.3901343, 6.4414344, 6.5813518, ..., 7.9361444, 7.8210273,
          7.4510965],
         ...,
         [7.976137 , 7.1362987, 6.7917104, ..., 7.2537456, 6.6555095,
          7.178766 ],
         [7.793176 , 6.9324927, 7.0725803, ..., 5.5457215, 5.4725423,
          6.641788 ],
         [8.095506 , 7.33383  , 6.4648485, ..., 6.4031425, 5.2241807,
          6.689626 ]],

        [[7.427603 , 6.369605 , 7.0653505, ..., 7.7532473, 7.7680655,
          5.86767  ],
         [6.354887 , 6.89118  , 6.5639534, ..., 8.676196 , 7.833905 ,
          7.8430805],
         [7.0764723, 7.4269753, 5.7521772, ..., 8.7820015, 8.515766 ,
          7.6656866],
         ...,
         [9.665207 , 8.070228 , 7.947341 , ..., 6.6797304, 6.740755 ,
          7.251953 ],
         [8.025931 , 7.7180023, 7.346319 , ..., 6.4853144, 5.9467134,
          7.3486214],
         [8.240556 , 7.143462 , 6.993849 , ..., 6.2916117, 6.252305 ,
          8.599945 ]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[    2,   128,   134, ...,   485,   118,   493],
         [  870,   872,   878, ...,   857,   987,   865],
         [  993,   999,  1124, ...,  1105,  1232,  1360],
         ...,
         [13888, 14019, 14271, ..., 14372, 14006, 14009],
         [14632, 14515, 14519, ..., 14868, 14751, 14630],
         [14880, 14884, 15261, ..., 15242, 15368, 15375]],

        [[    0,     6,   258, ...,   236,   118,   495],
         [  870,   872,   878, ...,   733,   984,   989],
         [  995,  1120,  1003, ...,  1352,  1480,  1115],
         ...,
         [13891, 14266, 14271, ..., 14374, 14007, 14382],
         [14508, 14515, 14394, ..., 14496, 14500, 14631],
         [14880, 14884, 15260, ..., 15365, 15244, 15372]],

        [[  248,     5,   134, ...,   237,   118,   494],
         [  869,   875,   876, ...,   733,   987,   617],
         [ 1117,   999,  1124, ...,  1105,  1233,  1114],
         ...,
         [13889, 13894, 14269, ..., 14373, 14006, 14008],
         [14756, 14515, 14394, ..., 14622, 14503, 14753],
         [14880, 15259, 15261, ..., 15364, 15244, 15127]],

        ...,

        [[    3,     5,    11, ...,   485,   118,   494],
         [  869,   751,   879, ...,   981,   737,   866],
         [ 1117,   999,  1248, ...,  1352,  1235,  1113],
         ...,
         [14137, 14017, 14268, ..., 14374, 14005, 14008],
         [14384, 14515, 14395, ..., 14622, 14875, 14629],
         [15005, 14884, 15261, ..., 15364, 14998, 15375]],

        [[  248,   129,   382, ...,   239,   119,   245],
         [  870,   872,   878, ...,   609,   737,   866],
         [ 1118,   996,  1248, ...,  1352,  1233,  1114],
         ...,
         [13888, 13894, 14271, ..., 14374, 14007, 14010],
         [14384, 14514, 14516, ..., 14868, 14500, 14631],
         [14880, 14884, 15261, ..., 15242, 15244, 15125]],

        [[    0,     5,   258, ...,   112,   118,   494],
         [  869,   872,   878, ...,   857,   985,   990],
         [  995,  1120,  1002, ...,  1352,  1109,  1115],
         ...,
         [14012, 14142, 14270, ..., 14374, 14007, 14008],
         [14756, 14638, 14394, ..., 14499, 14501, 14877],
         [15004, 14884, 15263, ..., 15366, 15245, 15375]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[6.157274 , 6.732797 , 6.8466873, ..., 9.558445 , 7.943347 ,
        8.599945 ]], shape=(1, 30752), dtype=float32), array([[122931.21 , 122812.01 , 122660.01 , ..., 122802.49 , 121542.086,
        122556.76 ]], shape=(1, 1024), dtype=float32), array([[6.2685740e+07, 6.3410160e+07, 6.2390724e+07, 6.3186020e+07,
        6.4175284e+07, 6.2694756e+07, 6.4098784e+07, 6.2466672e+07,
        6.2496920e+07, 6.2552360e+07, 6.3269496e+07, 6.2899440e+07,
        6.2978024e+07, 6.4169028e+07, 6.2707548e+07, 6.4332960e+07,
        6.3459072e+07, 6.2106776e+07, 6.4202848e+07, 6.3429140e+07,
        6.3302292e+07, 6.1958276e+07, 6.1664368e+07, 6.1055936e+07,
        6.3270112e+07, 6.1401068e+07, 6.2135536e+07, 6.2427304e+07,
        6.0256600e+07, 6.3209992e+07, 6.0593304e+07, 6.0781636e+07,
        6.1482384e+07, 6.1717400e+07, 6.2610832e+07, 6.3769260e+07,
        6.1473264e+07, 6.1505804e+07, 6.3947816e+07, 6.4301336e+07,
        6.4167992e+07, 6.2819408e+07, 6.5220052e+07, 6.1881272e+07,
        6.2952696e+07, 6.3190024e+07, 6.2599972e+07, 6.4674996e+07,
        6.1654040e+07, 6.2040848e+07, 6.5397256e+07, 6.2409552e+07,
        6.0697624e+07, 6.3951028e+07, 6.3041340e+07, 6.4684656e+07,
        6.3020680e+07, 6.3066268e+07, 6.2280680e+07, 6.2135076e+07,
        6.2253896e+07, 6.1113460e+07, 6.2437828e+07, 6.3035072e+07,
        6.4578888e+07, 6.3574296e+07, 6.3139376e+07, 6.5137860e+07,
        6.2854876e+07, 6.1501360e+07, 6.7199296e+07, 6.0134952e+07,
        6.3340264e+07, 6.3838528e+07, 6.1254408e+07, 6.3398144e+07,
        6.2998444e+07, 6.1963864e+07, 6.2621528e+07, 6.0110416e+07,
        6.1974800e+07, 6.3285376e+07, 6.1449940e+07, 6.3292648e+07,
        6.0577176e+07, 6.3333320e+07, 6.3513912e+07, 6.1636444e+07,
        6.2994152e+07, 6.3698160e+07, 6.2220116e+07, 6.4425136e+07,
        6.3645176e+07, 6.1214688e+07, 6.4143920e+07, 6.4210424e+07,
        6.0982224e+07, 6.2320652e+07, 6.3713752e+07, 6.5227760e+07,
        6.3964752e+07, 6.0706800e+07, 6.4768920e+07, 6.2442664e+07,
        6.2656032e+07, 6.4456688e+07, 6.0053592e+07, 6.3901548e+07,
        6.1754392e+07, 6.2550728e+07, 6.5093520e+07, 6.3075572e+07,
        6.4294520e+07, 6.5430116e+07, 6.3140252e+07, 6.2226720e+07,
        6.3469724e+07, 6.3753616e+07, 6.1655140e+07, 6.1167440e+07,
        6.2269368e+07, 6.3205008e+07, 6.3155536e+07, 6.2887120e+07,
        6.2112636e+07, 6.2542400e+07, 6.3138648e+07, 6.1983700e+07]],
      dtype=float32), array([[4.0259461e+09, 4.0417464e+09, 3.9178358e+09, 3.7554437e+09,
        4.1699364e+09, 4.0457357e+09, 3.7798787e+09, 3.9954289e+09,
        4.5021348e+09, 3.9691351e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='str')

Total running time of the script: (0 minutes 1.720 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Better shape inference

201: Better shape inference

101: A custom backend for torch

101: A custom backend for torch

101: Some dummy examples with torch.export.export

101: Some dummy examples with torch.export.export

Gallery generated by Sphinx-Gallery