101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in "filename,repeat".split(","):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.99603003, 0.28187045, 0.94057274, ..., 0.47897407, 0.67544824,
        0.28367826],
       [0.59909606, 0.9200839 , 0.11144039, ..., 0.2967804 , 0.54931194,
        0.4495453 ],
       [0.354489  , 0.9961373 , 0.8847841 , ..., 0.6899941 , 0.81957257,
        0.9896713 ],
       ...,
       [0.5167857 , 0.05671073, 0.23340201, ..., 0.09100904, 0.7765644 ,
        0.84998775],
       [0.45803857, 0.5119961 , 0.01425601, ..., 0.9199097 , 0.9291237 ,
        0.10630759],
       [0.1123516 , 0.13901344, 0.590363  , ..., 0.24007647, 0.39578873,
        0.702027  ]], shape=(128, 1024), dtype=float32), array([[0.08909088, 0.43504444, 0.33286464, ..., 0.8431867 , 0.971802  ,
        0.9439682 ],
       [0.963266  , 0.30320296, 0.4091526 , ..., 0.30921045, 0.10420576,
        0.11764065],
       [0.19232771, 0.99508417, 0.4739232 , ..., 0.76210153, 0.50049376,
        0.6888989 ],
       ...,
       [0.9358125 , 0.02748959, 0.17122154, ..., 0.616237  , 0.4705962 ,
        0.30539045],
       [0.5107166 , 0.1685892 , 0.07679476, ..., 0.20251717, 0.23111762,
        0.90966547],
       [0.7315253 , 0.7814603 , 0.01096786, ..., 0.88763666, 0.29489562,
        0.9761997 ]], shape=(1024, 30752), dtype=float32), array([[0.8004982 , 0.5728458 , 0.02227808, ..., 0.8448599 , 0.2603263 ,
        0.48587894],
       [0.58759516, 0.30733246, 0.42958462, ..., 0.81222117, 0.11928362,
        0.7038422 ],
       [0.6779062 , 0.971268  , 0.26938584, ..., 0.14818424, 0.4488738 ,
        0.08580573],
       ...,
       [0.7512462 , 0.48173395, 0.25311154, ..., 0.65344936, 0.16689548,
        0.3951568 ],
       [0.49963972, 0.45163333, 0.0149904 , ..., 0.35628906, 0.18334794,
        0.0660686 ],
       [0.6029076 , 0.62692666, 0.9531394 , ..., 0.16412741, 0.1392266 ,
        0.22076605]], shape=(10, 128), dtype=float32), array([[[[7.412415 , 6.947396 , 7.7499776, ..., 7.344826 , 8.160382 ,
          7.378998 ],
         [7.1864243, 6.8602967, 8.583358 , ..., 7.9021297, 8.480304 ,
          8.434767 ],
         [7.714754 , 8.168961 , 9.049314 , ..., 7.6547537, 7.9405837,
          7.4969406],
         ...,
         [6.755928 , 7.248899 , 6.8941197, ..., 7.530084 , 7.929444 ,
          8.2535   ],
         [6.245516 , 7.1055593, 8.334227 , ..., 7.163211 , 6.7189627,
          7.4504066],
         [6.868168 , 7.778609 , 7.2154665, ..., 7.093775 , 8.039719 ,
          7.9857683]],

        [[6.4633126, 6.0704403, 6.9583626, ..., 6.0713053, 7.352524 ,
          7.0360656],
         [7.0646014, 6.682242 , 7.0236006, ..., 7.077313 , 8.062915 ,
          7.6973286],
         [7.182877 , 6.7784886, 8.251508 , ..., 7.5644298, 6.985471 ,
          6.875925 ],
         ...,
         [4.956811 , 6.4329653, 7.3121424, ..., 6.418789 , 6.9557605,
          6.9164553],
         [5.555089 , 6.1387954, 7.3094616, ..., 6.5220647, 6.8980856,
          6.389908 ],
         [6.2811685, 6.8814416, 7.051608 , ..., 7.027675 , 7.2775087,
          6.9944224]],

        [[7.8101773, 8.02978  , 9.22542  , ..., 8.815471 , 8.727648 ,
          8.328909 ],
         [7.7797956, 8.424426 , 8.762375 , ..., 9.048892 , 9.356105 ,
          9.635596 ],
         [8.468139 , 8.826173 , 9.898383 , ..., 8.519879 , 9.694647 ,
          8.453154 ],
         ...,
         [6.9962697, 8.301465 , 8.131982 , ..., 8.62877  , 8.914747 ,
          9.201767 ],
         [7.599544 , 7.8799896, 8.994987 , ..., 7.9311004, 7.7485538,
          8.539363 ],
         [7.289168 , 8.507123 , 8.951217 , ..., 8.628854 , 8.851739 ,
          9.429405 ]],

        ...,

        [[5.3690186, 6.748861 , 6.405227 , ..., 6.640103 , 6.860623 ,
          6.891681 ],
         [5.975225 , 6.2555356, 6.2707887, ..., 7.000485 , 7.272141 ,
          6.163204 ],
         [6.6916194, 6.646954 , 6.9834046, ..., 6.430172 , 7.025165 ,
          6.7258735],
         ...,
         [5.072008 , 6.2744794, 7.3653674, ..., 6.208679 , 6.083686 ,
          7.1826925],
         [5.672527 , 6.5475283, 6.1471643, ..., 6.1967096, 6.40111  ,
          6.637901 ],
         [5.578671 , 6.0090785, 6.4548497, ..., 7.002253 , 6.1186047,
          6.903483 ]],

        [[7.6691523, 6.071759 , 8.334522 , ..., 7.9848547, 7.282882 ,
          6.621601 ],
         [6.8803563, 6.303494 , 7.7896066, ..., 7.136377 , 6.8392086,
          7.7262993],
         [6.916894 , 7.026955 , 8.212795 , ..., 7.4818993, 7.7161474,
          7.1175203],
         ...,
         [6.2011323, 7.709763 , 6.3460116, ..., 6.727624 , 7.8009114,
          7.369134 ],
         [6.3365254, 6.566052 , 6.2050405, ..., 6.5383816, 6.9874   ,
          6.8241477],
         [5.496375 , 6.8991847, 7.706931 , ..., 6.8613687, 7.3534236,
          6.857214 ]],

        [[8.619865 , 7.8872576, 9.443401 , ..., 7.889218 , 9.256315 ,
          8.805613 ],
         [7.604432 , 7.3575325, 9.625925 , ..., 8.851163 , 9.473678 ,
          8.617633 ],
         [8.4917965, 8.935475 , 9.398748 , ..., 8.367471 , 8.738515 ,
          9.16141  ],
         ...,
         [6.7676983, 8.4386215, 8.934378 , ..., 8.351564 , 8.401967 ,
          9.26569  ],
         [6.608559 , 8.492819 , 8.664786 , ..., 7.8224835, 8.02679  ,
          8.37425  ],
         [7.6203876, 8.453541 , 8.304399 , ..., 8.572281 , 8.418485 ,
          8.593636 ]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[  250,   376,   380, ...,   237,   491,   246],
         [  622,   750,   879, ...,   735,   860,   618],
         [ 1366,  1121,  1251, ...,  1229,  1232,  1238],
         ...,
         [14137, 14018, 14144, ..., 14375, 14376, 14135],
         [14635, 14513, 14517, ..., 14622, 14748, 14630],
         [15130, 14886, 15260, ..., 14992, 14996, 15003]],

        [[  250,   130,    11, ...,   487,   116,   246],
         [  498,   626,   879, ...,   858,   863,   619],
         [ 1366,  1120,  1003, ...,  1478,  1233,  1236],
         ...,
         [14260, 14266, 14144, ..., 14124, 14254, 14135],
         [14758, 14389, 14765, ..., 14497, 14624, 14879],
         [14882, 15259, 15012, ..., 15366, 15370, 15003]],

        [[  250,   253,   380, ...,   486,   364,   370],
         [  623,   751,   879, ...,   735,   736,   742],
         [ 1367,  1368,  1251, ...,  1230,  1232,  1238],
         ...,
         [14260, 14141, 14144, ..., 14375, 14252, 14135],
         [14758, 14514, 14517, ..., 14623, 14749, 14631],
         [14881, 15259, 15260, ..., 15365, 15371, 15003]],

        ...,

        [[  127,     4,    10, ...,   239,   364,   494],
         [  498,   626,   879, ...,   734,   739,   742],
         [ 1367,  1368,  1003, ...,  1478,  1356,  1238],
         ...,
         [14136, 14142, 14146, ..., 14251, 14376, 14256],
         [14759, 14389, 14641, ..., 14623, 14501, 14506],
         [15006, 15135, 15260, ..., 15365, 14997, 15003]],

        [[  373,   253,    10, ...,   487,   488,   493],
         [  498,   749,   879, ...,   982,   860,   742],
         [ 1367,  1368,  1002, ...,  1477,  1111,  1115],
         ...,
         [13889, 14018, 14145, ..., 14375, 14376, 14256],
         [14758, 14637, 14641, ..., 14499, 14500, 14631],
         [14882, 14886, 14890, ..., 14995, 15368, 15003]],

        [[  374,     6,   380, ...,   486,   489,   246],
         [  498,   750,   877, ...,   734,   987,   990],
         [ 1366,  1121,  1003, ...,  1478,  1232,  1238],
         ...,
         [13889, 14266, 14144, ..., 14124, 14376, 14259],
         [14634, 14639, 14641, ..., 14623, 14748, 14755],
         [14882, 15256, 14889, ..., 14993, 14997, 15003]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[9.049314 , 8.8050585, 8.431708 , ..., 8.636536 , 9.145031 ,
        9.532457 ]], shape=(1, 30752), dtype=float32), array([[122692.27 , 123354.43 , 122738.766, ..., 123197.14 , 123005.64 ,
        122822.42 ]], shape=(1, 1024), dtype=float32), array([[64476524., 63178512., 63571928., 61408712., 63948544., 61254916.,
        63788640., 60656840., 62005292., 60576240., 61750332., 64028384.,
        62924260., 62590548., 62791904., 62125692., 63452328., 63223480.,
        61166956., 63162168., 63957280., 63214704., 64537548., 62347512.,
        62676004., 64064904., 61335300., 63115980., 62215372., 63729584.,
        63329504., 64048060., 61412608., 63045352., 62345248., 62715024.,
        62450608., 63085408., 64916004., 62658532., 63560752., 62164024.,
        63572168., 60429192., 65007292., 62415392., 63668064., 61908612.,
        62705496., 61592256., 62911368., 62977044., 62988068., 65124008.,
        62349432., 60898712., 63354120., 60735272., 61795744., 62678264.,
        60191036., 63176676., 61930928., 62697992., 64341036., 62204100.,
        62658300., 63576296., 65663784., 62614416., 64790332., 62813208.,
        62694712., 65478248., 63085568., 62655460., 63305816., 60891716.,
        61771448., 62881608., 62181700., 63100712., 63485992., 62307460.,
        64776068., 63789616., 62237808., 64990136., 63628948., 61646744.,
        64225640., 61165748., 63018436., 62389832., 62274148., 64072104.,
        60719808., 61923820., 64079528., 63664608., 62672704., 63486440.,
        63350464., 63448668., 61290112., 62771896., 60954944., 63210696.,
        63092212., 63400984., 62084264., 63211000., 62950264., 64543832.,
        63648704., 64333432., 63517836., 60646512., 62356552., 60787600.,
        62999144., 61455820., 63084096., 62759296., 60923248., 64858940.,
        62804056., 61582172.]], dtype=float32), array([[4.1169510e+09, 3.9174533e+09, 4.0461307e+09, 4.0379960e+09,
        3.7629184e+09, 4.1390367e+09, 4.1590989e+09, 4.0597801e+09,
        4.4176717e+09, 4.0217754e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.610 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate DORT Training

201: Evaluate DORT Training

102: Measure LLAMA speed

102: Measure LLAMA speed

301: Compares LLAMA exporters

301: Compares LLAMA exporters

101: A custom backend for torch

101: A custom backend for torch

Gallery generated by Sphinx-Gallery