201: Evaluate different ways to export a torch model to ONNX

The example evaluates the performance of onnxruntime of a simple torch model after it was converted into ONNX through different processes:

  • TorchScript-based ONNX Exporter, let’s call it script

  • TorchDynamo-based ONNX Exporter, let’s call it dynamo

  • if available, the previous model but optimized, dynopt

  • a custom exporter cus_p0, this exporter supports a very limited set of models, as dynamo, it relies on torch.fx but the design is closer to what tensorflow-onnx does.

  • the same exporter but unused nodes were removed and constants were folded, cus_p2

To run the script:

python _doc/examples/plot_torch_export --help

The script takes around 12 minutes with a larger models.

Some helpers

from experimental_experiment.args import get_parsed_args


script_args = get_parsed_args(
    "plot_torch_export",
    description=__doc__,
    scenarios={
        "small": "small model to test",
        "middle": "55Mb model",
        "large": "1Gb model",
    },
    warmup=5,
    repeat=5,
    maxtime=(
        2,
        "maximum time to run a model to measure the computation time, "
        "it is 0.1 when scenario is small",
    ),
    expose="scenarios,repeat,warmup",
)


import contextlib
import itertools
import os
import platform
import pprint
import multiprocessing
import time
import cProfile
import pstats
import io
import warnings
import logging
from pstats import SortKey

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import numpy as np
import matplotlib.pyplot as plt
import pandas
import onnx
from onnx_array_api.profiling import profile2graph
import torch
from torch import nn
import torch.nn.functional as F
import experimental_experiment
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.plotting.memory import memory_peak_plot
from experimental_experiment.ext_test_case import measure_time, get_figure
from experimental_experiment.memory_peak import start_spying_on
from experimental_experiment.ext_test_case import unit_test_going
from experimental_experiment.helpers import pretty_onnx
from tqdm import tqdm

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)


def system_info():
    obs = {}
    obs["processor"] = platform.processor()
    obs["cores"] = multiprocessing.cpu_count()
    try:
        obs["cuda"] = 1 if torch.cuda.is_available() else 0
        obs["cuda_count"] = torch.cuda.device_count()
        obs["cuda_name"] = torch.cuda.get_device_name()
        obs["cuda_capa"] = torch.cuda.get_device_capability()
    except (RuntimeError, AssertionError):
        # no cuda
        pass
    return obs


pprint.pprint(system_info())
{'cores': 20,
 'cuda': 1,
 'cuda_capa': (8, 9),
 'cuda_count': 1,
 'cuda_name': 'NVIDIA GeForce RTX 4060 Laptop GPU',
 'processor': 'x86_64'}

Scripts arguments

if script_args.scenario in (None, "small"):
    script_args.maxtime = 0.1

if unit_test_going():
    script_args.warmup = 1
    script_args.repeat = 1
    script_args.maxtime = 0.1
    script_args.scenario = "small"

print(f"scenario={script_args.scenario or 'small'}")
print(f"warmup={script_args.warmup}")
print(f"repeat={script_args.repeat}")
print(f"maxtime={script_args.maxtime}")
scenario=small
warmup=5
repeat=5
maxtime=0.1

The model

A simple model to convert.

class MyModelClass(nn.Module):
    def __init__(self, scenario=script_args.scenario):
        super().__init__()
        if scenario == "middle":
            self.large = False
            self.conv1 = nn.Conv2d(1, 128, 5)
            self.conv2 = nn.Conv2d(128, 16, 5)
            self.fc1 = nn.Linear(13456, 1024)
            self.fcs = []
            self.fc2 = nn.Linear(1024, 128)
            self.fc3 = nn.Linear(128, 10)
        elif scenario in (None, "small"):
            self.large = False
            self.conv1 = nn.Conv2d(1, 16, 5)
            self.conv2 = nn.Conv2d(16, 16, 5)
            self.fc1 = nn.Linear(16, 512)
            self.fcs = []
            self.fc2 = nn.Linear(512, 128)
            self.fc3 = nn.Linear(128, 10)
        elif scenario in (None, "large"):
            self.large = True
            self.conv1 = nn.Conv2d(1, 128, 5)
            self.conv2 = nn.Conv2d(128, 16, 5)
            self.fc1 = nn.Linear(13456, 4096)
            # torch script does not support loops.
            self.fca = nn.Linear(4096, 4096)
            self.fcb = nn.Linear(4096, 4096)
            self.fcc = nn.Linear(4096, 4096)
            self.fcd = nn.Linear(4096, 4096)
            self.fce = nn.Linear(4096, 4096)
            self.fcf = nn.Linear(4096, 4096)
            self.fcg = nn.Linear(4096, 4096)
            self.fch = nn.Linear(4096, 4096)
            self.fci = nn.Linear(4096, 4096)
            self.fck = nn.Linear(4096, 4096)
            self.fcl = nn.Linear(4096, 4096)
            self.fcm = nn.Linear(4096, 4096)
            self.fcn = nn.Linear(4096, 4096)
            # end of the unfolded loop.
            self.fc2 = nn.Linear(4096, 128)
            self.fc3 = nn.Linear(128, 10)
        else:
            raise ValueError(f"Unsupported scenario={scenario!r}.")

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        if self.large:
            # loop
            x = F.relu(self.fca(x))
            x = F.relu(self.fcb(x))
            x = F.relu(self.fcc(x))
            x = F.relu(self.fcd(x))
            x = F.relu(self.fce(x))
            x = F.relu(self.fcf(x))
            x = F.relu(self.fcg(x))
            x = F.relu(self.fch(x))
            x = F.relu(self.fci(x))
            x = F.relu(self.fck(x))
            x = F.relu(self.fcl(x))
            x = F.relu(self.fcm(x))
            x = F.relu(self.fcn(x))
            # end of the loop
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def create_model_and_input(scenario=script_args.scenario):
    if scenario == "middle":
        shape = [1, 1, 128, 128]
    elif scenario in (None, "small"):
        shape = [1, 1, 16, 16]
    elif scenario == "large":
        shape = [1, 1, 128, 128]
    else:
        raise ValueError(f"Unsupported scenario={scenario!r}.")
    input_tensor = torch.rand(*shape).to(torch.float32)
    model = MyModelClass(scenario=scenario)
    assert model(input_tensor) is not None
    return model, input_tensor


def torch_model_size(model):
    size_model = 0
    for param in model.parameters():
        size = param.numel() * torch.finfo(param.data.dtype).bits / 8
        size_model += size
    return size_model


model, input_tensor = create_model_and_input()
model_size = torch_model_size(model)
print(f"model size={model_size / 2 ** 20} Mb")
model size=0.31467437744140625 Mb

The exporters

def export_script(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            torch.onnx.export(model, *args, filename, input_names=["input"])


def export_dynamo(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            export_output = torch.onnx.export(model, args, dynamo=True)
            export_output.save(filename)


def export_dynopt(filename, model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            export_output = torch.onnx.export(model, args, dynamo=True)
            model_onnx = export_output.model_proto

            from experimental_experiment.convert.convert_helper import (
                optimize_model_proto_oxs,
            )

            optimized_model = optimize_model_proto_oxs(model_onnx)

            with open(filename, "wb") as f:
                f.write(optimized_model.SerializeToString())


def export_cus_p0(filename, model, *args):
    onx = to_onnx(model, tuple(args), input_names=["input"])
    with open(filename, "wb") as f:
        f.write(onx.SerializeToString())


def export_cus_p2(filename, model, *args):
    onx = to_onnx(
        model,
        tuple(args),
        input_names=["input"],
        options=OptimizationOptions(
            remove_unused=True,
            constant_folding=True,
        ),
    )
    with open(filename, "wb") as f:
        f.write(onx.SerializeToString())

Let’s check they are working.

export_functions = [
    export_script,
    export_dynamo,
    export_dynopt,
    export_cus_p0,
    export_cus_p2,
]

exporters = {f.__name__.replace("export_", ""): f for f in export_functions}

supported_exporters = {}
for k, v in exporters.items():
    print(f"run exporter {k}")
    filename = f"plot_torch_export_{k}.onnx"
    try:
        v(filename, model, input_tensor)
    except Exception as e:
        print(f"skipped due to {str(e)[:1000]}")
        continue
    supported_exporters[k] = v
    print(f"done. size={os.stat(filename).st_size / 2 ** 20:1.0f} Mb")
run exporter script
done. size=0 Mb
run exporter dynamo
done. size=0 Mb
run exporter dynopt
done. size=0 Mb
run exporter cus_p0
done. size=0 Mb
run exporter cus_p2
done. size=0 Mb

Exporter memory

def flatten(ps):
    obs = ps["cpu"].to_dict(unit=2**20)
    if "gpus" in ps:
        for i, g in enumerate(ps["gpus"]):
            for k, v in g.to_dict(unit=2**20).items():
                obs[f"gpu{i}_{k}"] = v
    return obs


data = []

for k, v in supported_exporters.items():
    print(f"run exporter for memory {k}")
    filename = f"plot_torch_export_{k}.onnx"
    if has_cuda:
        torch.cuda.set_device(0)
    stat = start_spying_on(cuda=1 if has_cuda else 0)
    v(filename, model, input_tensor)
    obs = flatten(stat.stop())
    print("done.")
    onx = onnx.load(filename)
    obs.update(dict(nodes=len(onx.graph.node), export=k))
    data.append(obs)

stat = start_spying_on(cuda=1 if has_cuda else 0)
exported_mod = torch.export.export(model, (input_tensor,))
obs = flatten(stat.stop())
obs.update(dict(export="torch.fx"))
data.append(obs)
run exporter for memory script
done.
run exporter for memory dynamo
done.
run exporter for memory dynopt
done.
run exporter for memory cus_p0
done.
run exporter for memory cus_p2
done.

The result.

df1 = pandas.DataFrame(data)
df1.to_csv("plot_torch_export_memory.csv", index=False)
df1.to_excel("plot_torch_export_memory.xlsx", index=False)
print(df1)

ax = memory_peak_plot(
    data,
    bars=[model_size * i / 2**20 for i in range(1, 5)],
    suptitle=f"Memory Consumption of the Export\nmodel size={model_size / 2**20:1.0f} Mb",
)
get_figure(ax).savefig("plot_torch_export_memory.png")
Memory Consumption of the Export model size=0 Mb, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)
          peak         mean   n        begin          end   gpu0_peak   gpu0_mean  gpu0_n  gpu0_begin    gpu0_end  nodes    export
0  3335.367188  3335.367188   5  3335.367188  3335.367188  412.617188  412.617188       5  412.617188  412.617188   12.0    script
1  3335.367188  3335.367188  65  3335.367188  3335.367188  412.617188  412.617188      65  412.617188  412.617188   17.0    dynamo
2  3335.367188  3335.367188  72  3335.367188  3335.367188  412.617188  412.617188      72  412.617188  412.617188   16.0    dynopt
3  3335.367188  3335.367188  19  3335.367188  3335.367188  412.617188  412.617188      19  412.617188  412.617188   12.0    cus_p0
4  3335.371094  3335.367746  21  3335.367188  3335.371094  412.617188  412.617188      21  412.617188  412.617188   12.0    cus_p2
5  3335.371094  3335.371094  18  3335.371094  3335.371094  412.617188  412.617188      18  412.617188  412.617188    NaN  torch.fx

Exporter speed

data = []

for k, v in supported_exporters.items():
    print(f"run exporter {k}")
    filename = f"plot_torch_export_{k}.onnx"
    times = []
    for _ in range(script_args.repeat):
        begin = time.perf_counter()
        v(filename, model, input_tensor)
        duration = time.perf_counter() - begin
        times.append(duration)
    onx = onnx.load(filename)
    print("done.")
    data.append(
        dict(
            export=k,
            time=np.mean(times),
            min=min(times),
            max=max(times),
            first=times[0],
            last=times[-1],
            std=np.std(times),
            nodes=len(onx.graph.node),
        )
    )
run exporter script
done.
run exporter dynamo
done.
run exporter dynopt
done.
run exporter cus_p0
done.
run exporter cus_p2
done.

The last export to measure time torch spends in export the model before any other export can begin the translation except the first one.

times = []
for _ in range(script_args.repeat):
    begin = time.perf_counter()
    exported_mod = torch.export.export(model, (input_tensor,))
    duration = time.perf_counter() - begin
    times.append(duration)
data.append(
    dict(
        export="torch.fx",
        time=np.mean(times),
        min=min(times),
        max=max(times),
        first=times[0],
        last=times[-1],
        std=np.std(times),
        nodes=len(onx.graph.node),
    )
)

The result.

df1 = pandas.DataFrame(data)
df1.to_csv("plot_torch_export_time.csv", index=False)
df1.to_excel("plot_torch_export_time.xlsx", index=False)
print(df1)

fig, ax = plt.subplots(1, 1)
dfi = df1[["export", "time", "std"]].set_index("export")
dfi["time"].plot.bar(ax=ax, title="Export time", yerr=dfi["std"], rot=30)
fig.tight_layout()
fig.savefig("plot_torch_export_time.png")
Export time
     export      time       min       max     first      last       std  nodes
0    script  0.057934  0.020471  0.105974  0.048177  0.037465  0.030370     12
1    dynamo  0.617479  0.466728  1.113500  1.113500  0.497046  0.248425     17
2    dynopt  0.626486  0.497569  1.085869  0.497569  0.515716  0.229825     16
3    cus_p0  0.152661  0.125809  0.177849  0.169642  0.125809  0.019003     12
4    cus_p2  0.144619  0.127507  0.181575  0.181575  0.127507  0.020863     12
5  torch.fx  0.105241  0.094353  0.116510  0.116510  0.103488  0.009439     12

Exporter Profiling

def clean_text(text):
    pathes = [
        os.path.abspath(os.path.normpath(os.path.join(os.path.dirname(torch.__file__), ".."))),
        os.path.abspath(os.path.normpath(os.path.join(os.path.dirname(onnx.__file__), ".."))),
        os.path.abspath(
            os.path.normpath(
                os.path.join(os.path.dirname(experimental_experiment.__file__), "..")
            )
        ),
    ]
    for p in pathes:
        text = text.replace(p, "")
    text = text.replace("experimental_experiment", "experimental_experiment".upper())
    return text


def profile_function(name, export_function, verbose=False):
    print(f"profile {name}: {export_function}")
    pr = cProfile.Profile()
    pr.enable()
    for _ in range(script_args.repeat):
        export_function("dummyc.onnx", model, input_tensor)
    pr.disable()
    s = io.StringIO()
    sortby = SortKey.CUMULATIVE
    ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
    ps.print_stats()

    raw = s.getvalue()
    text = "\n".join(raw.split("\n")[:200])
    if verbose:
        print(text)
    with open(f"plot_torch_export_profile_{name}.txt", "w") as f:
        f.write(raw)

    root, nodes = profile2graph(ps, clean_text=clean_text)
    text = root.to_text()
    with open(f"plot_torch_export_profile_{name}_h.txt", "w") as f:
        f.write(text)
    print("done.")


profile_function("custom0", export_cus_p0, True)
profile_function("custom2", export_cus_p2)
profile custom0: <function export_cus_p0 at 0x7f0082f1b640>
         1331946 function calls (1294826 primitive calls) in 1.155 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        5    0.004    0.001    1.182    0.236 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_export_201.py:281(export_cus_p0)
        5    0.000    0.000    1.173    0.235 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py:632(to_onnx)
        5    0.000    0.000    0.990    0.198 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py:302(_make_builder_interpreter)
        5    0.000    0.000    0.988    0.198 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py:111(export)
        5    0.000    0.000    0.988    0.198 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/__init__.py:263(export)
        5    0.000    0.000    0.988    0.198 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:997(wrapper)
        5    0.000    0.000    0.988    0.198 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:118(wrapper)
        5    0.000    0.000    0.987    0.197 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:1850(_export)
        5    0.000    0.000    0.972    0.194 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:1236(_strict_export)
        5    0.001    0.000    0.972    0.194 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:1264(_strict_export_lower_to_aten_ir)
        5    0.000    0.000    0.495    0.099 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:691(_export_to_aten_ir)
   120/55    0.000    0.000    0.492    0.009 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py:1732(_wrapped_call_impl)
   120/55    0.001    0.000    0.492    0.009 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py:1740(_call_impl)
        5    0.000    0.000    0.452    0.090 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:622(_export_to_torch_ir)
        5    0.000    0.000    0.450    0.090 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1403(inner)
        5    0.000    0.000    0.445    0.089 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1127(aot_export_module)
        5    0.000    0.000    0.444    0.089 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1434(_aot_export_function)
        5    0.000    0.000    0.443    0.089 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:516(create_aot_dispatcher_function)
        5    0.001    0.000    0.438    0.088 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:529(_create_aot_dispatcher_function)
        5    0.000    0.000    0.405    0.081 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:523(_fn)
        5    0.000    0.000    0.347    0.069 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:1331(__call__)
        5    0.000    0.000    0.346    0.069 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:449(__call__)
        5    0.001    0.000    0.345    0.069 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:597(_compile)
        5    0.002    0.000    0.337    0.067 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:689(compile_inner)
        5    0.000    0.000    0.334    0.067 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_utils_internal.py:89(wrapper_function)
        5    0.001    0.000    0.334    0.067 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:709(_compile_inner)
2560/2170    0.004    0.000    0.311    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_stats.py:16(wrapper)
        5    0.000    0.000    0.271    0.054 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:1329(transform_code_object)
        5    0.000    0.000    0.262    0.052 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:204(_fn)
        5    0.000    0.000    0.260    0.052 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:632(transform)
     2345    0.004    0.000    0.255    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1242(__torch_dispatch__)
    15/10    0.000    0.000    0.255    0.025 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:715(_fn)
     2345    0.011    0.000    0.250    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1768(dispatch)
        5    0.000    0.000    0.248    0.050 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:91(aot_dispatch_export)
        5    0.000    0.000    0.247    0.049 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:65(aot_dispatch_base_graph)
      800    0.005    0.000    0.231    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1326(_cached_dispatch_impl)
        5    0.000    0.000    0.230    0.046 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2907(run)
        5    0.000    0.000    0.230    0.046 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1110(run)
      280    0.001    0.000    0.230    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:998(step)
       15    0.002    0.000    0.226    0.015 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:117(run)
  430/425    0.014    0.000    0.222    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:372(__torch_dispatch__)
       10    0.000    0.000    0.222    0.022 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:178(flat_fn)
       10    0.000    0.000    0.221    0.022 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:848(functional_call)
      210    0.001    0.000    0.217    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:210(run_node)
      140    0.001    0.000    0.211    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:6471(run_node)
        5    0.000    0.000    0.204    0.041 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:46(_create_graph)
        5    0.000    0.000    0.202    0.040 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:2170(wrapped)
        5    0.000    0.000    0.202    0.040 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:2108(trace)
        5    0.000    0.000    0.201    0.040 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1999(_trace_inner)
        5    0.000    0.000    0.197    0.039 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_compile.py:22(inner)
        5    0.000    0.000    0.197    0.039 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1131(dispatch_trace)
        5    0.000    0.000    0.184    0.037 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1674(trace)
        5    0.000    0.000    0.183    0.037 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:711(trace)
       60    0.000    0.000    0.182    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:627(wrapper)
       60    0.000    0.000    0.181    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1728(CALL_FUNCTION)
       60    0.000    0.000    0.180    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:941(call_function)
       65    0.000    0.000    0.170    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:2079(wrap_fx_proxy)
       65    0.002    0.000    0.170    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:2141(wrap_fx_proxy_cls)
        5    0.000    0.000    0.167    0.033 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1181(wrapped)
        5    0.001    0.000    0.152    0.030 /home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py:4074(to_onnx)
       90    0.000    0.000    0.151    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:1700(wrap_fake_exception)
       60    0.001    0.000    0.150    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:2096(get_fake_value)
        5    0.001    0.000    0.147    0.029 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:171(inner)
        5    0.000    0.000    0.147    0.029 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:663(inner_fn)
        5    0.000    0.000    0.147    0.029 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:396(_functionalized_f_helper)
       25    0.001    0.000    0.137    0.005 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py:341(call_function)
        5    0.000    0.000    0.135    0.027 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:76(inner_fn)
     1865    0.005    0.000    0.127    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1230(__torch_function__)
     1865    0.002    0.000    0.120    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1259(__torch_function__)
  985/445    0.002    0.000    0.115    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:923(tree_map)
   125/60    0.001    0.000    0.114    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_ops.py:830(handler)
       70    0.000    0.000    0.111    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:288(call_function)
   125/60    0.001    0.000    0.111    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_library/utils.py:269(handle_dispatch_mode)
 3855/615    0.009    0.000    0.107    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:801(unflatten)
 2250/450    0.003    0.000    0.105    0.000 /usr/lib/python3.10/copy.py:128(deepcopy)
  770/660    0.001    0.000    0.105    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_ops.py:722(__call__)
       65    0.001    0.000    0.104    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:792(recompile)
       25    0.000    0.000    0.104    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:1711(deepcopy_to_fake_tensor)
  460/160    0.001    0.000    0.104    0.001 /usr/lib/python3.10/copy.py:259(_reconstruct)
       25    0.000    0.000    0.104    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:1713(<lambda>)
   110/35    0.001    0.000    0.102    0.003 /usr/lib/python3.10/copy.py:227(_deepcopy_dict)
       50    0.001    0.000    0.097    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/parameter.py:63(__deepcopy__)
      725    0.002    0.000    0.096    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1701(_output_from_cache_entry)
      250    0.001    0.000    0.095    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:2465(__torch_function__)
       70    0.001    0.000    0.094    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1562(python_code)
      765    0.011    0.000    0.094    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1635(_get_output_tensor_from_cache_entry)
  230/180    0.047    0.000    0.092    0.001 {method 'clone' of 'torch._C.TensorBase' objects}
       50    0.000    0.000    0.090    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:334(call_module)
        5    0.000    0.000    0.089    0.018 /home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py:4333(optimize)
   120/80    0.000    0.000    0.089    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/functional.py:1693(relu)
      800    0.003    0.000    0.088    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1369(_cache_key)
       80    0.004    0.000    0.088    0.001 {built-in method torch.relu}
        5    0.000    0.000    0.086    0.017 /home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py:4591(optimize_with_patterns)
        5    0.005    0.001    0.085    0.017 /home/xadupre/github/experimental-experiment/experimental_experiment/xoptim/graph_builder_optim.py:1008(optimize)
 3025/850    0.012    0.000    0.080    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1445(_prep_args_for_hash)
      880    0.002    0.000    0.077    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:1130(tree_map_only)
       70    0.001    0.000    0.077    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1639(_python_code)
       70    0.009    0.000    0.076    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:408(_gen_python_code)
  235/185    0.004    0.000    0.076    0.000 {method 'detach' of 'torch._C.TensorBase' objects}
9750/9290    0.005    0.000    0.075    0.000 {built-in method builtins.next}
       60    0.000    0.000    0.075    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/linear.py:124(forward)
    90/60    0.005    0.000    0.075    0.001 {built-in method torch._C._nn.linear}
       60    0.000    0.000    0.072    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1328(__torch_dispatch__)
       60    0.002    0.000    0.071    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:761(proxy_call)
      230    0.002    0.000    0.070    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:325(from_real_tensor)
      155    0.002    0.000    0.066    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:744(__torch_dispatch__)
      170    0.003    0.000    0.065    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:1588(__call__)
        5    0.000    0.000    0.062    0.012 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/guards.py:2107(__init__)
      970    0.013    0.000    0.061    0.000 /home/xadupre/github/experimental-experiment/experimental_experiment/xoptim/patterns_api.py:115(enumerate_matches)
   120/90    0.000    0.000    0.059    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_guards.py:296(create)
      110    0.000    0.000    0.058    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/functional_utils.py:35(to_fun)
       50    0.001    0.000    0.057    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:490(call_module)
      110    0.001    0.000    0.057    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:228(to_functional)
        5    0.000    0.000    0.057    0.011 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1453(result_capturing_wrapper)
       25    0.000    0.000    0.055    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:811(module_call_wrapper)
       25    0.000    0.000    0.054    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1724(call_module)
    60/30    0.000    0.000    0.054    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/overrides.py:1668(handle_torch_function)
       25    0.000    0.000    0.053    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:813(forward)
        5    0.001    0.000    0.051    0.010 /home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py:3599(_build_initializers)
       50    0.002    0.000    0.049    0.001 /home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/model_container.py:60(proto_from_array)
       50    0.000    0.000    0.049    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/parameter.py:40(__new__)
       40    0.000    0.000    0.047    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/conv.py:553(forward)
        5    0.000    0.000    0.047    0.009 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:411(_produce_aten_artifact)
       55    0.000    0.000    0.047    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:161(_to_fun)
       40    0.000    0.000    0.047    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/conv.py:536(_conv_forward)
    60/40    0.004    0.000    0.046    0.001 {built-in method torch.conv2d}
      165    0.000    0.000    0.044    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:2364(from_tensor)
4705/4480    0.003    0.000    0.044    0.000 /usr/lib/python3.10/contextlib.py:130(__enter__)
       65    0.004    0.000    0.042    0.001 {built-in method torch.tensor}
        5    0.000    0.000    0.042    0.008 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/functional_call.py:11(functional_call)
        5    0.000    0.000    0.042    0.008 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/utils/stateless.py:264(_functional_call)
       60    0.000    0.000    0.041    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:2146(<lambda>)
       35    0.001    0.000    0.041    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py:876(call_function)
       60    0.000    0.000    0.041    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:2236(run_node)
        5    0.000    0.000    0.041    0.008 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:115(_lazy_forward)
243560/240910    0.033    0.000    0.041    0.000 {built-in method builtins.isinstance}
       10    0.000    0.000    0.039    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:132(reset)
       30    0.000    0.000    0.039    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:74(__init__)
       30    0.000    0.000    0.038    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:77(reset)
      110    0.001    0.000    0.038    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/base.py:366(build)
4705/4480    0.004    0.000    0.038    0.000 /usr/lib/python3.10/contextlib.py:139(__exit__)
       75    0.001    0.000    0.037    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1801(_dispatch_impl)
    60/40    0.000    0.000    0.037    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_jit_internal.py:614(fn)
    60/40    0.000    0.000    0.037    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/functional.py:807(_max_pool2d)
      110    0.001    0.000    0.037    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:371(__call__)
       40    0.002    0.000    0.036    0.001 {built-in method torch.max_pool2d}
     1380    0.005    0.000    0.036    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:634(emit_node)
  610/540    0.002    0.000    0.035    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py:1935(__setattr__)
       25    0.001    0.000    0.035    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:437(__init__)
      925    0.012    0.000    0.035    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:933(_flatten_into)
      170    0.007    0.000    0.035    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:687(meta_tensor)
       65    0.003    0.000    0.034    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:536(_wrap)
        5    0.000    0.000    0.034    0.007 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:821(call_wrapped)
        5    0.000    0.000    0.034    0.007 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:382(__call__)
     1050    0.012    0.000    0.034    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:658(__new__)
     1195    0.001    0.000    0.034    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:1079(wrapped)
     1305    0.001    0.000    0.033    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:866(tree_flatten)
      240    0.001    0.000    0.033    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/proxy.py:209(create_proxy)
      925    0.009    0.000    0.032    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:953(extract_tensor_metadata)
4710/1305    0.007    0.000    0.032    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:845(_tree_flatten_helper)
       25    0.000    0.000    0.032    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:548(graph)
        5    0.000    0.000    0.031    0.006 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/guards.py:1715(SHAPE_ENV)
      110    0.002    0.000    0.030    0.000 {built-in method torch._to_functional_tensor}
        5    0.000    0.000    0.030    0.006 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1125(rewrite_signature)
        5    0.000    0.000    0.029    0.006 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2778(__init__)
        5    0.000    0.000    0.028    0.006 /home/xadupre/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py:3985(process)
      120    0.001    0.000    0.027    0.000 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py:144(run_node)
       65    0.000    0.000    0.026    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:593(track_tensor_tree)
   120/65    0.000    0.000    0.025    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:615(wrap_with_proxy)
       55    0.000    0.000    0.025    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:134(<lambda>)
        5    0.000    0.000    0.024    0.005 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/guards.py:1143(add_python_lambda_leaf_guard_to_root)
     6800    0.004    0.000    0.024    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:146(is_sparse_any)
      250    0.002    0.000    0.023    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/proxy.py:143(create_node)
      170    0.007    0.000    0.023    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:211(describe_tensor)
        5    0.000    0.000    0.023    0.005 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/guards.py:2414(build_guard_function)
       90    0.001    0.000    0.022    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/guards.py:1319(ID_MATCH)
      275    0.004    0.000    0.022    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:117(__new__)
       10    0.000    0.000    0.022    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/passes/replace_with_hop_pass_util.py:151(_replace_with_hop_pass_helper)
    30/20    0.001    0.000    0.021    0.001 {built-in method torch.flatten}
 2230/830    0.002    0.000    0.021    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:859(<listcomp>)
     4860    0.003    0.000    0.021    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/node.py:874(map_arg)
       65    0.000    0.000    0.021    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1869(LOAD_ATTR)
        5    0.000    0.000    0.020    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1068(transform)
        5    0.000    0.000    0.020    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:563(transform)
       65    0.000    0.000    0.020    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1862(_load_attr)
     5010    0.007    0.000    0.020    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/node.py:854(__setattr__)
        5    0.000    0.000    0.020    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:125(__enter__)
     30/5    0.000    0.000    0.020    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py:104(realize_all)
        5    0.000    0.000    0.020    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py:136(<dictcomp>)
       10    0.000    0.000    0.020    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py:61(realize)
       60    0.001    0.000    0.020    0.000 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py:1099(call_function)
       10    0.000    0.000    0.020    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py:20(realize)
9265/5415    0.009    0.000    0.020    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/node.py:883(map_aggregate)
       15    0.000    0.000    0.019    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:57(_make_graph_module)
      260    0.002    0.000    0.019    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1104(create_node)
done.
profile custom2: <function export_cus_p2 at 0x7f0082f1b5b0>
done.

Same with dynamo-exporter.

profile_function("dynamo", export_dynamo, verbose=True)
if "dynopt" in supported_exporters:
    profile_function("dynopt", export_dynopt)
profile dynamo: <function export_dynamo at 0x7f0082f1b6d0>
         6825857 function calls (6684312 primitive calls) in 5.269 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        5    0.000    0.000    5.385    1.077 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_export_201.py:256(export_dynamo)
        5    0.001    0.000    5.368    1.074 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/__init__.py:129(export)
        5    0.001    0.000    5.366    1.073 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_compat.py:114(export_compat)
        5    0.000    0.000    3.315    0.663 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_core.py:1006(export)
       10    0.000    0.000    3.094    0.309 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:118(wrapper)
        5    0.022    0.004    2.050    0.410 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_registration.py:133(from_torchlib)
        5    0.000    0.000    1.696    0.339 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py:99(__call__)
        5    0.000    0.000    1.695    0.339 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py:140(_capture)
        5    0.000    0.000    1.695    0.339 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/__init__.py:263(export)
        5    0.000    0.000    1.695    0.339 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:997(wrapper)
        5    0.000    0.000    1.694    0.339 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:1850(_export)
        5    0.000    0.000    1.676    0.335 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:1236(_strict_export)
        5    0.001    0.000    1.676    0.335 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:1264(_strict_export_lower_to_aten_ir)
        5    0.001    0.000    1.552    0.310 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_core.py:787(_prepare_exported_program_for_export)
        5    0.001    0.000    1.488    0.298 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_fx_passes.py:11(decompose_with_registry)
        5    0.000    0.000    1.399    0.280 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:1052(run_decompositions)
        5    0.027    0.005    1.285    0.257 /home/xadupre/github/onnxscript/onnxscript/_framework_apis/torch_2_5.py:107(get_torchlib_ops)
      920    0.008    0.000    1.252    0.001 /home/xadupre/github/onnxscript/onnxscript/values.py:640(function_ir)
   120/55    0.000    0.000    1.134    0.021 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py:1732(_wrapped_call_impl)
   120/55    0.001    0.000    1.134    0.021 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py:1740(_call_impl)
        5    0.000    0.000    1.081    0.216 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:622(_export_to_torch_ir)
        5    0.000    0.000    1.078    0.216 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1403(inner)
       10    0.000    0.000    1.047    0.105 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:691(_export_to_aten_ir)
        5    0.000    0.000    1.031    0.206 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:523(_fn)
       10    0.000    0.000    0.970    0.097 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1127(aot_export_module)
        5    0.000    0.000    0.969    0.194 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:1331(__call__)
        5    0.001    0.000    0.968    0.194 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:449(__call__)
       10    0.000    0.000    0.967    0.097 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1434(_aot_export_function)
        5    0.001    0.000    0.967    0.193 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:597(_compile)
       10    0.000    0.000    0.964    0.096 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:516(create_aot_dispatcher_function)
        5    0.000    0.000    0.958    0.192 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:689(compile_inner)
        5    0.000    0.000    0.957    0.191 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_utils_internal.py:89(wrapper_function)
        5    0.000    0.000    0.957    0.191 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:709(_compile_inner)
       10    0.002    0.000    0.951    0.095 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:529(_create_aot_dispatcher_function)
        5    0.000    0.000    0.932    0.186 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:669(_decompose_exported_program)
        5    0.001    0.000    0.913    0.183 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:322(_decompose_and_get_gm_with_new_signature_constants)
        5    0.000    0.000    0.889    0.178 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:1329(transform_code_object)
        5    0.000    0.000    0.881    0.176 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:204(_fn)
        5    0.000    0.000    0.879    0.176 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:632(transform)
      920    0.006    0.000    0.827    0.001 /home/xadupre/github/onnxscript/onnxscript/_internal/ast_utils.py:16(get_src_and_ast)
     2905    0.061    0.000    0.714    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_schemas.py:429(from_function)
      920    0.001    0.000    0.683    0.001 /usr/lib/python3.10/inspect.py:1133(getsource)
      920    0.019    0.000    0.681    0.001 /usr/lib/python3.10/inspect.py:1112(getsourcelines)
        5    0.000    0.000    0.627    0.125 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2778(__init__)
      920    0.101    0.000    0.607    0.001 /usr/lib/python3.10/inspect.py:1101(getblock)
        5    0.000    0.000    0.605    0.121 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:240(__init__)
        5    0.000    0.000    0.602    0.120 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1107(__init__)
        5    0.000    0.000    0.599    0.120 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:273(__init__)
        5    0.000    0.000    0.599    0.120 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:631(__init__)
        5    0.587    0.117    0.599    0.120 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:169(__init__)
4860/3845    0.007    0.000    0.565    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_stats.py:16(wrapper)
       10    0.000    0.000    0.555    0.056 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:91(aot_dispatch_export)
       10    0.001    0.000    0.554    0.055 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:65(aot_dispatch_base_graph)
1150/1085    0.036    0.000    0.504    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:372(__torch_dispatch__)
    25/15    0.000    0.000    0.498    0.033 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:715(_fn)
       30    0.004    0.000    0.492    0.016 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:117(run)
17590/16040    0.011    0.000    0.483    0.000 {built-in method builtins.next}
        5    0.005    0.001    0.467    0.093 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:269(_split_decomp_table_to_cia_and_python_decomp)
      585    0.002    0.000    0.459    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:210(run_node)
   133550    0.265    0.000    0.457    0.000 /usr/lib/python3.10/tokenize.py:431(_tokenize)
     4020    0.007    0.000    0.451    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1242(__torch_dispatch__)
       10    0.000    0.000    0.448    0.045 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:46(_create_graph)
       20    0.000    0.000    0.448    0.022 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:178(flat_fn)
       20    0.001    0.000    0.446    0.022 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:848(functional_call)
       10    0.000    0.000    0.445    0.044 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:2170(wrapped)
       10    0.000    0.000    0.445    0.044 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:2108(trace)
       10    0.000    0.000    0.443    0.044 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1999(_trace_inner)
     4020    0.021    0.000    0.442    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1768(dispatch)
       10    0.000    0.000    0.436    0.044 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_compile.py:22(inner)
       10    0.000    0.000    0.435    0.044 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1131(dispatch_trace)
        5    0.001    0.000    0.429    0.086 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/utils.py:1074(_collect_all_valid_cia_ops)
      140    0.005    0.000    0.429    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/utils.py:1057(_collect_all_valid_cia_ops_for_namespace)
      380    0.002    0.000    0.425    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:6471(run_node)
8540/7780    0.006    0.000    0.417    0.000 /usr/lib/python3.10/contextlib.py:130(__enter__)
       10    0.000    0.000    0.408    0.041 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1674(trace)
       10    0.001    0.000    0.408    0.041 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:711(trace)
     1230    0.010    0.000    0.407    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1326(_cached_dispatch_impl)
      140    0.162    0.001    0.393    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/utils.py:992(_materialize_cpp_cia_ops)
      920    0.002    0.000    0.388    0.000 /home/xadupre/github/onnxscript/onnxscript/converter.py:1463(translate_function_signature)
      920    0.029    0.000    0.383    0.000 /home/xadupre/github/onnxscript/onnxscript/converter.py:1378(_translate_function_signature_common)
       10    0.000    0.000    0.372    0.037 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1181(wrapped)
       20    0.089    0.004    0.345    0.017 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:188(_override_composite_implicit_decomp)
       10    0.001    0.000    0.338    0.034 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:663(inner_fn)
       10    0.000    0.000    0.337    0.034 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:396(_functionalized_f_helper)
2730/1520    0.002    0.000    0.335    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_ops.py:722(__call__)
      265    0.000    0.000    0.322    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:288(call_function)
       10    0.002    0.000    0.308    0.031 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:171(inner)
 2495/875    0.006    0.000    0.307    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:923(tree_map)
8910/1230    0.021    0.000    0.289    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:801(unflatten)
       10    0.000    0.000    0.264    0.026 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:76(inner_fn)
     3945    0.009    0.000    0.264    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1230(__torch_function__)
28035/5190    0.065    0.000    0.252    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_schemas.py:267(_get_allowed_types_from_type_annotation)
        5    0.000    0.000    0.251    0.050 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2907(run)
        5    0.000    0.000    0.251    0.050 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1110(run)
      280    0.001    0.000    0.251    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:998(step)
    18795    0.246    0.000    0.246    0.000 {built-in method builtins.compile}
37875/9100    0.042    0.000    0.230    0.000 /home/xadupre/github/onnxscript/onnxscript/type_annotation.py:131(is_value_type)
     2905    0.033    0.000    0.226    0.000 /usr/lib/python3.10/typing.py:1773(get_type_hints)
    65395    0.183    0.000    0.218    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_ops.py:106(inner)
      120    0.002    0.000    0.217    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:792(recompile)
       60    0.000    0.000    0.203    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:627(wrapper)
       60    0.000    0.000    0.202    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1728(CALL_FUNCTION)
       60    0.001    0.000    0.201    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:941(call_function)
     2280    0.005    0.000    0.197    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:1130(tree_map_only)
1186605/1179290    0.154    0.000    0.194    0.000 {built-in method builtins.isinstance}
       65    0.000    0.000    0.187    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:2079(wrap_fx_proxy)
      125    0.001    0.000    0.187    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1562(python_code)
       65    0.002    0.000    0.187    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:2141(wrap_fx_proxy_cls)
      220    0.001    0.000    0.172    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/functional_utils.py:35(to_fun)
      220    0.002    0.000    0.171    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:228(to_functional)
       60    0.001    0.000    0.168    0.003 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:2096(get_fake_value)
       90    0.000    0.000    0.167    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:1700(wrap_fake_exception)
     1080    0.004    0.000    0.161    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1701(_output_from_cache_entry)
      685    0.002    0.000    0.158    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1328(__torch_dispatch__)
     1130    0.018    0.000    0.157    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1635(_get_output_tensor_from_cache_entry)
       25    0.001    0.000    0.153    0.006 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py:341(call_function)
   522130    0.146    0.000    0.152    0.000 {built-in method builtins.getattr}
      125    0.001    0.000    0.152    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:1639(_python_code)
      125    0.018    0.000    0.151    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:408(_gen_python_code)
     5120    0.003    0.000    0.150    0.000 /home/xadupre/github/onnxscript/onnxscript/type_annotation.py:172(is_valid_type)
     1230    0.005    0.000    0.150    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1369(_cache_key)
      135    0.005    0.000    0.141    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:761(proxy_call)
4550/1280    0.019    0.000    0.137    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1445(_prep_args_for_hash)
     1865    0.002    0.000    0.127    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1259(__torch_function__)
4485/4005    0.009    0.000    0.126    0.000 /home/xadupre/github/onnxscript/onnxscript/type_annotation.py:150(<listcomp>)
 2865/460    0.005    0.000    0.125    0.000 /usr/lib/python3.10/copy.py:128(deepcopy)
  265/215    0.006    0.000    0.124    0.001 {method 'detach' of 'torch._C.TensorBase' objects}
   127360    0.122    0.000    0.122    0.000 {method 'match' of 're.Pattern' objects}
   125/60    0.001    0.000    0.121    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_ops.py:830(handler)
   125/60    0.001    0.000    0.118    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_library/utils.py:269(handle_dispatch_mode)
  630/305    0.002    0.000    0.117    0.000 /usr/lib/python3.10/copy.py:259(_reconstruct)
     1010    0.001    0.000    0.115    0.000 /usr/lib/python3.10/ast.py:33(parse)
       25    0.000    0.000    0.114    0.005 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:1711(deepcopy_to_fake_tensor)
     2950    0.002    0.000    0.114    0.000 /usr/lib/python3.10/inspect.py:3252(signature)
       25    0.000    0.000    0.114    0.005 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/utils.py:1713(<lambda>)
   170/70    0.001    0.000    0.113    0.002 /usr/lib/python3.10/copy.py:227(_deepcopy_dict)
48015/47965    0.020    0.000    0.112    0.000 {built-in method builtins.repr}
       70    0.000    0.000    0.112    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/utils.py:63(inner)
       70    0.001    0.000    0.112    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_higher_order_ops/utils.py:20(autograd_not_implemented_inner)
     2950    0.003    0.000    0.111    0.000 /usr/lib/python3.10/inspect.py:2998(from_callable)
2985/2950    0.016    0.000    0.108    0.000 /usr/lib/python3.10/inspect.py:2375(_signature_from_callable)
       50    0.000    0.000    0.107    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/interpreter.py:334(call_module)
       50    0.001    0.000    0.106    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/parameter.py:63(__deepcopy__)
      250    0.001    0.000    0.103    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:2465(__torch_function__)
   120/80    0.000    0.000    0.103    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/functional.py:1693(relu)
      110    0.000    0.000    0.102    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:161(_to_fun)
       80    0.004    0.000    0.102    0.001 {built-in method torch.relu}
      345    0.003    0.000    0.101    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:325(from_real_tensor)
    33780    0.015    0.000    0.096    0.000 /home/xadupre/github/onnxscript/onnxscript/ir/_core.py:1368(__hash__)
24030/10760    0.015    0.000    0.096    0.000 /usr/lib/python3.10/typing.py:320(_eval_type)
      220    0.005    0.000    0.095    0.000 {built-in method torch._to_functional_tensor}
      235    0.004    0.000    0.092    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/meta_utils.py:1588(__call__)
    37875    0.028    0.000    0.091    0.000 /home/xadupre/github/onnxscript/onnxscript/type_annotation.py:123(_is_tensor_type)
       60    0.000    0.000    0.088    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/linear.py:124(forward)
    10760    0.022    0.000    0.088    0.000 /usr/lib/python3.10/typing.py:679(_evaluate)
    90/60    0.006    0.000    0.088    0.001 {built-in method torch._C._nn.linear}
     2805    0.002    0.000    0.087    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:1079(wrapped)
      130    0.008    0.000    0.086    0.001 {built-in method torch.tensor}
     3980    0.002    0.000    0.085    0.000 /home/xadupre/github/onnxscript/onnxscript/type_annotation.py:168(is_attr_type)
       45    0.001    0.000    0.082    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:437(__init__)
 1125/985    0.004    0.000    0.081    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/nn/modules/module.py:1935(__setattr__)
8540/7780    0.008    0.000    0.080    0.000 /usr/lib/python3.10/contextlib.py:139(__exit__)
      150    0.004    0.000    0.079    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1801(_dispatch_impl)
       20    0.001    0.000    0.079    0.004 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:132(reset)
       60    0.000    0.000    0.077    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:74(__init__)
       60    0.000    0.000    0.077    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_decomp/decompositions_for_rng.py:77(reset)
  380/285    0.002    0.000    0.075    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_ops.py:757(decompose)
      155    0.002    0.000    0.074    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:744(__torch_dispatch__)
       45    0.000    0.000    0.074    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph_module.py:548(graph)
       10    0.001    0.000    0.072    0.007 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_trace.py:411(_produce_aten_artifact)
     6905    0.011    0.000    0.072    0.000 /home/xadupre/github/onnxscript/onnxscript/converter.py:451(_eval_constant_expr)
     2950    0.027    0.000    0.072    0.000 /usr/lib/python3.10/inspect.py:2280(_signature_from_function)
     3020    0.003    0.000    0.071    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:866(tree_flatten)
     2805    0.010    0.000    0.070    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/graph.py:634(emit_node)
    10760    0.010    0.000    0.069    0.000 /usr/lib/python3.10/typing.py:664(__init__)
      140    0.068    0.000    0.068    0.000 {built-in method torch._C._dispatch_get_registrations_for_dispatch_key}
9910/3020    0.016    0.000    0.068    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/utils/_pytree.py:845(_tree_flatten_helper)
        5    0.000    0.000    0.067    0.013 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/guards.py:2107(__init__)
      110    0.000    0.000    0.066    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:134(<lambda>)
    71020    0.041    0.000    0.066    0.000 /usr/lib/python3.10/typing.py:1902(get_origin)
    36805    0.031    0.000    0.066    0.000 /home/xadupre/github/onnxscript/onnxscript/ir/_core.py:1376(__repr__)
        5    0.014    0.003    0.065    0.013 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_decomp.py:15(get_onnx_implemented_overloads)
   120/90    0.000    0.000    0.064    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_guards.py:296(create)
    40260    0.021    0.000    0.064    0.000 /home/xadupre/github/onnxscript/onnxscript/type_annotation.py:70(_remove_annotation)
        5    0.001    0.000    0.063    0.013 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/onnx/_internal/exporter/_core.py:813(_exported_program_to_onnx_program)
       50    0.001    0.000    0.063    0.001 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:490(call_module)
        5    0.000    0.000    0.062    0.012 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1453(result_capturing_wrapper)
       30    0.000    0.000    0.061    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_export/utils.py:1110(_special_op_to_decompose_cia)
        5    0.000    0.000    0.060    0.012 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/exported_program.py:1024(module)
        5    0.000    0.000    0.060    0.012 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/export/_unlift.py:356(_unlift_exported_program_lifted_states)
       25    0.000    0.000    0.060    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:811(module_call_wrapper)
     1430    0.021    0.000    0.059    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:933(_flatten_into)
       25    0.000    0.000    0.059    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1724(call_module)
      535    0.009    0.000    0.058    0.000 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:117(__new__)
       25    0.000    0.000    0.058    0.002 /home/xadupre/vv/this/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:813(forward)
done.
profile dynopt: <function export_dynopt at 0x7f0082f1b760>
done.

Benchmark exported models with ORT

def benchmark(shape):
    from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel

    data = []
    data1 = []
    data_mem_load = []
    data_mem_first_run = []
    data_mem_run = []
    confs = list(
        itertools.product(
            [_ for _ in os.listdir(".") if ".onnx" in _ and _.startswith("plot_torch")],
            [
                ["CPUExecutionProvider"],
                ["CUDAExecutionProvider", "CPUExecutionProvider"],
            ],
            ["0", "1"],
        )
    )
    loop = tqdm(confs)
    print(f"number of experiments: {len(loop)}")
    for name, ps, aot in loop:
        root = os.path.split(name)[-1]
        _, ext = os.path.splitext(root)
        if ext != ".onnx":
            continue

        obs = {}  # system_info()
        obs["name"] = name
        obs["providers"] = ",".join(ps)
        p = "CUDA" if "CUDA" in obs["providers"] else "CPU"
        obs["compute"] = p
        obs["aot"] = 1 if aot == "0" else 0
        obs["export"] = name.replace("plot_torch_export_", "").replace(".onnx", "")

        if not has_cuda and p == "CUDA":
            continue

        onx = onnx.load(name)
        obs["n_nodes"] = len(onx.graph.node)
        obs["n_function"] = len(onx.functions or [])
        obs["n_sub"] = len([n for n in onx.graph.node if n.op_type == "Sub"])
        obs1 = obs.copy()
        short_obs = dict(
            name=obs["name"],
            aot=obs["aot"],
            providers=obs["providers"],
            export=obs["export"],
            compute=obs["compute"],
        )

        opts = SessionOptions()
        opts.add_session_config_entry("session.disable_aot_function_inlining", aot)
        opts.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        opts.optimized_model_filepath = (
            f"ort-{name.replace('.onnx', '')}-{p.lower()}-aot{1 if aot == '0' else 0}.onnx"
        )

        try:
            InferenceSession(name, opts, providers=ps)
        except Exception as e:
            loop.set_description(f"ERROR-load: {name} {e}")
            obs.update({"error": e, "step": "run"})
            data.append(obs)
            continue

        opts = SessionOptions()
        opts.add_session_config_entry("session.disable_aot_function_inlining", aot)
        opts.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        stat = start_spying_on(cuda=1 if has_cuda else 0)
        sess = InferenceSession(name, opts, providers=ps)
        memobs = flatten(stat.stop())
        memobs.update(short_obs)
        data_mem_load.append(memobs)

        input_name = sess.get_inputs()[0].name
        feeds = {input_name: np.random.rand(*shape).astype(np.float32)}

        stat = start_spying_on(cuda=1 if has_cuda else 0)
        try:
            sess.run(None, feeds)
        except Exception as e:
            loop.set_description(f"ERROR-run: {name} {e}")
            obs.update({"error": e, "step": "load"})
            data.append(obs)
            stat.stop()
            continue
        memobs = flatten(stat.stop())
        memobs.update(short_obs)
        data_mem_first_run.append(memobs)

        # memory consumption
        stat = start_spying_on(cuda=1 if has_cuda else 0)
        for _ in range(0, script_args.warmup):
            sess.run(None, feeds)
        memobs = flatten(stat.stop())
        memobs.update(short_obs)
        data_mem_run.append(memobs)

        obs.update(
            measure_time(
                lambda sess=sess, feeds=feeds: sess.run(None, feeds),
                max_time=script_args.maxtime,
                repeat=script_args.repeat,
                number=1,
            )
        )

        loop.set_description(f"{obs['average']} {name} {ps}")
        data.append(obs)

        # check first run
        obs1.update(
            measure_time(
                lambda name=name, opts=opts, ps=ps, feeds=feeds: InferenceSession(
                    name, opts, providers=ps
                ).run(None, feeds),
                max_time=script_args.maxtime,
                repeat=max(1, script_args.repeat // 2),
                number=1,
            )
        )
        data1.append(obs1)

    df = pandas.DataFrame(data)
    df.to_csv("plot_torch_export_ort_time.csv", index=False)
    df.to_excel("plot_torch_export_ort_time.xlsx", index=False)
    df1 = pandas.DataFrame(data1)
    df1.to_csv("plot_torch_export_ort_time1_init.csv", index=False)
    df1.to_excel("plot_torch_export_ort_time1_init.xlsx", index=False)
    dfmem = pandas.DataFrame(data_mem_load)
    dfmem.to_csv("plot_torch_export_ort_load_mem.csv", index=False)
    dfmem.to_excel("plot_torch_export_ort_load_mem.xlsx", index=False)
    dfmemr = pandas.DataFrame(data_mem_run)
    dfmemr.to_csv("plot_torch_export_ort_run_mem.csv", index=False)
    dfmemr.to_excel("plot_torch_export_ort_run_mem.xlsx", index=False)
    dfmemfr = pandas.DataFrame(data_mem_first_run)
    dfmemfr.to_csv("plot_torch_export_ort_first_run_mem.csv", index=False)
    dfmemfr.to_excel("plot_torch_export_ort_first_run_mem.xlsx", index=False)
    return df, df1, dfmem, dfmemfr, dfmemr


df, df_init, dfmem, dfmemfr, dfmemr = benchmark(list(input_tensor.shape))
print(df)
  0%|          | 0/20 [00:00<?, ?it/s]number of experiments: 20

8.653420389262618e-05 plot_torch_export_cus_p2.onnx ['CPUExecutionProvider']:   0%|          | 0/20 [00:00<?, ?it/s]
8.653420389262618e-05 plot_torch_export_cus_p2.onnx ['CPUExecutionProvider']:   5%|▌         | 1/20 [00:00<00:11,  1.67it/s]
7.380301669828995e-05 plot_torch_export_cus_p2.onnx ['CPUExecutionProvider']:   5%|▌         | 1/20 [00:01<00:11,  1.67it/s]
7.380301669828995e-05 plot_torch_export_cus_p2.onnx ['CPUExecutionProvider']:  10%|█         | 2/20 [00:01<00:10,  1.65it/s]
0.0007605597926292965 plot_torch_export_cus_p2.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  10%|█         | 2/20 [00:01<00:10,  1.65it/s]
0.0007605597926292965 plot_torch_export_cus_p2.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  15%|█▌        | 3/20 [00:02<00:12,  1.39it/s]
0.0009669082782119918 plot_torch_export_cus_p2.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  15%|█▌        | 3/20 [00:03<00:12,  1.39it/s]
0.0009669082782119918 plot_torch_export_cus_p2.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  20%|██        | 4/20 [00:03<00:15,  1.01it/s]
0.00030176520232021045 plot_torch_export_dynopt.onnx ['CPUExecutionProvider']:  20%|██        | 4/20 [00:04<00:15,  1.01it/s]
0.00030176520232021045 plot_torch_export_dynopt.onnx ['CPUExecutionProvider']:  25%|██▌       | 5/20 [00:04<00:17,  1.15s/it]
0.0003192254260644395 plot_torch_export_dynopt.onnx ['CPUExecutionProvider']:  25%|██▌       | 5/20 [00:06<00:17,  1.15s/it]
0.0003192254260644395 plot_torch_export_dynopt.onnx ['CPUExecutionProvider']:  30%|███       | 6/20 [00:06<00:16,  1.21s/it]
0.0019285396296321847 plot_torch_export_dynopt.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  30%|███       | 6/20 [00:07<00:16,  1.21s/it]
0.0019285396296321847 plot_torch_export_dynopt.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  35%|███▌      | 7/20 [00:07<00:16,  1.28s/it]
0.0019344716316130903 plot_torch_export_dynopt.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  35%|███▌      | 7/20 [00:08<00:16,  1.28s/it]
0.0019344716316130903 plot_torch_export_dynopt.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  40%|████      | 8/20 [00:08<00:14,  1.24s/it]
0.0002802157974610388 plot_torch_export_dynamo.onnx ['CPUExecutionProvider']:  40%|████      | 8/20 [00:09<00:14,  1.24s/it]
0.0002802157974610388 plot_torch_export_dynamo.onnx ['CPUExecutionProvider']:  45%|████▌     | 9/20 [00:09<00:13,  1.20s/it]
0.0003851877096741031 plot_torch_export_dynamo.onnx ['CPUExecutionProvider']:  45%|████▌     | 9/20 [00:10<00:13,  1.20s/it]
0.0003851877096741031 plot_torch_export_dynamo.onnx ['CPUExecutionProvider']:  50%|█████     | 10/20 [00:11<00:12,  1.20s/it]
0.002496057195992976 plot_torch_export_dynamo.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  50%|█████     | 10/20 [00:12<00:12,  1.20s/it]
0.002496057195992976 plot_torch_export_dynamo.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  55%|█████▌    | 11/20 [00:12<00:10,  1.22s/it]
0.005621908095539159 plot_torch_export_dynamo.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  55%|█████▌    | 11/20 [00:13<00:10,  1.22s/it]
0.005621908095539159 plot_torch_export_dynamo.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  60%|██████    | 12/20 [00:13<00:10,  1.30s/it]
0.00017777336808137624 plot_torch_export_script.onnx ['CPUExecutionProvider']:  60%|██████    | 12/20 [00:14<00:10,  1.30s/it]
0.00017777336808137624 plot_torch_export_script.onnx ['CPUExecutionProvider']:  65%|██████▌   | 13/20 [00:15<00:10,  1.45s/it]
0.00019911070765523612 plot_torch_export_script.onnx ['CPUExecutionProvider']:  65%|██████▌   | 13/20 [00:16<00:10,  1.45s/it]
0.00019911070765523612 plot_torch_export_script.onnx ['CPUExecutionProvider']:  70%|███████   | 14/20 [00:16<00:07,  1.29s/it]
0.001263198761929137 plot_torch_export_script.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  70%|███████   | 14/20 [00:18<00:07,  1.29s/it]
0.001263198761929137 plot_torch_export_script.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  75%|███████▌  | 15/20 [00:18<00:07,  1.46s/it]
0.001176670557853006 plot_torch_export_script.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  75%|███████▌  | 15/20 [00:19<00:07,  1.46s/it]
0.001176670557853006 plot_torch_export_script.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  80%|████████  | 16/20 [00:19<00:05,  1.32s/it]
0.00027771701291297383 plot_torch_export_cus_p0.onnx ['CPUExecutionProvider']:  80%|████████  | 16/20 [00:20<00:05,  1.32s/it]
0.00027771701291297383 plot_torch_export_cus_p0.onnx ['CPUExecutionProvider']:  85%|████████▌ | 17/20 [00:20<00:03,  1.29s/it]
0.0002739880473662899 plot_torch_export_cus_p0.onnx ['CPUExecutionProvider']:  85%|████████▌ | 17/20 [00:21<00:03,  1.29s/it]
0.0002739880473662899 plot_torch_export_cus_p0.onnx ['CPUExecutionProvider']:  90%|█████████ | 18/20 [00:21<00:02,  1.24s/it]
0.0017883641690197795 plot_torch_export_cus_p0.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  90%|█████████ | 18/20 [00:22<00:02,  1.24s/it]
0.0017883641690197795 plot_torch_export_cus_p0.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  95%|█████████▌| 19/20 [00:23<00:01,  1.26s/it]
0.0016247789135377388 plot_torch_export_cus_p0.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']:  95%|█████████▌| 19/20 [00:24<00:01,  1.26s/it]
0.0016247789135377388 plot_torch_export_cus_p0.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']: 100%|██████████| 20/20 [00:24<00:00,  1.25s/it]
0.0016247789135377388 plot_torch_export_cus_p0.onnx ['CUDAExecutionProvider', 'CPUExecutionProvider']: 100%|██████████| 20/20 [00:24<00:00,  1.22s/it]
                             name                                   providers compute  aot  export  n_nodes  n_function  n_sub   average  deviation  min_exec  max_exec  repeat  number     ttime  context_size  warmup_time
0   plot_torch_export_cus_p2.onnx                        CPUExecutionProvider     CPU    1  cus_p2       12           0      0  0.000087   0.000010  0.000055  0.000093       1  1491.0  0.129022            64     0.000295
1   plot_torch_export_cus_p2.onnx                        CPUExecutionProvider     CPU    0  cus_p2       12           0      0  0.000074   0.000002  0.000054  0.000093       1  1497.0  0.110483            64     0.000296
2   plot_torch_export_cus_p2.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    1  cus_p2       12           0      0  0.000761   0.000135  0.000661  0.000986       1   135.0  0.102676            64     0.002943
3   plot_torch_export_cus_p2.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    0  cus_p2       12           0      0  0.000967   0.000227  0.000843  0.002806       1   115.0  0.111194            64     0.004135
4   plot_torch_export_dynopt.onnx                        CPUExecutionProvider     CPU    1  dynopt       16           0      0  0.000302   0.000030  0.000272  0.000458       1   519.0  0.156616            64     0.001024
5   plot_torch_export_dynopt.onnx                        CPUExecutionProvider     CPU    0  dynopt       16           0      0  0.000319   0.000021  0.000202  0.000348       1   399.0  0.127371            64     0.001097
6   plot_torch_export_dynopt.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    1  dynopt       16           0      0  0.001929   0.000054  0.001597  0.001976       1    54.0  0.104141            64     0.003662
7   plot_torch_export_dynopt.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    0  dynopt       16           0      0  0.001934   0.000302  0.001513  0.002162       1    57.0  0.110265            64     0.002720
8   plot_torch_export_dynamo.onnx                        CPUExecutionProvider     CPU    1  dynamo       17           2      0  0.000280   0.000033  0.000254  0.000502       1   395.0  0.110685            64     0.000930
9   plot_torch_export_dynamo.onnx                        CPUExecutionProvider     CPU    0  dynamo       17           2      0  0.000385   0.000081  0.000297  0.000579       1   279.0  0.107467            64     0.000834
10  plot_torch_export_dynamo.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    1  dynamo       17           2      0  0.002496   0.000154  0.001779  0.002560       1    51.0  0.127299            64     0.004430
11  plot_torch_export_dynamo.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    0  dynamo       17           2      0  0.005622   0.001042  0.003066  0.006478       1    21.0  0.118060            64     0.003444
12  plot_torch_export_script.onnx                        CPUExecutionProvider     CPU    1  script       12           0      0  0.000178   0.000012  0.000144  0.000186       1   633.0  0.112531            64     0.000657
13  plot_torch_export_script.onnx                        CPUExecutionProvider     CPU    0  script       12           0      0  0.000199   0.000008  0.000146  0.000241       1   561.0  0.111701            64     0.000890
14  plot_torch_export_script.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    1  script       12           0      0  0.001263   0.000213  0.001002  0.001935       1   105.0  0.132636            64     0.003027
15  plot_torch_export_script.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    0  script       12           0      0  0.001177   0.000141  0.001073  0.001796       1    95.0  0.111784            64     0.002309
16  plot_torch_export_cus_p0.onnx                        CPUExecutionProvider     CPU    1  cus_p0       12           0      0  0.000278   0.000053  0.000206  0.000797       1   465.0  0.129138            64     0.001195
17  plot_torch_export_cus_p0.onnx                        CPUExecutionProvider     CPU    0  cus_p0       12           0      0  0.000274   0.000034  0.000150  0.000298       1   570.0  0.156173            64     0.000731
18  plot_torch_export_cus_p0.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    1  cus_p0       12           0      0  0.001788   0.000126  0.001708  0.002007       1    71.0  0.126974            64     0.002796
19  plot_torch_export_cus_p0.onnx  CUDAExecutionProvider,CPUExecutionProvider    CUDA    0  cus_p0       12           0      0  0.001625   0.000305  0.001214  0.003470       1    81.0  0.131607            64     0.004696

Other view

def view_time(df, title, suffix="time"):
    piv = pandas.pivot_table(df, index="export", columns=["compute", "aot"], values="average")
    print(piv)
    piv.to_csv(f"plot_torch_export_ort_{suffix}_compute.csv")
    piv.to_excel(f"plot_torch_export_ort_{suffix}_compute.xlsx")

    piv_cpu = pandas.pivot_table(
        df[df.compute == "CPU"],
        index="export",
        columns=["compute", "aot"],
        values="average",
    )

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    fig.suptitle(title)
    piv_cpu.plot.barh(ax=ax[0], title="CPU")

    if has_cuda:
        piv_gpu = pandas.pivot_table(
            df[df.compute == "CUDA"],
            index="export",
            columns=["compute", "aot"],
            values="average",
        )
        piv_gpu.plot.barh(ax=ax[1], title="CUDA")

    fig.tight_layout()
    fig.savefig(f"plot_torch_export_ort_{suffix}.png")
    return ax


view_time(df, "Compares onnxruntime time on exported models")
Compares onnxruntime time on exported models, CPU, CUDA
compute       CPU                CUDA
aot             0         1         0         1
export
cus_p0   0.000274  0.000278  0.001625  0.001788
cus_p2   0.000074  0.000087  0.000967  0.000761
dynamo   0.000385  0.000280  0.005622  0.002496
dynopt   0.000319  0.000302  0.001934  0.001929
script   0.000199  0.000178  0.001177  0.001263

array([<Axes: title={'center': 'CPU'}, ylabel='export'>,
       <Axes: title={'center': 'CUDA'}, ylabel='export'>], dtype=object)

New graph without the very long times.

piv_cpu = pandas.pivot_table(
    df[
        (df.compute == "CPU")
        & ((df.aot == 1) | ((df.export != "dynamo") & (df.export != "dynopt")))
    ],
    index="export",
    columns=["compute", "aot"],
    values="average",
)

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
fig.suptitle("Compares onnxruntime time on exported models\nHide dynamo without AOT")
piv_cpu.plot.barh(ax=ax[0], title="CPU")

if has_cuda:
    piv_gpu = pandas.pivot_table(
        df[df.compute == "CUDA"],
        index="export",
        columns=["compute", "aot"],
        values="average",
    )
    piv_gpu.plot.barh(ax=ax[1], title="CUDA")

fig.tight_layout()
fig.savefig("plot_torch_export_ort_time_2.png")
Compares onnxruntime time on exported models Hide dynamo without AOT, CPU, CUDA

Let’s do the same with the loading time + the first run.

view_time(
    df_init,
    "Compares onnxruntime loading time and first run on exported models",
    suffix="time1_init",
)
Compares onnxruntime loading time and first run on exported models, CPU, CUDA
compute       CPU                CUDA
aot             0         1         0         1
export
cus_p0   0.030264  0.029011  0.046317  0.287893
cus_p2   0.010272  0.014303  0.035720  0.018086
dynamo   0.035333  0.034049  0.093727  0.049105
dynopt   0.129071  0.034321  0.042326  0.043476
script   0.019742  0.209895  0.037691  0.036885

array([<Axes: title={'center': 'CPU'}, ylabel='export'>,
       <Axes: title={'center': 'CUDA'}, ylabel='export'>], dtype=object)

Memory Loading Time (ORT)

for compute in ["CPU", "CUDA"]:
    if not has_cuda and compute == "CUDA":
        continue
    ax = memory_peak_plot(
        dfmem[dfmem.compute == compute],
        ("export", "aot"),
        suptitle=f"Memory Consumption of onnxruntime loading time\nrunning on {compute}",
        bars=[model_size * i / 2**20 for i in range(1, 3)],
        figsize=(18, 6),
    )
    get_figure(ax).savefig(f"plot_torch_export_ort_load_mem_{compute}.png")
  • Memory Consumption of onnxruntime loading time running on CPU, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)
  • Memory Consumption of onnxruntime loading time running on CUDA, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)

Memory First Running Time (ORT)

for compute in ["CPU", "CUDA"]:
    if not has_cuda and compute == "CUDA":
        continue
    ax = memory_peak_plot(
        dfmemfr[dfmemfr.compute == compute],
        ("export", "aot"),
        suptitle=f"Memory Consumption of onnxruntime first running time"
        f"\nrunning on {compute}",
        bars=[model_size * i / 2**20 for i in range(1, 3)],
        figsize=(18, 6),
    )
    get_figure(ax).savefig(f"plot_torch_export_ort_first_run_mem_{compute}.png")
  • Memory Consumption of onnxruntime first running time running on CPU, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)
  • Memory Consumption of onnxruntime first running time running on CUDA, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)

Memory Running Time (ORT)

for compute in ["CPU", "CUDA"]:
    if not has_cuda and compute == "CUDA":
        continue
    ax = memory_peak_plot(
        dfmemr[dfmemr.compute == compute],
        ("export", "aot"),
        suptitle=f"Memory Consumption of onnxruntime running time\nrunning on {compute}",
        bars=[model_size * i / 2**20 for i in range(1, 3)],
        figsize=(18, 6),
    )
    get_figure(ax).savefig(f"plot_torch_export_ort_run_mem_{compute}.png")
  • Memory Consumption of onnxruntime running time running on CPU, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)
  • Memory Consumption of onnxruntime running time running on CUDA, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb), GPU Memory peak (Mb), GPU Memory peak - memory begin (Mb), GPU Memory average - memory begin (Mb)

Show the interesting models for CPU

script

model = "ort-plot_torch_export_cus_p2-cpu-aot0.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='reorder' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='p_conv1_bias' type=dtype('float32') shape=(16,)
init: name='reorder_token_1' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='p_conv2_bias' type=dtype('float32') shape=(16,)
init: name='p_fc1_bias' type=dtype('float32') shape=(512,)
init: name='p_fc2_bias' type=dtype('float32') shape=(128,)
init: name='p_fc3_bias' type=dtype('float32') shape=(10,)
init: name='init7_s2_1_16' type=dtype('int64') shape=(2,) -- array([ 1, 16])
init: name='GemmTransposePattern--_onx_transpose0' type=dtype('float32') shape=(512, 16)
init: name='GemmTransposePattern--_onx_transpose02' type=dtype('float32') shape=(128, 512)
init: name='GemmTransposePattern--_onx_transpose03' type=dtype('float32') shape=(10, 128)
Conv[com.microsoft.nchwc](input, reorder, p_conv1_bias, activation=b'Relu', dilations=[1,1], group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET') -> reorder_token_0
  ReorderOutput[com.microsoft.nchwc](reorder_token_0, channels_last=0, channels=16) -> relu
    MaxPool(relu, storage_order=0, auto_pad=b'NOTSET', ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool0, _onx_maxpool1
      ReorderInput[com.microsoft.nchwc](_onx_maxpool0, channels_last=0) -> reorder_token_2
        Conv[com.microsoft.nchwc](reorder_token_2, reorder_token_1, p_conv2_bias, activation=b'Relu', dilations=[1,1], group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET') -> reorder_token_3
          ReorderOutput[com.microsoft.nchwc](reorder_token_3, channels_last=0, channels=16) -> relu_1
            MaxPool(relu_1, storage_order=0, auto_pad=b'NOTSET', ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool02, _onx_maxpool12
              Reshape(_onx_maxpool02, init7_s2_1_16, allowzero=0) -> view
                FusedGemm[com.microsoft](view, GemmTransposePattern--_onx_transpose0, p_fc1_bias, transA=0, beta=1.00, activation=b'Relu', transB=1, alpha=1.00) -> relu_2
                  FusedGemm[com.microsoft](relu_2, GemmTransposePattern--_onx_transpose02, p_fc2_bias, transA=0, beta=1.00, activation=b'Relu', transB=1, alpha=1.00) -> relu_3
                    Gemm(relu_3, GemmTransposePattern--_onx_transpose03, p_fc3_bias, transA=0, beta=1.00, transB=1, alpha=1.00) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 10]

cus_p2

model = "ort-plot_torch_export_cus_p2-cpu-aot0.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='reorder' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='p_conv1_bias' type=dtype('float32') shape=(16,)
init: name='reorder_token_1' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='p_conv2_bias' type=dtype('float32') shape=(16,)
init: name='p_fc1_bias' type=dtype('float32') shape=(512,)
init: name='p_fc2_bias' type=dtype('float32') shape=(128,)
init: name='p_fc3_bias' type=dtype('float32') shape=(10,)
init: name='init7_s2_1_16' type=dtype('int64') shape=(2,) -- array([ 1, 16])
init: name='GemmTransposePattern--_onx_transpose0' type=dtype('float32') shape=(512, 16)
init: name='GemmTransposePattern--_onx_transpose02' type=dtype('float32') shape=(128, 512)
init: name='GemmTransposePattern--_onx_transpose03' type=dtype('float32') shape=(10, 128)
Conv[com.microsoft.nchwc](input, reorder, p_conv1_bias, activation=b'Relu', dilations=[1,1], group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET') -> reorder_token_0
  ReorderOutput[com.microsoft.nchwc](reorder_token_0, channels_last=0, channels=16) -> relu
    MaxPool(relu, storage_order=0, auto_pad=b'NOTSET', ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool0, _onx_maxpool1
      ReorderInput[com.microsoft.nchwc](_onx_maxpool0, channels_last=0) -> reorder_token_2
        Conv[com.microsoft.nchwc](reorder_token_2, reorder_token_1, p_conv2_bias, activation=b'Relu', dilations=[1,1], group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET') -> reorder_token_3
          ReorderOutput[com.microsoft.nchwc](reorder_token_3, channels_last=0, channels=16) -> relu_1
            MaxPool(relu_1, storage_order=0, auto_pad=b'NOTSET', ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool02, _onx_maxpool12
              Reshape(_onx_maxpool02, init7_s2_1_16, allowzero=0) -> view
                FusedGemm[com.microsoft](view, GemmTransposePattern--_onx_transpose0, p_fc1_bias, transA=0, beta=1.00, activation=b'Relu', transB=1, alpha=1.00) -> relu_2
                  FusedGemm[com.microsoft](relu_2, GemmTransposePattern--_onx_transpose02, p_fc2_bias, transA=0, beta=1.00, activation=b'Relu', transB=1, alpha=1.00) -> relu_3
                    Gemm(relu_3, GemmTransposePattern--_onx_transpose03, p_fc3_bias, transA=0, beta=1.00, transB=1, alpha=1.00) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 10]

dynopt

model = "ort-plot_torch_export_dynopt-cpu-aot1.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
input: name='x' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='reorder' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='conv1.bias' type=dtype('float32') shape=(16,)
init: name='reorder_token_2' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='conv2.bias' type=dtype('float32') shape=(16,)
init: name='fc1.bias' type=dtype('float32') shape=(512,)
init: name='fc2.bias' type=dtype('float32') shape=(128,)
init: name='fc3.bias' type=dtype('float32') shape=(10,)
init: name='val_3' type=dtype('int64') shape=(2,) -- array([ 1, 16])
init: name='t' type=dtype('float32') shape=(16, 512)
init: name='t_1' type=dtype('float32') shape=(512, 128)
init: name='t_2' type=dtype('float32') shape=(128, 10)
Conv[com.microsoft.nchwc](x, reorder, conv1.bias, activation=b'Relu', group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET', dilations=[1,1]) -> reorder_token_0
  MaxPool[com.microsoft.nchwc](reorder_token_0, pads=[0,0,0,0], kernel_shape=[2,2], ceil_mode=0, auto_pad=b'NOTSET', dilations=[1,1], strides=[2,2], storage_order=0) -> reorder_token_1
    Conv[com.microsoft.nchwc](reorder_token_1, reorder_token_2, conv2.bias, activation=b'Relu', group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET', dilations=[1,1]) -> reorder_token_3
      MaxPool[com.microsoft.nchwc](reorder_token_3, pads=[0,0,0,0], kernel_shape=[2,2], ceil_mode=0, auto_pad=b'NOTSET', dilations=[1,1], strides=[2,2], storage_order=0) -> reorder_token_4
        ReorderOutput[com.microsoft.nchwc](reorder_token_4, channels_last=0, channels=16) -> max_pool2d_1
          Reshape(max_pool2d_1, val_3, allowzero=0) -> view
            FusedGemm[com.microsoft](view, t, fc1.bias, transA=0, alpha=1.00, activation=b'Relu', transB=0, beta=1.00) -> relu_2
              FusedGemm[com.microsoft](relu_2, t_1, fc2.bias, transA=0, alpha=1.00, activation=b'Relu', transB=0, beta=1.00) -> relu_3
                Gemm(relu_3, t_2, fc3.bias, transA=0, alpha=1.00, transB=0, beta=1.00) -> addmm_2
output: name='addmm_2' type=dtype('float32') shape=[1, 10]

dynamo

model = "ort-plot_torch_export_dynamo-cpu-aot1.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
input: name='x' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='reorder' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='conv1.bias' type=dtype('float32') shape=(16,)
init: name='reorder_token_1' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='conv2.bias' type=dtype('float32') shape=(16,)
init: name='fc1.weight' type=dtype('float32') shape=(512, 16)
init: name='fc1.bias' type=dtype('float32') shape=(512,)
init: name='fc2.weight' type=dtype('float32') shape=(128, 512)
init: name='fc2.bias' type=dtype('float32') shape=(128,)
init: name='fc3.weight' type=dtype('float32') shape=(10, 128)
init: name='fc3.bias' type=dtype('float32') shape=(10,)
init: name='val_2' type=dtype('int64') shape=(2,) -- array([ 1, 16])
Conv[com.microsoft.nchwc](x, reorder, conv1.bias, activation=b'Relu', group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET', dilations=[1,1]) -> reorder_token_0
  ReorderOutput[com.microsoft.nchwc](reorder_token_0, channels_last=0, channels=16) -> relu
    MaxPool(relu, pads=[0,0,0,0], kernel_shape=[2,2], ceil_mode=0, auto_pad=b'NOTSET', dilations=[1,1], strides=[2,2], storage_order=0) -> max_pool2d, val_0
      ReorderInput[com.microsoft.nchwc](max_pool2d, channels_last=0) -> reorder_token_2
        Conv[com.microsoft.nchwc](reorder_token_2, reorder_token_1, conv2.bias, activation=b'Relu', group=1, strides=[1,1], pads=[0,0,0,0], auto_pad=b'NOTSET', dilations=[1,1]) -> reorder_token_3
          ReorderOutput[com.microsoft.nchwc](reorder_token_3, channels_last=0, channels=16) -> relu_1
            MaxPool(relu_1, pads=[0,0,0,0], kernel_shape=[2,2], ceil_mode=0, auto_pad=b'NOTSET', dilations=[1,1], strides=[2,2], storage_order=0) -> max_pool2d_1, val_1
              Reshape(max_pool2d_1, val_2, allowzero=0) -> view
                FusedGemm[com.microsoft](view, fc1.weight, fc1.bias, activation=b'Relu', beta=1.00, transB=1, alpha=1.00, transA=0) -> relu_2
                  FusedGemm[com.microsoft](relu_2, fc2.weight, fc2.bias, activation=b'Relu', beta=1.00, transB=1, alpha=1.00, transA=0) -> relu_3
                    Gemm(relu_3, fc3.weight, fc3.bias, beta=1.00, transB=1, alpha=1.00, transA=0) -> addmm_2
output: name='addmm_2' type=dtype('float32') shape=[1, 10]

Show the interesting models for CUDA

script

model = "ort-plot_torch_export_cus_p2-cuda-aot0.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='p_conv1_weight' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='p_conv1_bias' type=dtype('float32') shape=(16,)
init: name='p_conv2_weight' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='p_conv2_bias' type=dtype('float32') shape=(16,)
init: name='p_fc1_bias' type=dtype('float32') shape=(512,)
init: name='p_fc2_bias' type=dtype('float32') shape=(128,)
init: name='p_fc3_bias' type=dtype('float32') shape=(10,)
init: name='init7_s2_1_16' type=dtype('int64') shape=(2,) -- array([ 1, 16])
init: name='GemmTransposePattern--_onx_transpose0' type=dtype('float32') shape=(512, 16)
init: name='GemmTransposePattern--_onx_transpose02' type=dtype('float32') shape=(128, 512)
init: name='GemmTransposePattern--_onx_transpose03' type=dtype('float32') shape=(10, 128)
Conv(input, p_conv1_weight, p_conv1_bias, dilations=[1,1], group=1, pads=[0,0,0,0], strides=[1,1]) -> conv2d
  Relu(conv2d) -> relu
    MaxPool(relu, ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool0, _onx_maxpool1
      Conv(_onx_maxpool0, p_conv2_weight, p_conv2_bias, dilations=[1,1], group=1, pads=[0,0,0,0], strides=[1,1]) -> conv2d_1
        Relu(conv2d_1) -> relu_1
          MaxPool(relu_1, ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool02, _onx_maxpool12
            Reshape(_onx_maxpool02, init7_s2_1_16) -> view
              Gemm(view, GemmTransposePattern--_onx_transpose0, p_fc1_bias, transB=1) -> linear
                Relu(linear) -> relu_2
                  Gemm(relu_2, GemmTransposePattern--_onx_transpose02, p_fc2_bias, transB=1) -> linear_1
                    Relu(linear_1) -> relu_3
                      Gemm(relu_3, GemmTransposePattern--_onx_transpose03, p_fc3_bias, transB=1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 10]

cus_p2

model = "ort-plot_torch_export_cus_p2-cuda-aot0.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='p_conv1_weight' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='p_conv1_bias' type=dtype('float32') shape=(16,)
init: name='p_conv2_weight' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='p_conv2_bias' type=dtype('float32') shape=(16,)
init: name='p_fc1_bias' type=dtype('float32') shape=(512,)
init: name='p_fc2_bias' type=dtype('float32') shape=(128,)
init: name='p_fc3_bias' type=dtype('float32') shape=(10,)
init: name='init7_s2_1_16' type=dtype('int64') shape=(2,) -- array([ 1, 16])
init: name='GemmTransposePattern--_onx_transpose0' type=dtype('float32') shape=(512, 16)
init: name='GemmTransposePattern--_onx_transpose02' type=dtype('float32') shape=(128, 512)
init: name='GemmTransposePattern--_onx_transpose03' type=dtype('float32') shape=(10, 128)
Conv(input, p_conv1_weight, p_conv1_bias, dilations=[1,1], group=1, pads=[0,0,0,0], strides=[1,1]) -> conv2d
  Relu(conv2d) -> relu
    MaxPool(relu, ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool0, _onx_maxpool1
      Conv(_onx_maxpool0, p_conv2_weight, p_conv2_bias, dilations=[1,1], group=1, pads=[0,0,0,0], strides=[1,1]) -> conv2d_1
        Relu(conv2d_1) -> relu_1
          MaxPool(relu_1, ceil_mode=0, dilations=[1,1], kernel_shape=[2,2], pads=[0,0,0,0], strides=[2,2]) -> _onx_maxpool02, _onx_maxpool12
            Reshape(_onx_maxpool02, init7_s2_1_16) -> view
              Gemm(view, GemmTransposePattern--_onx_transpose0, p_fc1_bias, transB=1) -> linear
                Relu(linear) -> relu_2
                  Gemm(relu_2, GemmTransposePattern--_onx_transpose02, p_fc2_bias, transB=1) -> linear_1
                    Relu(linear_1) -> relu_3
                      Gemm(relu_3, GemmTransposePattern--_onx_transpose03, p_fc3_bias, transB=1) -> output_0
output: name='output_0' type=dtype('float32') shape=[1, 10]

dynopt

model = "ort-plot_torch_export_dynopt-cuda-aot1.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
input: name='x' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='conv1.weight' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='conv1.bias' type=dtype('float32') shape=(16,)
init: name='conv2.weight' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='conv2.bias' type=dtype('float32') shape=(16,)
init: name='fc1.bias' type=dtype('float32') shape=(512,)
init: name='fc2.bias' type=dtype('float32') shape=(128,)
init: name='fc3.bias' type=dtype('float32') shape=(10,)
init: name='val_3' type=dtype('int64') shape=(2,) -- array([ 1, 16])
init: name='t' type=dtype('float32') shape=(16, 512)
init: name='t_1' type=dtype('float32') shape=(512, 128)
init: name='t_2' type=dtype('float32') shape=(128, 10)
Conv(x, conv1.weight, conv1.bias, group=1, pads=[0,0,0,0], auto_pad=b'NOTSET', strides=[1,1], dilations=[1,1]) -> conv2d
  Relu(conv2d) -> relu
    MaxPool(relu, storage_order=0, dilations=[1,1], ceil_mode=0, pads=[0,0,0,0], auto_pad=b'NOTSET', strides=[2,2], kernel_shape=[2,2]) -> max_pool2d
      Conv(max_pool2d, conv2.weight, conv2.bias, group=1, pads=[0,0,0,0], auto_pad=b'NOTSET', strides=[1,1], dilations=[1,1]) -> conv2d_1
        Relu(conv2d_1) -> relu_1
          MaxPool(relu_1, storage_order=0, dilations=[1,1], ceil_mode=0, pads=[0,0,0,0], auto_pad=b'NOTSET', strides=[2,2], kernel_shape=[2,2]) -> max_pool2d_1
            Reshape(max_pool2d_1, val_3, allowzero=0) -> view
              Gemm(view, t, fc1.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm
                Relu(addmm) -> relu_2
                  Gemm(relu_2, t_1, fc2.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_1
                    Relu(addmm_1) -> relu_3
                      Gemm(relu_3, t_2, fc3.bias, beta=1.00, transB=0, alpha=1.00, transA=0) -> addmm_2
output: name='addmm_2' type=dtype('float32') shape=[1, 10]

dynamo

model = "ort-plot_torch_export_dynamo-cuda-aot1.onnx"
if os.path.exists(model):
    print(pretty_onnx(onnx.load(model)))
opset: domain='' version=18
opset: domain='pkg.onnxscript.torch_lib.common' version=1
opset: domain='ai.onnx.ml' version=5
opset: domain='onnx_extended.ortops.optim.cuda' version=1000
opset: domain='ai.onnx.training' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='com.microsoft' version=1
opset: domain='com.microsoft.experimental' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='org.pytorch.aten' version=1
input: name='x' type=dtype('float32') shape=[1, 1, 16, 16]
init: name='conv1.weight' type=dtype('float32') shape=(16, 1, 5, 5)
init: name='conv1.bias' type=dtype('float32') shape=(16,)
init: name='conv2.weight' type=dtype('float32') shape=(16, 16, 5, 5)
init: name='conv2.bias' type=dtype('float32') shape=(16,)
init: name='fc1.weight' type=dtype('float32') shape=(512, 16)
init: name='fc1.bias' type=dtype('float32') shape=(512,)
init: name='fc2.weight' type=dtype('float32') shape=(128, 512)
init: name='fc2.bias' type=dtype('float32') shape=(128,)
init: name='fc3.weight' type=dtype('float32') shape=(10, 128)
init: name='fc3.bias' type=dtype('float32') shape=(10,)
init: name='val_2' type=dtype('int64') shape=(2,) -- array([ 1, 16])
Conv(x, conv1.weight, conv1.bias, dilations=[1,1], auto_pad=b'NOTSET', pads=[0,0,0,0], strides=[1,1], group=1) -> conv2d
  Relu(conv2d) -> relu
    MaxPool(relu, pads=[0,0,0,0], kernel_shape=[2,2], ceil_mode=0, auto_pad=b'NOTSET', dilations=[1,1], strides=[2,2], storage_order=0) -> max_pool2d, val_0
      Conv(max_pool2d, conv2.weight, conv2.bias, dilations=[1,1], auto_pad=b'NOTSET', pads=[0,0,0,0], strides=[1,1], group=1) -> conv2d_1
        Relu(conv2d_1) -> relu_1
          MaxPool(relu_1, pads=[0,0,0,0], kernel_shape=[2,2], ceil_mode=0, auto_pad=b'NOTSET', dilations=[1,1], strides=[2,2], storage_order=0) -> max_pool2d_1, val_1
            Reshape(max_pool2d_1, val_2, allowzero=0) -> view
              Gemm(view, fc1.weight, fc1.bias, beta=1.00, transB=1, alpha=1.00, transA=0) -> addmm
                Relu(addmm) -> relu_2
                  Gemm(relu_2, fc2.weight, fc2.bias, beta=1.00, transB=1, alpha=1.00, transA=0) -> addmm_1
                    Relu(addmm_1) -> relu_3
                      Gemm(relu_3, fc3.weight, fc3.bias, beta=1.00, transB=1, alpha=1.00, transA=0) -> addmm_2
output: name='addmm_2' type=dtype('float32') shape=[1, 10]

Total running time of the script: (1 minutes 12.096 seconds)

Gallery generated by Sphinx-Gallery