101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.8800316 , 0.1565457 , 0.7311014 , ..., 0.62529314, 0.9170331 ,
        0.11257499],
       [0.02621556, 0.2150497 , 0.59243214, ..., 0.14446622, 0.71780896,
        0.3996856 ],
       [0.2285278 , 0.04529846, 0.02729026, ..., 0.44896546, 0.32855508,
        0.8476245 ],
       ...,
       [0.8691981 , 0.94416565, 0.11693375, ..., 0.49558818, 0.9814404 ,
        0.78239745],
       [0.9971977 , 0.74231964, 0.6078006 , ..., 0.4596116 , 0.11658058,
        0.5496513 ],
       [0.45001355, 0.36166173, 0.62603885, ..., 0.4452884 , 0.4126759 ,
        0.17366609]], shape=(128, 1024), dtype=float32), array([[0.29583213, 0.14490867, 0.589824  , ..., 0.4678403 , 0.04309676,
        0.14172134],
       [0.60258716, 0.34892586, 0.8109066 , ..., 0.19553368, 0.48262998,
        0.6269261 ],
       [0.8850752 , 0.09799934, 0.84918606, ..., 0.31640843, 0.860387  ,
        0.5689675 ],
       ...,
       [0.20949169, 0.22091351, 0.90057385, ..., 0.34955096, 0.80877936,
        0.19594395],
       [0.8107549 , 0.8188743 , 0.985135  , ..., 0.9879154 , 0.4369427 ,
        0.19141744],
       [0.90216166, 0.47737488, 0.9437667 , ..., 0.623779  , 0.12656698,
        0.8749854 ]], shape=(1024, 30752), dtype=float32), array([[0.6211051 , 0.38196164, 0.70972604, ..., 0.38812354, 0.8140742 ,
        0.2158481 ],
       [0.6738045 , 0.79611367, 0.01371412, ..., 0.04355347, 0.29237035,
        0.7252755 ],
       [0.9716458 , 0.11741051, 0.0205212 , ..., 0.3371951 , 0.26055747,
        0.26125574],
       ...,
       [0.5123202 , 0.18919867, 0.11697026, ..., 0.228091  , 0.9696022 ,
        0.60125136],
       [0.78565794, 0.76074916, 0.1477966 , ..., 0.7654378 , 0.24070972,
        0.08616269],
       [0.1744032 , 0.23314552, 0.55516225, ..., 0.45925698, 0.7293221 ,
        0.7714493 ]], shape=(10, 128), dtype=float32), array([[[[5.6080728, 5.879957 , 6.2761335, ..., 6.354779 , 4.9485126,
          5.215397 ],
         [5.6972356, 6.175117 , 5.6709957, ..., 5.9555974, 5.853348 ,
          4.939538 ],
         [5.9978356, 6.3872027, 5.93476  , ..., 5.894418 , 5.348946 ,
          5.347117 ],
         ...,
         [6.6759267, 6.595423 , 6.6967854, ..., 4.3580065, 3.9331465,
          4.464229 ],
         [6.199033 , 6.543278 , 6.2544227, ..., 4.690233 , 3.9420273,
          3.9371111],
         [6.619613 , 6.163595 , 5.7971234, ..., 4.591831 , 3.8123624,
          3.705595 ]],

        [[5.579121 , 5.4777207, 5.567755 , ..., 5.4606247, 4.7075315,
          5.0974855],
         [5.126064 , 4.217288 , 4.82486  , ..., 4.868908 , 4.7508106,
          4.3303437],
         [5.6175604, 5.2288322, 4.7507257, ..., 4.9057775, 4.5535617,
          4.419969 ],
         ...,
         [5.3279824, 5.4914746, 6.279444 , ..., 3.75188  , 3.4050195,
          4.713725 ],
         [5.4564505, 5.325256 , 5.8769484, ..., 3.924797 , 3.4823897,
          3.9182181],
         [5.6844954, 4.9992666, 4.492304 , ..., 4.27733  , 4.0100174,
          4.217204 ]],

        [[8.638106 , 8.01329  , 8.646049 , ..., 7.968223 , 7.0755343,
          6.9595375],
         [7.8601084, 7.8778753, 7.034657 , ..., 7.722811 , 7.7659597,
          6.9368424],
         [8.143422 , 8.111586 , 8.189595 , ..., 7.7231994, 7.576418 ,
          7.24572  ],
         ...,
         [8.332782 , 8.942136 , 8.971578 , ..., 6.433213 , 5.25701  ,
          5.98688  ],
         [8.435105 , 9.1750965, 8.553229 , ..., 6.435356 , 5.877287 ,
          6.4431562],
         [8.722213 , 8.021475 , 7.398889 , ..., 6.3686037, 6.2669754,
          5.613288 ]],

        ...,

        [[6.8795333, 7.029456 , 7.1767206, ..., 7.0778255, 5.807105 ,
          5.8944917],
         [5.638324 , 7.0470634, 6.7055745, ..., 7.0124536, 6.749146 ,
          6.6055756],
         [7.0582795, 7.3123403, 6.485142 , ..., 6.8116426, 5.7626076,
          6.7492757],
         ...,
         [7.137337 , 8.042789 , 7.903657 , ..., 5.253233 , 5.3354077,
          4.35629  ],
         [7.389143 , 7.700496 , 7.9052114, ..., 5.201979 , 5.102844 ,
          5.263525 ],
         [7.1734104, 7.358558 , 6.82334  , ..., 5.6460195, 4.830977 ,
          4.7260404]],

        [[6.3675914, 5.775776 , 6.5316486, ..., 5.442385 , 5.1557956,
          5.481117 ],
         [5.048122 , 6.3470526, 5.8324585, ..., 5.574298 , 5.562227 ,
          5.6668124],
         [6.138811 , 6.763809 , 5.184014 , ..., 5.6179814, 5.5911336,
          5.930111 ],
         ...,
         [6.5419087, 7.182144 , 6.937564 , ..., 4.7039623, 4.2790594,
          4.0735683],
         [6.5316443, 6.7497096, 6.4860544, ..., 4.6504436, 4.1984644,
          4.650841 ],
         [6.4661884, 6.3849177, 5.2701406, ..., 5.250996 , 4.55113  ,
          4.4763727]],

        [[6.2261124, 5.7660675, 6.4132643, ..., 6.0000668, 5.725874 ,
          5.4538755],
         [5.2041693, 5.794277 , 6.921607 , ..., 6.001631 , 5.681967 ,
          6.1989183],
         [6.337393 , 6.864723 , 6.5822067, ..., 6.038706 , 5.9784155,
          6.6927824],
         ...,
         [6.6316986, 7.894251 , 7.431756 , ..., 4.85993  , 4.6346374,
          4.6964607],
         [6.923641 , 6.957471 , 6.6798964, ..., 5.663718 , 4.894273 ,
          4.425264 ],
         [6.5367413, 6.879862 , 6.090458 , ..., 5.172448 , 4.631626 ,
          4.877141 ]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[  249,     6,   257, ...,   360,   118,   121],
         [  747,   874,   505, ...,   608,   984,   743],
         [ 1364,   997,  1375, ...,  1354,  1235,  1486],
         ...,
         [14262, 13892, 13897, ..., 14003, 14253, 14008],
         [14509, 14760, 14643, ..., 14496, 14875, 14631],
         [14881, 15008, 14891, ..., 15118, 15121, 15000]],

        [[  373,   128,   380, ...,   486,   118,   120],
         [  621,   500,   504, ...,   610,   986,   988],
         [ 1116,  1370,  1372, ...,  1230,  1232,  1238],
         ...,
         [14261, 14140, 14020, ..., 14249, 14128, 14380],
         [14634, 14760, 14517, ..., 14746, 14750, 14506],
         [15006, 15008, 15139, ..., 15117, 15244, 15124]],

        [[    2,     6,   380, ...,   236,   119,   121],
         [  497,   751,   505, ...,   734,   986,   743],
         [ 1240,  1371,  1373, ...,  1230,  1232,  1485],
         ...,
         [14261, 14265, 14021, ..., 14002, 14004, 14256],
         [14510, 14636, 14395, ..., 14620, 14751, 14506],
         [15129, 15008, 15015, ..., 15366, 15245, 15000]],

        ...,

        [[  372,     6,   383, ...,   486,   119,   120],
         [  744,   749,   505, ...,   734,   984,   742],
         [ 1240,   998,  1373, ...,  1355,  1235,  1361],
         ...,
         [14136, 14018, 14021, ..., 14003, 14252, 14008],
         [14509, 14388, 14395, ..., 14621, 14875, 14753],
         [14882, 15008, 15139, ..., 15118, 15244, 15000]],

        [[  249,     6,   380, ...,   484,   118,   120],
         [  747,   875,   629, ...,   858,   986,   864],
         [ 1116,  1371,  1373, ...,  1230,  1232,  1361],
         ...,
         [14136, 14140, 14021, ..., 14251, 14006, 14256],
         [14632, 14391, 14643, ..., 14621, 14751, 14506],
         [14881, 15011, 15139, ..., 15365, 15121, 15372]],

        [[  126,   129,   258, ...,   486,   116,   371],
         [  871,   750,   506, ...,   608,   986,   864],
         [ 1366,  1371,  1372, ...,  1230,  1232,  1236],
         ...,
         [14136, 14019, 14021, ..., 14002, 14004, 14382],
         [14510, 14636, 14393, ..., 14621, 14750, 14628],
         [15005, 15256, 15015, ..., 15365, 14998, 15002]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[6.3872027, 6.7197433, 6.2416167, ..., 8.1515255, 7.3032866,
        6.0225306]], shape=(1, 30752), dtype=float32), array([[115955.516, 115892.23 , 115751.95 , ..., 115718.945, 115916.52 ,
        116846.695]], shape=(1, 1024), dtype=float32), array([[5.9659184e+07, 5.9664440e+07, 5.9971368e+07, 5.8910832e+07,
        5.9589992e+07, 6.0142564e+07, 5.8924392e+07, 5.9213020e+07,
        5.8786476e+07, 5.7909300e+07, 5.8865600e+07, 5.9845692e+07,
        5.8653948e+07, 5.7685956e+07, 6.1129084e+07, 6.1001976e+07,
        5.9196260e+07, 5.8439996e+07, 6.0301160e+07, 5.8395100e+07,
        5.7276176e+07, 5.9090208e+07, 5.9499800e+07, 6.0687320e+07,
        6.0186816e+07, 6.0105032e+07, 5.8936948e+07, 5.8344060e+07,
        5.8476692e+07, 5.9210984e+07, 5.9409392e+07, 5.9647736e+07,
        5.7836096e+07, 5.9485664e+07, 5.9718200e+07, 6.0732392e+07,
        5.9111792e+07, 6.0246808e+07, 5.9730104e+07, 5.8876280e+07,
        5.8835644e+07, 5.9149848e+07, 5.8859928e+07, 5.9045244e+07,
        5.8783280e+07, 5.9544180e+07, 5.8657880e+07, 5.8061760e+07,
        5.9287328e+07, 5.9157300e+07, 5.8784144e+07, 5.8773584e+07,
        5.8979432e+07, 5.8705060e+07, 5.9545304e+07, 5.9339996e+07,
        5.9318112e+07, 6.0859368e+07, 5.9944676e+07, 6.0262364e+07,
        5.7780352e+07, 5.7548232e+07, 5.9343296e+07, 5.8010624e+07,
        5.9631920e+07, 5.9877120e+07, 5.9207224e+07, 6.0099536e+07,
        5.8787616e+07, 5.8922216e+07, 6.1278384e+07, 5.9981192e+07,
        5.8116460e+07, 6.0157168e+07, 5.8923860e+07, 5.9267360e+07,
        5.9244228e+07, 6.0760544e+07, 6.0221824e+07, 5.9175440e+07,
        5.8358952e+07, 5.8613448e+07, 6.0620736e+07, 5.9199648e+07,
        5.8213216e+07, 6.0463424e+07, 5.9123888e+07, 6.0122680e+07,
        6.0504668e+07, 6.0344984e+07, 5.8893760e+07, 5.7030332e+07,
        5.9282164e+07, 5.9703600e+07, 5.9106192e+07, 5.8312568e+07,
        6.0589992e+07, 5.8813680e+07, 5.9460092e+07, 6.0281276e+07,
        5.9175032e+07, 5.8732708e+07, 5.9557620e+07, 6.1021932e+07,
        5.9096120e+07, 5.9126220e+07, 5.9398288e+07, 6.1209176e+07,
        6.0228208e+07, 6.1183100e+07, 5.9450080e+07, 5.8432408e+07,
        6.0312720e+07, 5.8739944e+07, 5.8204404e+07, 5.9773452e+07,
        5.9734852e+07, 5.7584768e+07, 6.0984816e+07, 5.8816848e+07,
        5.8835856e+07, 5.9760296e+07, 5.9261568e+07, 5.9491024e+07,
        5.9193900e+07, 5.8271896e+07, 6.1066428e+07, 5.8811600e+07]],
      dtype=float32), array([[4.2370161e+09, 3.8225513e+09, 3.8155474e+09, 3.8518804e+09,
        4.1606067e+09, 3.7995356e+09, 3.7226161e+09, 4.0438641e+09,
        3.9391332e+09, 3.8708603e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='str')

Total running time of the script: (0 minutes 1.316 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Better shape inference

201: Better shape inference

101: A custom backend for torch

101: A custom backend for torch

101: Some dummy examples with torch.export.export

101: Some dummy examples with torch.export.export

Gallery generated by Sphinx-Gallery