101: Profile an existing model with onnxruntime

Profiles any onnx model on CPU.

Preparation

import os
import numpy as np
import matplotlib.pyplot as plt
from experimental_experiment.args import get_parsed_args

try:
    from onnx_extended.tools.js_profile import (
        js_profile_to_dataframe,
        plot_ort_profile,
    )
except ImportError:
    js_profile_to_dataframe = None

try:
    filename = os.path.join(
        os.path.dirname(__file__ or ""), "data", "example_4700-CPUep-opt.onnx"
    )
except NameError:
    filename = "data/example_4700-CPUep-opt.onnx"

script_args = get_parsed_args(
    "plot_profile_existing_onnx",
    filename=(filename, "input file"),
    repeat=10,
    expose="",
)


for att in ("filename", "repeat"):
    print(f"{att}={getattr(script_args, att)}")
filename=data/example_4700-CPUep-opt.onnx
repeat=10

Random inputs.

def create_random_input(sess):
    feeds = {}
    for i in sess.get_inputs():
        shape = i.shape
        ot = i.type
        if ot == "tensor(float)":
            dtype = np.float32
        else:
            raise ValueError(f"Unsupposed onnx type {ot}.")
        t = np.random.rand(*shape).astype(dtype)
        feeds[i.name] = t
    return feeds


def create_session(filename, profiling=False):
    from onnxruntime import InferenceSession, SessionOptions

    if not profiling:
        return InferenceSession(filename, providers=["CPUExecutionProvider"])
    opts = SessionOptions()
    opts.enable_profiling = True
    return InferenceSession(filename, opts, providers=["CPUExecutionProvider"])


sess = create_session(script_args.filename)
feeds = create_random_input(sess)
sess.run(None, feeds)
[array([[0.38556445, 0.55397576, 0.8805988 , ..., 0.80855435, 0.9698065 ,
        0.1651492 ],
       [0.60911006, 0.98146534, 0.48810425, ..., 0.37745422, 0.0707617 ,
        0.21648818],
       [0.14771894, 0.10950058, 0.7909586 , ..., 0.05426677, 0.37327752,
        0.26818016],
       ...,
       [0.3218134 , 0.29144892, 0.0526025 , ..., 0.02053035, 0.82496136,
        0.9514841 ],
       [0.635222  , 0.64783275, 0.4547749 , ..., 0.76839703, 0.34234193,
        0.40970457],
       [0.3180374 , 0.5930903 , 0.9649414 , ..., 0.33332193, 0.47703722,
        0.55091107]], shape=(128, 1024), dtype=float32), array([[0.95870274, 0.636761  , 0.44476703, ..., 0.19968031, 0.4780304 ,
        0.26984775],
       [0.49667543, 0.8193919 , 0.7951488 , ..., 0.1742966 , 0.426188  ,
        0.35868287],
       [0.13918132, 0.42133337, 0.48337352, ..., 0.80362266, 0.58344287,
        0.6889565 ],
       ...,
       [0.6976248 , 0.8877859 , 0.1375428 , ..., 0.6435721 , 0.95254046,
        0.40901002],
       [0.8864794 , 0.25991818, 0.49433857, ..., 0.7553312 , 0.5885717 ,
        0.6678185 ],
       [0.13902307, 0.8430048 , 0.49835652, ..., 0.65506876, 0.72827005,
        0.16622455]], shape=(1024, 30752), dtype=float32), array([[0.9065616 , 0.9900758 , 0.51137394, ..., 0.48558617, 0.63175607,
        0.62513673],
       [0.2600353 , 0.67811406, 0.19555514, ..., 0.10207321, 0.66219664,
        0.88368404],
       [0.4334533 , 0.8996908 , 0.46073616, ..., 0.67588544, 0.0141007 ,
        0.89307386],
       ...,
       [0.51598513, 0.14081694, 0.06316388, ..., 0.98565435, 0.5844102 ,
        0.8166332 ],
       [0.46951342, 0.7719806 , 0.8183161 , ..., 0.46453735, 0.8912257 ,
        0.81608504],
       [0.24607968, 0.4938955 , 0.7717674 , ..., 0.5502433 , 0.18733628,
        0.7372484 ]], shape=(10, 128), dtype=float32), array([[[[7.2124248, 6.949165 , 5.3856463, ..., 4.8469324, 5.320979 ,
          6.0521   ],
         [7.0614085, 6.2695427, 5.4931307, ..., 6.3549614, 6.6163974,
          6.017617 ],
         [7.381121 , 7.3006744, 5.7618985, ..., 7.3029084, 7.318741 ,
          5.310724 ],
         ...,
         [6.0006537, 6.3537717, 6.3752923, ..., 5.9468803, 5.795259 ,
          5.810827 ],
         [5.379063 , 5.876003 , 6.685616 , ..., 6.7443995, 6.578243 ,
          6.7477827],
         [5.1108017, 5.9684243, 5.794866 , ..., 6.321777 , 7.0116224,
          6.7306232]],

        [[8.06898  , 7.4926353, 6.1334796, ..., 5.615914 , 7.085943 ,
          6.8165927],
         [7.9092493, 6.7172837, 5.5994554, ..., 7.534926 , 7.068978 ,
          6.920303 ],
         [8.383128 , 7.9394937, 5.7300982, ..., 7.8446493, 7.263592 ,
          5.735348 ],
         ...,
         [6.9940324, 7.7478275, 7.1182747, ..., 6.857957 , 6.194977 ,
          6.4526353],
         [6.337092 , 7.0728784, 7.845219 , ..., 8.164216 , 7.142987 ,
          8.17029  ],
         [6.4320693, 6.7336698, 7.240597 , ..., 7.342244 , 7.521635 ,
          8.012213 ]],

        [[5.124064 , 4.732128 , 3.5098467, ..., 4.012354 , 3.9715738,
          3.7085536],
         [4.8587728, 4.8069825, 2.9407508, ..., 4.826569 , 5.049167 ,
          4.556578 ],
         [5.6522174, 4.7376804, 3.9291642, ..., 4.69229  , 4.4731402,
          4.2415004],
         ...,
         [4.896876 , 4.9004974, 5.440536 , ..., 4.2603846, 4.126876 ,
          4.629271 ],
         [3.7297525, 4.5046086, 5.397951 , ..., 4.871149 , 5.1618586,
          6.174954 ],
         [4.2740426, 4.448325 , 5.3924384, ..., 5.0491276, 5.3751755,
          4.833969 ]],

        ...,

        [[7.6567993, 6.9736595, 5.819581 , ..., 6.2047315, 6.4378743,
          5.8060503],
         [7.087953 , 6.772794 , 5.4935155, ..., 7.0531735, 6.8689923,
          7.3181067],
         [7.8979726, 7.473395 , 6.11533  , ..., 6.4378467, 5.8727493,
          6.642461 ],
         ...,
         [6.6677213, 7.990225 , 6.453491 , ..., 5.8332553, 5.967316 ,
          6.534334 ],
         [5.6889033, 7.0696597, 6.1350913, ..., 6.7491493, 7.232734 ,
          7.3192263],
         [5.3458953, 6.3537183, 6.6923842, ..., 6.86066  , 6.358967 ,
          7.2293086]],

        [[6.190681 , 4.934544 , 4.2880063, ..., 4.067191 , 4.7604604,
          4.095345 ],
         [5.4029784, 5.408377 , 3.620253 , ..., 5.2773747, 5.409656 ,
          5.743461 ],
         [6.1972303, 5.4862523, 4.1798673, ..., 5.4588537, 4.4254813,
          4.4029136],
         ...,
         [5.1725945, 4.882221 , 5.595881 , ..., 4.0541406, 3.5928552,
          4.3137593],
         [4.4695835, 4.2526155, 5.578326 , ..., 5.469363 , 5.4409103,
          5.0957747],
         [4.410438 , 4.3117046, 5.8523507, ..., 5.515564 , 4.527907 ,
          4.909782 ]],

        [[7.51237  , 6.875618 , 5.907875 , ..., 5.4989085, 6.8030105,
          5.2218556],
         [7.7072873, 6.429263 , 5.3156986, ..., 7.142754 , 6.151346 ,
          6.922994 ],
         [7.2347255, 6.926053 , 5.7444644, ..., 7.704699 , 6.722652 ,
          6.2124476],
         ...,
         [6.441152 , 7.2276235, 6.91788  , ..., 6.3062444, 5.544852 ,
          5.9097867],
         [6.163632 , 5.882812 , 6.2885513, ..., 6.486774 , 6.972126 ,
          6.839702 ],
         [5.8554616, 5.5420575, 6.187286 , ..., 6.4345636, 7.115575 ,
          6.2969093]]]], shape=(1, 32, 124, 124), dtype=float32), array([[[[  248,   252,   380, ...,   361,   491,   244],
         [  744,   503,   753, ...,   611,   738,   617],
         [  992,   997,  1000, ...,  1476,  1111,  1239],
         ...,
         [14136, 14141, 13899, ..., 14372, 14377, 14011],
         [14757, 14762, 14394, ..., 14496, 14624, 14507],
         [15131, 15010, 15137, ..., 14994, 15121, 15374]],

        [[  248,   131,   256, ...,   361,   491,   493],
         [  744,   874,   752, ...,   859,   615,   617],
         [  994,   998,  1000, ...,  1353,  1111,  1238],
         ...,
         [14136, 13892, 14146, ..., 14248, 14377, 14135],
         [14757, 14513, 14641, ..., 14496, 14872, 14507],
         [15130, 15258, 15260, ..., 14993, 15246, 15251]],

        [[  248,   131,   381, ...,   361,   119,   493],
         [  620,   874,   754, ...,   610,   737,   741],
         [  994,   997,  1000, ...,  1477,  1235,  1236],
         ...,
         [13889, 14143, 13898, ..., 14372, 14005, 14011],
         [14758, 14638, 14642, ..., 14869, 14624, 14506],
         [15006, 15258, 15261, ..., 14993, 14996, 15251]],

        ...,

        [[  248,   131,   381, ...,   360,   491,   492],
         [  744,   751,   752, ...,   858,   739,   741],
         [  994,   998,  1002, ...,  1477,  1111,  1114],
         ...,
         [14012, 14140, 14145, ..., 14126, 14130, 14135],
         [14508, 14638, 14517, ..., 14869, 14748, 14507],
         [15005, 15258, 15263, ..., 15119, 15245, 15251]],

        [[  248,   255,   380, ...,   484,   489,   493],
         [  497,   750,   752, ...,   610,   862,   617],
         [ 1118,  1247,  1002, ...,  1352,  1235,  1487],
         ...,
         [13890, 13895, 14144, ..., 14372, 14005, 14008],
         [14758, 14761, 14640, ..., 14496, 14872, 14755],
         [15254, 15009, 15263, ..., 14993, 14997, 15002]],

        [[  124,   131,   258, ...,   361,   365,   493],
         [  744,   874,   628, ...,   859,   614,   617],
         [ 1118,   997,  1375, ...,  1476,  1111,  1114],
         ...,
         [13889, 14141, 13899, ..., 14372, 14005, 14011],
         [14756, 14761, 14640, ..., 14496, 14872, 14507],
         [15255, 15134, 15260, ..., 14995, 15121, 15374]]]],
      shape=(1, 32, 31, 31), dtype=int64), array([[7.381121 , 6.113424 , 6.964289 , ..., 7.1669254, 7.2326717,
        7.115575 ]], shape=(1, 30752), dtype=float32), array([[119141.46 , 119613.23 , 119708.78 , ..., 119339.086, 118628.85 ,
        118833.1  ]], shape=(1, 1024), dtype=float32), array([[6.1533384e+07, 5.9396112e+07, 6.0553608e+07, 6.1904116e+07,
        6.0651256e+07, 6.0473864e+07, 6.0628596e+07, 6.1006848e+07,
        6.1611816e+07, 6.1494788e+07, 5.9946136e+07, 6.0892072e+07,
        6.0989100e+07, 6.0879668e+07, 5.9669440e+07, 6.1841072e+07,
        6.0864296e+07, 5.8424244e+07, 6.0586568e+07, 6.0708544e+07,
        6.1408496e+07, 6.3036920e+07, 6.1415088e+07, 6.1482232e+07,
        6.0187580e+07, 6.2735524e+07, 6.1316716e+07, 5.9624520e+07,
        6.0540560e+07, 6.0819616e+07, 6.3122228e+07, 6.0894616e+07,
        6.0451604e+07, 6.1070192e+07, 6.0979504e+07, 6.1360276e+07,
        5.9039568e+07, 6.0540144e+07, 6.1145744e+07, 6.0742980e+07,
        6.1351368e+07, 6.0078768e+07, 6.2439900e+07, 6.2247788e+07,
        5.9820968e+07, 5.8503176e+07, 6.0834992e+07, 5.8239728e+07,
        6.1263068e+07, 5.9641888e+07, 6.2209932e+07, 6.1295636e+07,
        6.0269064e+07, 6.0916872e+07, 6.2716144e+07, 5.9071256e+07,
        6.1943084e+07, 6.0251592e+07, 5.9849740e+07, 6.0636216e+07,
        6.1342548e+07, 5.9601248e+07, 6.1807060e+07, 6.0105000e+07,
        6.1794720e+07, 5.9975672e+07, 6.0113808e+07, 5.9534040e+07,
        5.9255896e+07, 6.0921696e+07, 6.1040104e+07, 5.9613088e+07,
        6.0227500e+07, 6.0268120e+07, 5.9954272e+07, 6.1388912e+07,
        6.3367608e+07, 6.1473328e+07, 6.2184904e+07, 6.2094508e+07,
        6.1093552e+07, 6.0453852e+07, 6.1754576e+07, 6.1102236e+07,
        6.0100800e+07, 6.0864548e+07, 6.1891544e+07, 6.0204456e+07,
        6.1565576e+07, 6.0354372e+07, 6.0864624e+07, 5.9297096e+07,
        6.0754444e+07, 6.1920328e+07, 6.0301968e+07, 6.1914372e+07,
        6.1900076e+07, 6.1327152e+07, 6.2442824e+07, 6.1349680e+07,
        6.0438016e+07, 6.2174912e+07, 6.0849416e+07, 5.8883824e+07,
        6.1998108e+07, 5.9313568e+07, 6.2345872e+07, 6.0202680e+07,
        6.1111028e+07, 6.2123064e+07, 6.0621932e+07, 6.0387752e+07,
        5.9783000e+07, 6.1904984e+07, 5.8992736e+07, 6.0704196e+07,
        6.1475780e+07, 6.1883072e+07, 6.1640540e+07, 6.1268072e+07,
        6.1788800e+07, 6.0115072e+07, 5.9435696e+07, 6.0516424e+07,
        6.0247128e+07, 6.0656156e+07, 6.0185568e+07, 6.3598848e+07]],
      dtype=float32), array([[3.8433152e+09, 3.5135181e+09, 4.3521690e+09, 4.0556605e+09,
        3.5809229e+09, 4.0594386e+09, 4.1181942e+09, 4.0656640e+09,
        3.9362488e+09, 3.9547407e+09]], dtype=float32)]

Profiling

sess = create_session(script_args.filename, profiling=True)

for _ in range(script_args.repeat):
    sess.run(None, feeds)

prof = sess.end_profiling()
if js_profile_to_dataframe is not None:
    df = js_profile_to_dataframe(prof, first_it_out=True)
    print(df.columns)
    df.to_csv("plot_profile_existing_onnx.csv")
    df.to_excel("plot_profile_existing_onnx.xlsx")
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    plot_ort_profile(df, ax[0], ax[1], "dort")
    fig.tight_layout()
    fig.savefig("plot_profile_existing_onnx.png")
else:
    print("Install onnx-extended first.")
dort, n occurences
Index(['cat', 'pid', 'tid', 'dur', 'ts', 'ph', 'name',
       'args_thread_scheduling_stats', 'args_output_size',
       'args_parameter_size', 'args_activation_size', 'args_node_index',
       'args_provider', 'args_op_name', 'op_name', 'event_name', 'iteration',
       'it==0'],
      dtype='object')

Total running time of the script: (0 minutes 1.626 seconds)

Related examples

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

201: Better shape inference

201: Better shape inference

101: A custom backend for torch

101: A custom backend for torch

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

Gallery generated by Sphinx-Gallery