101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.18355778, 0.76602465, 0.6846948 , ..., 0.81538266, 0.48825163,
        0.8215927 ],
       [0.8380143 , 0.82199407, 0.91720724, ..., 0.42312086, 0.7012276 ,
        0.85677046],
       [0.03198922, 0.73179334, 0.497972  , ..., 0.05410693, 0.4504848 ,
        0.75826985],
       ...,
       [0.35802746, 0.31844774, 0.24433039, ..., 0.42168942, 0.786901  ,
        0.57012194],
       [0.44429383, 0.91724664, 0.87534624, ..., 0.7951918 , 0.35054106,
        0.2510082 ],
       [0.00353544, 0.9529619 , 0.12721317, ..., 0.47134268, 0.9127621 ,
        0.18090533]], shape=(128, 1024), dtype=float32), array([[0.45642313, 0.97663534, 0.44507095, ..., 0.17866996, 0.40479878,
        0.3248754 ],
       [0.7451527 , 0.32974935, 0.6402123 , ..., 0.04206786, 0.02783561,
        0.14506415],
       [0.76905596, 0.34797153, 0.5862115 , ..., 0.05950727, 0.6474196 ,
        0.61212987],
       ...,
       [0.7462301 , 0.61285716, 0.97019523, ..., 0.7958755 , 0.08900838,
        0.48395354],
       [0.02824485, 0.6900251 , 0.8272219 , ..., 0.9637669 , 0.68942076,
        0.56973773],
       [0.6575323 , 0.83865434, 0.7762486 , ..., 0.62679636, 0.05806791,
        0.08949406]], shape=(1024, 30752), dtype=float32), array([[0.4156681 , 0.8852881 , 0.82216066, ..., 0.59876835, 0.9890533 ,
        0.64269114],
       [0.7018817 , 0.7474065 , 0.14376931, ..., 0.28401065, 0.2374286 ,
        0.02105127],
       [0.4993807 , 0.14813603, 0.60560733, ..., 0.09373682, 0.31881204,
        0.59701616],
       ...,
       [0.4337615 , 0.8195667 , 0.52077234, ..., 0.28911933, 0.2622677 ,
        0.5486428 ],
       [0.66053134, 0.44122005, 0.8410543 , ..., 0.5868373 , 0.5415453 ,
        0.9791288 ],
       [0.56729203, 0.97646445, 0.9090489 , ..., 0.46370596, 0.9985369 ,
        0.9798734 ]], shape=(10, 128), dtype=float32), array([[[[ 6.07834  ,  6.9498243,  6.860451 , ...,  6.2489305,
           6.2862563,  6.7256274],
         [ 5.921294 ,  5.088063 ,  5.9590945, ...,  6.1864996,
           6.4864893,  6.445052 ],
         [ 6.229443 ,  6.221809 ,  6.837107 , ...,  6.673263 ,
           7.107069 ,  5.732827 ],
         ...,
         [ 7.924215 ,  7.702151 ,  8.286121 , ...,  6.619735 ,
           7.1034517,  6.7472777],
         [ 8.172043 ,  8.223117 ,  7.5010796, ...,  6.6054106,
           6.514733 ,  6.607249 ],
         [ 8.234721 ,  7.6824584,  7.5782495, ...,  7.1097865,
           6.588666 ,  6.167768 ]],

        [[ 4.816602 ,  4.980968 ,  5.2392216, ...,  4.2774105,
           4.963634 ,  5.6008587],
         [ 4.3513794,  4.6608553,  4.400864 , ...,  5.190287 ,
           4.7333574,  4.9007587],
         [ 3.716007 ,  4.7814445,  4.8018336, ...,  4.9781647,
           5.1291504,  5.112105 ],
         ...,
         [ 5.5040097,  6.288353 ,  5.96729  , ...,  5.183324 ,
           4.977249 ,  5.3500094],
         [ 6.1326423,  5.8048773,  6.0327473, ...,  4.91949  ,
           4.7221417,  3.8280673],
         [ 6.105724 ,  5.871759 ,  5.0002174, ...,  4.562823 ,
           4.7435555,  4.44974  ]],

        [[ 5.53453  ,  5.794347 ,  6.2370534, ...,  5.161731 ,
           5.707613 ,  6.292285 ],
         [ 5.318179 ,  5.586127 ,  5.1728587, ...,  5.939946 ,
           6.1873393,  5.5668015],
         [ 5.7686415,  6.2614045,  6.0903535, ...,  6.001094 ,
           6.1330996,  6.294871 ],
         ...,
         [ 6.554962 ,  7.35803  ,  7.0087237, ...,  6.5292606,
           5.8693843,  6.15562  ],
         [ 7.457148 ,  7.4793177,  7.0415397, ...,  5.5478067,
           5.611039 ,  5.2644186],
         [ 7.1043415,  7.0413504,  6.655721 , ...,  5.4384737,
           6.1681967,  5.869578 ]],

        ...,

        [[ 7.1980147,  7.4931073,  7.642072 , ...,  6.556544 ,
           6.066242 ,  7.6273203],
         [ 5.922835 ,  7.3984213,  6.640364 , ...,  7.8946056,
           8.098768 ,  7.6947145],
         [ 7.808523 ,  8.190684 ,  7.591018 , ...,  7.269128 ,
           8.425458 ,  7.6608944],
         ...,
         [ 8.021083 ,  8.460705 ,  9.2485895, ...,  7.5667057,
           8.170313 ,  7.944552 ],
         [ 9.075381 ,  8.964181 ,  8.893072 , ...,  7.1060567,
           7.391473 ,  6.6333447],
         [ 8.8052435,  9.684732 ,  7.981792 , ...,  6.981919 ,
           6.945111 ,  7.851127 ]],

        [[ 6.298741 ,  5.891486 ,  7.373604 , ...,  5.9188366,
           6.3638263,  6.752047 ],
         [ 5.65589  ,  6.512294 ,  6.267812 , ...,  6.6380067,
           6.721056 ,  7.0125165],
         [ 6.279974 ,  7.7126894,  7.7464128, ...,  6.673944 ,
           8.247188 ,  7.128787 ],
         ...,
         [ 7.7832313,  8.129161 ,  7.890993 , ...,  6.6077604,
           7.6892266,  7.4518147],
         [ 8.370986 ,  8.25616  ,  7.6481733, ...,  7.224927 ,
           5.9082007,  7.0773153],
         [ 8.270368 ,  7.39873  ,  7.473835 , ...,  6.827078 ,
           7.0893016,  5.7405863]],

        [[ 7.668147 ,  8.556849 ,  8.313988 , ...,  6.982935 ,
           8.241359 ,  8.147688 ],
         [ 7.324009 ,  7.1250477,  7.237815 , ...,  7.9485464,
           8.406957 ,  8.741947 ],
         [ 8.189062 ,  7.6091394,  8.813858 , ...,  7.4599285,
           9.423107 ,  7.2442865],
         ...,
         [10.089016 ,  9.962808 ,  9.380581 , ...,  8.365626 ,
           8.823648 ,  9.186967 ],
         [ 9.522221 , 10.118403 ,  9.132392 , ...,  8.858189 ,
           7.239127 ,  7.566367 ],
         [10.822688 ,  9.112359 ,  9.190873 , ...,  8.424476 ,
           9.06429  ,  7.6042747]]]],
      shape=(1, 32, 124, 124), dtype=float32), array([[[[  373,     5,   132, ...,   114,   116,   370],
         [  745,   624,   631, ...,   980,   614,   618],
         [ 1240,  1368,  1127, ...,  1476,  1110,  1361],
         ...,
         [14013, 14017, 13896, ..., 14003, 14007, 14009],
         [14757, 14515, 14394, ..., 14622, 14503, 14752],
         [15007, 14884, 15012, ..., 14992, 15369, 15373]],

        [[  375,     4,     8, ...,   114,   367,   123],
         [  745,   751,   631, ...,   856,   614,   618],
         [ 1241,   999,  1127, ...,  1476,  1108,  1487],
         ...,
         [14013, 13893, 14145, ..., 14251, 14007, 14009],
         [14510, 14515, 14393, ..., 14621, 14626, 14754],
         [15005, 15135, 14891, ..., 14992, 15120, 15248]],

        [[  373,   131,     8, ...,   115,   116,   371],
         [  745,   627,   631, ...,   980,   614,   618],
         [ 1240,  1368,  1127, ...,  1476,  1110,  1487],
         ...,
         [14015, 14143, 14146, ..., 14127, 14007, 14009],
         [14509, 14515, 14392, ..., 14621, 14501, 14754],
         [14881, 15009, 15013, ..., 14992, 14999, 15125]],

        ...,

        [[  374,   253,   259, ...,   238,   116,   370],
         [  745,   627,   631, ...,   856,   614,   618],
         [ 1241,  1123,  1127, ...,  1476,  1110,  1487],
         ...,
         [14012, 14016, 14269, ..., 14127, 14006, 14008],
         [14633, 14515, 14516, ..., 14744, 14502, 14755],
         [15253, 15132, 14891, ..., 14992, 15122, 15124]],

        [[  250,   253,     8, ...,   112,   366,   370],
         [  497,   627,   755, ...,   980,   862,   990],
         [ 1241,  1244,  1127, ...,  1476,  1482,  1113],
         ...,
         [14137, 14267, 14144, ..., 14127, 14004, 14008],
         [14758, 14514, 14518, ..., 14744, 14501, 14879],
         [14881, 15009, 15139, ..., 15116, 15369, 15124]],

        [[  373,     6,     8, ...,   114,   116,   370],
         [  745,   627,   629, ...,   856,   736,   618],
         [ 1241,  1244,  1127, ...,  1476,  1110,  1114],
         ...,
         [14013, 14017, 14020, ..., 14127, 14131, 14009],
         [14385, 14515, 14394, ..., 14622, 14503, 14505],
         [15252, 14884, 15015, ..., 14992, 15369, 15124]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[7.1899786, 7.3577456, 7.6858153, ..., 8.958739 , 9.312826 ,
        9.3664255]], shape=(1, 30752), dtype=float32), array([[120812.44 , 120839.29 , 121072.41 , ..., 120453.914, 120551.21 ,
        121532.125]], shape=(1, 1024), dtype=float32), array([[63484704., 60533000., 62130512., 61941408., 63664984., 61566120.,
        63286264., 63703008., 63079624., 61745824., 62223184., 60576504.,
        61746176., 61067632., 62630712., 61837336., 62830504., 62107448.,
        59699292., 62542712., 61596336., 63030620., 62054184., 60935224.,
        61544172., 63167976., 62568496., 62869856., 60167544., 62973136.,
        62354568., 63392608., 61804184., 61095572., 63384256., 59698552.,
        61696724., 60751416., 62832744., 62190488., 63270368., 61412040.,
        61984328., 60175340., 62067800., 61359244., 61532304., 63054916.,
        62221136., 64136068., 60459248., 63276432., 62465880., 63874488.,
        62690524., 61346160., 62701872., 61051048., 61692488., 61089428.,
        62281944., 63674976., 61875704., 62530728., 63580112., 60634688.,
        62300992., 61001220., 62401952., 60494996., 60453248., 59658960.,
        63709336., 63080192., 61398376., 61248872., 62664520., 60157236.,
        61482216., 63152300., 61016756., 60946248., 61574216., 63291412.,
        61452816., 62619224., 61836464., 61059952., 60368320., 60074352.,
        62982848., 62147304., 60649312., 61713400., 61994580., 59903760.,
        63128048., 60552948., 63441584., 58940096., 60717852., 62886584.,
        63638224., 61619360., 62009328., 62346952., 60813680., 62829600.,
        60454452., 61774616., 62928664., 60918748., 61073508., 60762400.,
        62621056., 62109944., 61926528., 61843684., 63290372., 60095112.,
        62949164., 61098680., 63197128., 61460384., 62764144., 63427484.,
        62074096., 62518216.]], dtype=float32), array([[4.2773996e+09, 3.7981742e+09, 3.7741798e+09, 3.8146074e+09,
        4.3467366e+09, 4.3224289e+09, 4.0250798e+09, 4.1375685e+09,
        3.9062408e+09, 4.0790200e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.295 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate DORT Training

201: Evaluate DORT Training

102: Measure LLAMA speed

102: Measure LLAMA speed

201: Evaluate DORT

201: Evaluate DORT

301: Compares LLAMA exporters

301: Compares LLAMA exporters

Gallery generated by Sphinx-Gallery