Profiling with onnxruntime

onnxruntime optimizes the onnx graph by default before running the inference. It modifies, fuses or add new operators. Some of them are standard onnx operators, some of them are implemented in onnxruntime (see Supported Operators). This example profiles the two models.

Optimize a model with onnxruntime

import os
import numpy
import matplotlib.pyplot as plt
from onnxruntime import get_available_providers
from onnx_array_api.ext_test_case import example_path
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
from onnx_array_api.ort.ort_profile import ort_profile, merge_ort_profile
from onnx_array_api.plotting.stat_plot import plot_ort_profile


suffix = ""
filename = example_path(f"data/small{suffix}.onnx")
optimized = filename + ".optimized.onnx"
print(f"model={filename!r}")

if not os.path.exists(optimized):
    ort_optimized_model(filename, output=optimized)
print(f"optimized={optimized!r}")
model='data/small.onnx'
optimized='data/small.onnx.optimized.onnx'

Profiling

feeds = {"input": numpy.random.random((1, 3, 112, 112)).astype(numpy.float32)}
prof_base = ort_profile(
    filename,
    feeds,
    repeat=6,
    disable_optimization=True,
    providers=["CPUExecutionProvider"],
)
prof_base.to_excel(f"prof_base{suffix}.xlsx", index=False)
prof_base
cat pid tid dur ts ph name args_thread_scheduling_stats args_output_type_shape args_output_size args_parameter_size args_activation_size args_node_index args_input_type_shape args_provider args_op_name op_name event_name iteration
0 Session 120046 120046 955 10 X model_loading_uri NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN model_loading_uri -1
1 Session 120046 120046 876 1020 X session_initialization NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN session_initialization -1
2 Node 120046 120046 1058 2106 X n0_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 7168 150528 0 [{'float': [1, 3, 112, 112]}, {'float': [64, 3... CPUExecutionProvider Conv n0 kernel_time -1
3 Node 120046 120046 1048 3192 X n1_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 256 3211264 1 [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider PRelu n1 kernel_time -1
4 Node 120046 120046 428 4263 X n3_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 1, 112, 112]}] 50176 0 3211264 3 [{'float': [1, 64, 112, 112]}] CPUExecutionProvider ReduceMax n3 kernel_time -1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
93 Node 120046 120046 122 74962 X n10_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 256 3211264 10 [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider PRelu n10 kernel_time 4
94 Node 120046 120046 1445 75093 X n11_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 56, 56]}] 802816 147712 3211264 11 [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider Conv n11 kernel_time 4
95 Node 120046 120046 44 76547 X n13_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 56, 56]}] 802816 0 1605632 13 [{'float': [1, 64, 56, 56]}, {'float': [1, 64,... CPUExecutionProvider Add n13 kernel_time 4
96 Session 120046 120046 8501 68097 X SequentialExecutor::Execute NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN SequentialExecutor::Execute 5
97 Session 120046 120046 8529 68080 X model_run NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN model_run 5

98 rows × 19 columns



And the optimized model.

prof_opti = ort_profile(
    optimized,
    feeds,
    repeat=6,
    disable_optimization=True,
    providers=["CPUExecutionProvider"],
)
prof_opti.to_excel(f"prof_opti{suffix}.xlsx", index=False)
prof_opti
cat pid tid dur ts ph name args_thread_scheduling_stats args_output_type_shape args_output_size args_parameter_size args_activation_size args_node_index args_input_type_shape args_provider args_op_name op_name event_name iteration
0 Session 120046 120046 690 3 X model_loading_uri NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN model_loading_uri -1
1 Session 120046 120046 643 726 X session_initialization NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN session_initialization -1
2 Node 120046 120046 413 1568 X r0_nchwc_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 7168 150528 0 [{'float': [1, 3, 112, 112]}, {'float': [64, 3... CPUExecutionProvider Conv r0_nchwc kernel_time -1
3 Node 120046 120046 234 2002 X ReorderOutput_token_14_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 0 3211264 1 [{'float': [1, 64, 112, 112]}] CPUExecutionProvider ReorderOutput ReorderOutput_token_14 kernel_time -1
4 Node 120046 120046 151 2252 X n1_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 256 3211264 2 [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider PRelu n1 kernel_time -1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
123 Node 120046 120046 76 45040 X ReorderInput_token_12_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 112, 112]}] 3211264 0 3211264 14 [{'float': [1, 64, 112, 112]}] CPUExecutionProvider ReorderInput ReorderInput_token_12 kernel_time 4
124 Node 120046 120046 673 45123 X r11_nchwc_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 56, 56]}] 802816 147712 4014080 17 [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider Conv r11_nchwc kernel_time 4
125 Node 120046 120046 40 45805 X ReorderOutput_kernel_time {'main_thread': {'thread_pool_name': 'session-... [{'float': [1, 64, 56, 56]}] 802816 0 802816 18 [{'float': [1, 64, 56, 56]}] CPUExecutionProvider ReorderOutput ReorderOutput kernel_time 4
126 Session 120046 120046 5359 40491 X SequentialExecutor::Execute NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN SequentialExecutor::Execute 5
127 Session 120046 120046 5445 40413 X model_run NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN model_run 5

128 rows × 19 columns



And the graph is:

unique_op = set(prof_base["args_op_name"])
fig, ax = plt.subplots(2, 2, figsize=(10, len(unique_op)), sharex="col")
plot_ort_profile(prof_base, ax[0, 0], ax[0, 1], title="baseline")
plot_ort_profile(prof_opti, ax[1, 0], ax[1, 1], title="optimized")
fig.tight_layout()
fig.savefig(f"plot_profiling{suffix}.png")
baseline, n occurences, optimized, n occurences

Merging profiles

Let’s try to compare both profiles assuming every iteration process the same image and the input and output size are the same at every iteration.

merge, gr = merge_ort_profile(prof_base, prof_opti)
merge.to_excel(f"plot_profiling_merged{suffix}.xlsx", index=False)
merge
~/github/onnx-array-api/onnx_array_api/ort/ort_profile.py:260: FutureWarning: The provided callable <function sum at 0x7ce342d63880> is currently using SeriesGroupBy.sum. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "sum" instead.
  .agg(
~/github/onnx-array-api/onnx_array_api/ort/ort_profile.py:260: FutureWarning: The provided callable <function sum at 0x7ce342d63880> is currently using SeriesGroupBy.sum. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "sum" instead.
  .agg(
args_op_name args_output_type_shape args_input_type_shape args_provider idx durbase countbase duropti countopti
0 Add [{'float': [1, 64, 56, 56]}] [{'float': [1, 64, 56, 56]}, {'float': [1, 64,... CPUExecutionProvider 0 603.0 6.0 NaN NaN
1 BatchNormalization [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}, {'float': [64]}... CPUExecutionProvider 0 3786.0 6.0 2795.0 6.0
2 Concat [{'float': [1, 2, 112, 112]}] [{'float': [1, 1, 112, 112]}, {'float': [1, 1,... CPUExecutionProvider 0 512.0 6.0 198.0 6.0
3 Conv [{'float': [1, 1, 112, 112]}] [{'float': [1, 2, 112, 112]}, {'float': [1, 2,... CPUExecutionProvider 0 1765.0 6.0 NaN NaN
4 Conv [{'float': [1, 64, 112, 112]}] [{'float': [1, 3, 112, 112]}, {'float': [64, 3... CPUExecutionProvider 0 4170.0 6.0 1436.0 6.0
5 Conv [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider 0 38055.0 6.0 20675.0 6.0
6 Conv [{'float': [1, 64, 56, 56]}] [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider 0 2732.0 6.0 901.0 6.0
7 Conv [{'float': [1, 64, 56, 56]}] [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider 0 NaN NaN 5302.0 6.0
8 Conv [{'float': [1, 64, 56, 56]}] [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider 0 11862.0 6.0 NaN NaN
9 Conv [{'float': [1, 8, 112, 112]}] [{'float': [1, 2, 112, 112]}, {'float': [8, 2,... CPUExecutionProvider 0 NaN NaN 1097.0 6.0
10 Mul [{'float': [1, 64, 112, 112]}] [{'float': [1, 1, 112, 112]}, {'float': [1, 64... CPUExecutionProvider 0 788.0 6.0 608.0 6.0
11 PRelu [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider 0 2733.0 6.0 805.0 6.0
12 PRelu [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}, {'float': [64, ... CPUExecutionProvider 1 1165.0 6.0 757.0 6.0
13 ReduceMax [{'float': [1, 1, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 0 1895.0 6.0 1258.0 6.0
14 ReduceMean [{'float': [1, 1, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 0 1906.0 6.0 1514.0 6.0
15 ReorderInput [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 0 NaN NaN 844.0 6.0
16 ReorderInput [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 1 NaN NaN 784.0 6.0
17 ReorderInput [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 2 NaN NaN 579.0 6.0
18 ReorderOutput [{'float': [1, 1, 112, 112]}] [{'float': [1, 8, 112, 112]}] CPUExecutionProvider 0 NaN NaN 309.0 6.0
19 ReorderOutput [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 0 NaN NaN 1247.0 6.0
20 ReorderOutput [{'float': [1, 64, 112, 112]}] [{'float': [1, 64, 112, 112]}] CPUExecutionProvider 1 NaN NaN 810.0 6.0
21 ReorderOutput [{'float': [1, 64, 56, 56]}] [{'float': [1, 64, 56, 56]}] CPUExecutionProvider 0 NaN NaN 304.0 6.0
22 Sigmoid [{'float': [1, 1, 112, 112]}] [{'float': [1, 1, 112, 112]}] CPUExecutionProvider 0 160.0 6.0 NaN NaN


More detailed

gr.to_excel(f"plot_profiling_merged_details{suffix}.xlsx", index=False)
gr
durbase duropti countbase countopti
label
[+CPU]Conv(f-1x2x112x112,f-8x2x7x7)->f-1x8x112x112 0.0 1097.0 0.0 6.0
[+CPU]Conv(f-1x64x112x112,f-64x64x3x3,f-64,f-1x64x56x56)->f-1x64x56x56 0.0 5302.0 0.0 6.0
[+CPU]ReorderInput(f-1x64x112x112)->f-1x64x112x112 0.0 2207.0 0.0 18.0
[+CPU]ReorderOutput(f-1x64x112x112)->f-1x64x112x112 0.0 2057.0 0.0 12.0
[+CPU]ReorderOutput(f-1x64x56x56)->f-1x64x56x56 0.0 304.0 0.0 6.0
[+CPU]ReorderOutput(f-1x8x112x112)->f-1x1x112x112 0.0 309.0 0.0 6.0
[-CPU]Add(f-1x64x56x56,f-1x64x56x56)->f-1x64x56x56 603.0 0.0 6.0 0.0
[-CPU]Conv(f-1x2x112x112,f-1x2x7x7)->f-1x1x112x112 1765.0 0.0 6.0 0.0
[-CPU]Conv(f-1x64x112x112,f-64x64x3x3,f-64)->f-1x64x56x56 11862.0 0.0 6.0 0.0
[-CPU]Sigmoid(f-1x1x112x112)->f-1x1x112x112 160.0 0.0 6.0 0.0
[=CPU]BatchNormalization(f-1x64x112x112,f-64,f-64,f-64,f-64)->f-1x64x112x112 3786.0 2795.0 6.0 6.0
[=CPU]Concat(f-1x1x112x112,f-1x1x112x112)->f-1x2x112x112 512.0 198.0 6.0 6.0
[=CPU]Conv(f-1x3x112x112,f-64x3x3x3,f-64)->f-1x64x112x112 4170.0 1436.0 6.0 6.0
[=CPU]Conv(f-1x64x112x112,f-64x64x1x1,f-64)->f-1x64x56x56 2732.0 901.0 6.0 6.0
[=CPU]Conv(f-1x64x112x112,f-64x64x3x3,f-64)->f-1x64x112x112 38055.0 20675.0 6.0 6.0
[=CPU]Mul(f-1x1x112x112,f-1x64x112x112)->f-1x64x112x112 788.0 608.0 6.0 6.0
[=CPU]PRelu(f-1x64x112x112,f-64x1x1)->f-1x64x112x112 3898.0 1562.0 12.0 12.0
[=CPU]ReduceMax(f-1x64x112x112)->f-1x1x112x112 1895.0 1258.0 6.0 6.0
[=CPU]ReduceMean(f-1x64x112x112)->f-1x1x112x112 1906.0 1514.0 6.0 6.0


Final plot

# let's filter out unsignificant operator.
grmax = gr["durbase"] + gr["duropti"]
total = grmax.sum()
grmax /= total
gr = gr[grmax >= 0.01]


fig, ax = plt.subplots(1, 2, figsize=(14, min(gr.shape[0], 500)), sharey=True)
gr[["durbase", "duropti"]].plot.barh(ax=ax[0])
ax[0].set_title("Side by side duration")
gr = gr.copy()
gr[["countbase", "countopti"]].plot.barh(ax=ax[1])
ax[1].set_title("Side by side count")
fig.tight_layout()
fig.savefig(f"plot_profiling_side_by_side{suffix}.png")
Side by side duration, Side by side count

On CUDA

if "CUDAExecutionProvider" in get_available_providers():
    print("Profiling on CUDA")
    prof_base = ort_profile(
        filename,
        feeds,
        repeat=6,
        disable_optimization=True,
        providers=["CUDAExecutionProvider"],
    )
    prof_base.to_excel(f"prof_cuda_base{suffix}.xlsx", index=False)

    prof_opti = ort_profile(
        optimized,
        feeds,
        repeat=6,
        disable_optimization=True,
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
    )
    prof_opti.to_excel(f"prof_cuda_opti{suffix}.xlsx", index=False)

    unique_op = set(prof_base["args_op_name"])
    fig, ax = plt.subplots(2, 2, figsize=(10, len(unique_op)), sharex="col")
    plot_ort_profile(prof_base, ax[0, 0], ax[0, 1], title="baseline")
    plot_ort_profile(prof_opti, ax[1, 0], ax[1, 1], title="optimized")
    fig.tight_layout()
    fig.savefig(f"plot_profiling_cuda{suffix}.png")

    merge, gr = merge_ort_profile(prof_base, prof_opti)
    merge.to_excel(f"plot_profiling_merged{suffix}.xlsx", index=False)
    gr.to_excel(f"plot_profiling_merged_details{suffix}.xlsx", index=False)

    grmax = gr["durbase"] + gr["duropti"]
    total = grmax.sum()
    grmax /= total
    gr = gr[grmax >= 0.01]

    fig, ax = plt.subplots(1, 2, figsize=(14, min(gr.shape[0], 500)), sharey=True)
    gr[["durbase", "duropti"]].plot.barh(ax=ax[0])
    ax[0].set_title("Side by side duration")
    gr = gr.copy()
    gr[["countbase", "countopti"]].plot.barh(ax=ax[1])
    ax[1].set_title("Side by side count")
    fig.tight_layout()
    fig.savefig(f"plot_profiling_side_by_side_cuda{suffix}.png")

else:
    print(f"CUDA not available in {get_available_providers()}.")
    fig, ax = None, None

ax
  • baseline, n occurences, optimized, n occurences
  • Side by side duration, Side by side count
Profiling on CUDA
~/github/onnx-array-api/onnx_array_api/ort/ort_profile.py:260: FutureWarning: The provided callable <function sum at 0x7ce342d63880> is currently using SeriesGroupBy.sum. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "sum" instead.
  .agg(
~/github/onnx-array-api/onnx_array_api/ort/ort_profile.py:260: FutureWarning: The provided callable <function sum at 0x7ce342d63880> is currently using SeriesGroupBy.sum. In a future version of pandas, the provided callable will be used directly. To keep current behavior pass the string "sum" instead.
  .agg(

array([<Axes: title={'center': 'Side by side duration'}, ylabel='label'>,
       <Axes: title={'center': 'Side by side count'}, ylabel='label'>],
      dtype=object)

Total running time of the script: (0 minutes 4.155 seconds)

Gallery generated by Sphinx-Gallery