201: Evaluate DORT Training

It compares DORT to eager mode and onnxrt backend.

To run the script:

python _doc/examples/plot_torch_aot --help

Some helpers

import warnings

try:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        import onnxruntime

        has_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers()
except ImportError:
    print("onnxruntime not available.")
    import sys

    sys.exit(0)

import torch._dynamo
import contextlib
import itertools
import os
import gc
import platform

# import pickle
import pprint
import multiprocessing
import time
import cProfile
import pstats
import io
import logging
from pstats import SortKey

import numpy as np
import matplotlib.pyplot as plt
import pandas
import onnx
from onnx_array_api.profiling import profile2graph
import torch
from torch import nn
import torch.nn.functional as F
import experimental_experiment
from experimental_experiment.plotting.memory import memory_peak_plot
from experimental_experiment.ext_test_case import measure_time, get_figure
from experimental_experiment.args import get_parsed_args
from experimental_experiment.memory_peak import start_spying_on
from experimental_experiment.torch_models.training_helper import make_aot_ort
from tqdm import tqdm

has_cuda = has_cuda and torch.cuda.is_available()
logging.disable(logging.ERROR)


def system_info():
    obs = {}
    obs["processor"] = platform.processor()
    obs["cores"] = multiprocessing.cpu_count()
    try:
        obs["cuda"] = 1 if torch.cuda.is_available() else 0
        obs["cuda_count"] = torch.cuda.device_count()
        obs["cuda_name"] = torch.cuda.get_device_name()
        obs["cuda_capa"] = torch.cuda.get_device_capability()
    except (RuntimeError, AssertionError):
        # no cuda
        pass
    return obs


pprint.pprint(system_info())
{'cores': 8, 'cuda': 0, 'cuda_count': 0, 'processor': 'x86_64'}

Scripts arguments

script_args = get_parsed_args(
    "plot_torch_aot",
    description=__doc__,
    scenarios={
        "small": "small model to test",
        "middle": "55Mb model",
        "large": "1Gb model",
    },
    warmup=5,
    repeat=5,
    repeat1=(1, "repeat for the first iteration"),
    maxtime=(
        2,
        "maximum time to run a model to measure the computation time, "
        "it is 0.1 when scenario is small",
    ),
    expose="scenarios,repeat,repeat1,warmup",
)

if script_args.scenario in (None, "small"):
    script_args.maxtime = 0.1
print(f"scenario={script_args.scenario or 'small'}")
print(f"warmup={script_args.warmup}")
print(f"repeat={script_args.repeat}")
print(f"repeat1={script_args.repeat1}")
print(f"maxtime={script_args.maxtime}")
scenario=small
warmup=5
repeat=5
repeat1=1
maxtime=0.1

The model

A simple model to convert.

class MyModelClass(nn.Module):
    def __init__(self, scenario=script_args.scenario):
        super(MyModelClass, self).__init__()
        if scenario == "middle":
            self.large = False
            self.conv1 = nn.Conv2d(1, 32, 5)
            # self.conv2 = nn.Conv2d(128, 16, 5)
            self.fc1 = nn.Linear(30752, 1024)
            self.fcs = []
            self.fc2 = nn.Linear(1024, 128)
            self.fc3 = nn.Linear(128, 10)
        elif scenario in (None, "small"):
            self.large = False
            self.conv1 = nn.Conv2d(1, 16, 5)
            # self.conv2 = nn.Conv2d(16, 16, 5)
            self.fc1 = nn.Linear(144, 512)
            self.fcs = []
            self.fc2 = nn.Linear(512, 128)
            self.fc3 = nn.Linear(128, 10)
        elif scenario in (None, "large"):
            self.large = True
            self.conv1 = nn.Conv2d(1, 32, 5)
            # self.conv2 = nn.Conv2d(128, 16, 5)
            self.fc1 = nn.Linear(30752, 4096)
            # torch script does not support loops.
            self.fca = nn.Linear(4096, 4096)
            self.fcb = nn.Linear(4096, 4096)
            self.fcc = nn.Linear(4096, 4096)
            self.fcd = nn.Linear(4096, 4096)
            self.fce = nn.Linear(4096, 4096)
            self.fcf = nn.Linear(4096, 4096)
            self.fcg = nn.Linear(4096, 4096)
            self.fch = nn.Linear(4096, 4096)
            self.fci = nn.Linear(4096, 4096)
            # end of the unfolded loop.
            self.fc2 = nn.Linear(4096, 128)
            self.fc3 = nn.Linear(128, 10)
        else:
            raise ValueError(f"Unsupported scenario={scenario!r}.")

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (4, 4))
        # x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        if self.large:
            # loop
            x = F.relu(self.fca(x))
            x = F.relu(self.fcb(x))
            x = F.relu(self.fcc(x))
            x = F.relu(self.fcd(x))
            x = F.relu(self.fce(x))
            x = F.relu(self.fcf(x))
            x = F.relu(self.fcg(x))
            x = F.relu(self.fch(x))
            x = F.relu(self.fci(x))
            # end of the loop
        x = F.relu(self.fc2(x))
        y = self.fc3(x)
        return y


def create_model_and_input(scenario=script_args.scenario):
    if scenario == "middle":
        shape = [1, 1, 128, 128]
    elif scenario in (None, "small"):
        shape = [1, 1, 16, 16]
    elif scenario == "large":
        shape = [1, 1, 128, 128]
    else:
        raise ValueError(f"Unsupported scenario={scenario!r}.")
    input_tensor = torch.rand(*shape).to(torch.float32)
    y = torch.rand((1, 10)).to(torch.float32)
    model = MyModelClass(scenario=scenario)
    assert model(input_tensor) is not None
    return model, (input_tensor, y)


def torch_model_size(model):
    size_model = 0
    for param in model.parameters():
        size = param.numel() * torch.finfo(param.data.dtype).bits / 8
        size_model += size
    return size_model


model, input_tensors = create_model_and_input()
model_size = torch_model_size(model)
print(f"model size={model_size / 2 ** 20} Mb")
model size=0.5401992797851562 Mb

Backends

def run(model, tensor_x, tensor_y):
    tensor_x = tensor_x.detach()
    tensor_y = tensor_y.detach()
    for param in model.parameters():
        param.grad = None
    try:
        output = model(tensor_x)
    except Exception as e:
        raise AssertionError(f"issue with {type(tensor_x)}") from e
    loss = F.mse_loss(output, tensor_y)

    # return loss
    def _backward_():
        loss.backward()

    _backward_()
    return loss, (param.grad for param in model.parameters())


def get_torch_eager(model, *args):
    def my_compiler(gm, example_inputs):
        return gm.forward

    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            optimized_mod = torch.compile(model, fullgraph=True, backend=my_compiler)
            assert run(optimized_mod, *args)
            return optimized_mod


def get_torch_default(model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            optimized_mod = torch.compile(model, fullgraph=True, mode="reduce-overhead")
            assert run(optimized_mod, *args)
            return optimized_mod


def get_torch_dort(model, *args):
    with contextlib.redirect_stdout(io.StringIO()):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            local_aot_ort, _ = make_aot_ort(dynamic=True, rewrite=True)
            optimized_mod = torch.compile(model, backend=local_aot_ort, fullgraph=True)
            run(optimized_mod, *args)
            assert run(optimized_mod, *args)
            return optimized_mod

Let’s check they are working.

export_functions = [
    get_torch_eager,
    get_torch_default,
    get_torch_dort,
]

exporters = {f.__name__.replace("get_", ""): f for f in export_functions}

supported_exporters = {}
for k, v in exporters.items():
    print(f"run function {k}")
    filename = f"plot_torch_aot_{k}.onnx"
    torch._dynamo.reset()
    model, input_tensors = create_model_and_input()
    if 1:  # try:
        run(model, *input_tensors)
    else:  # except Exception as e:
        print(f"skipped due to {str(e)[:1000]}")  # noqa: F821
        continue
    supported_exporters[k] = v
    del model
    gc.collect()
    time.sleep(1)
run function torch_eager
run function torch_default
run function torch_dort

Compile and Memory

def flatten(ps):
    obs = ps["cpu"].to_dict(unit=2**20)
    if "gpus" in ps:
        for i, g in enumerate(ps["gpus"]):
            for k, v in g.to_dict(unit=2**20).items():
                obs[f"gpu{i}_{k}"] = v
    return obs


data = []

for k, v in supported_exporters.items():
    print(f"run compile for memory {k} on cpu")
    filename = f"plot_torch_aot_{k}.onnx"
    if has_cuda:
        torch.cuda.set_device(0)
    torch._dynamo.reset()
    # CPU
    model, input_tensors = create_model_and_input()
    stat = start_spying_on(cuda=1 if has_cuda else 0)
    run(model, *input_tensors)
    obs = flatten(stat.stop())
    print("done.")
    obs.update(dict(export=k, p="cpu"))
    data.append(obs)
    del model
    gc.collect()
    time.sleep(1)

    if not has_cuda:
        continue
    torch._dynamo.reset()
    # CUDA
    model, input_tensors = create_model_and_input()
    model = model.cuda()
    input_tensors = [i.cuda() for i in input_tensors]
    print(f"run compile for memory {k} on cuda")
    stat = start_spying_on(cuda=1 if has_cuda else 0)
    run(model, *input_tensors)
    obs = flatten(stat.stop())
    print("done.")
    obs.update(dict(export=k, p="cuda"))
    data.append(obs)
    del model
    gc.collect()
    time.sleep(1)
run compile for memory torch_eager on cpu
done.
run compile for memory torch_default on cpu
done.
run compile for memory torch_dort on cpu
done.

The result.

df1 = pandas.DataFrame(data)
df1.to_csv("plot_torch_aot_1_memory.csv", index=False)
df1.to_excel("plot_torch_aot_1_memory.xlsx", index=False)
print(df1)

for p in ["cpu", "cuda"]:
    if not has_cuda and p == "cuda":
        continue
    ax = memory_peak_plot(
        df1[df1["p"] == p],
        key=("export",),
        bars=[model_size * i / 2**20 for i in range(1, 5)],
        suptitle=f"Memory Consumption of the Compilation on {p}\n"
        f"model size={model_size / 2**20:1.0f} Mb",
    )
    get_figure(ax).savefig(f"plot_torch_aot_1_memory_{p}.png")
Memory Consumption of the Compilation on cpu model size=1 Mb, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb)
          peak         mean  ...         export    p
0  1031.144531  1030.970703  ...    torch_eager  cpu
1  1031.144531  1031.144531  ...  torch_default  cpu
2  1031.144531  1031.144531  ...     torch_dort  cpu

[3 rows x 7 columns]
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")

dort first iteration speed

data = []

for k, v in supported_exporters.items():
    print(f"run dort cpu {k}: {script_args.repeat1}")
    times = []
    for i in range(int(script_args.repeat1)):
        model, input_tensors = create_model_and_input()
        torch._dynamo.reset()
        begin = time.perf_counter()
        run(model, *input_tensors)
        duration = time.perf_counter() - begin
        times.append(duration)
        del model
        gc.collect()
        time.sleep(1)

    print(f"done: {times[-1]}")
    data.append(
        dict(
            export=k,
            time=np.mean(times),
            min=min(times),
            max=max(times),
            first=times[0],
            last=times[-1],
            std=np.std(times),
            p="cpu",
        )
    )

    if not has_cuda:
        continue

    print(f"run dort cuda {k}: {script_args.repeat1}")
    times = []
    for i in range(int(script_args.repeat1)):
        model, input_tensors = create_model_and_input()
        model = model.cuda()
        input_tensors = [i.cuda() for i in input_tensors]
        torch._dynamo.reset()
        begin = time.perf_counter()
        run(model, *input_tensors)
        duration = time.perf_counter() - begin
        times.append(duration)
        del model
        gc.collect()
        time.sleep(1)

    print(f"done: {times[-1]}")
    data.append(
        dict(
            export=k,
            time=np.mean(times),
            min=min(times),
            max=max(times),
            first=times[0],
            last=times[-1],
            std=np.std(times),
            p="cuda",
        )
    )
run dort cpu torch_eager: 1
done: 0.0010658000010153046
run dort cpu torch_default: 1
done: 0.002268799999001203
run dort cpu torch_dort: 1
done: 0.0011696000001393259

The result.

df1 = pandas.DataFrame(data)
df1.to_csv("plot_torch_aot_1_time.csv", index=False)
df1.to_excel("plot_torch_aot_1_time.xlsx", index=False)
print(df1)

fig, ax = plt.subplots(1, 1)
dfi = df1[["export", "p", "time", "std"]].set_index(["export", "p"])
dfi["time"].plot.bar(ax=ax, title="Compilation time", yerr=dfi["std"], rot=30)
fig.tight_layout()
fig.savefig("plot_torch_aot_1_time.png")
Compilation time
          export      time       min  ...      last  std    p
0    torch_eager  0.001066  0.001066  ...  0.001066  0.0  cpu
1  torch_default  0.002269  0.002269  ...  0.002269  0.0  cpu
2     torch_dort  0.001170  0.001170  ...  0.001170  0.0  cpu

[3 rows x 8 columns]

Compilation Profiling

def clean_text(text):
    pathes = [
        os.path.abspath(
            os.path.normpath(os.path.join(os.path.dirname(torch.__file__), ".."))
        ),
        os.path.abspath(
            os.path.normpath(os.path.join(os.path.dirname(onnx.__file__), ".."))
        ),
        os.path.abspath(
            os.path.normpath(
                os.path.join(os.path.dirname(experimental_experiment.__file__), "..")
            )
        ),
    ]
    for p in pathes:
        text = text.replace(p, "")
    text = text.replace("experimental_experiment", "experimental_experiment".upper())
    return text


def profile_function(
    name, export_function, with_args=True, verbose=False, suffix="export"
):
    if verbose:
        print(f"profile {name}: {export_function}")
    if with_args:
        model, input_tensors = create_model_and_input()
        export_function(model, input_tensors)
        pr = cProfile.Profile()
        pr.enable()
        for i in range(int(script_args.repeat1)):
            export_function(model, input_tensors)
        pr.disable()
    else:
        pr = cProfile.Profile()
        pr.enable()
        for i in range(int(script_args.repeat1)):
            export_function()
        pr.disable()
    s = io.StringIO()
    sortby = SortKey.CUMULATIVE
    ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
    ps.print_stats()
    # with open(f"plot_torch_aot_profile_{name}_{suffix}.pickle", "wb") as f:
    #     pickle.dump(ps, f)

    raw = s.getvalue()
    text = "\n".join(raw.split("\n")[:200])
    if verbose:
        print(text)
    with open(f"plot_torch_aot_profile_{name}_{suffix}.txt", "w") as f:
        f.write(raw)

    root, nodes = profile2graph(ps, clean_text=clean_text)
    text = root.to_text()
    with open(f"plot_torch_aot_profile_{name}_{suffix}_h.txt", "w") as f:
        f.write(text)
    if verbose:
        print("done.")


model, input_tensors = create_model_and_input()


def function_to_profile(model=model, input_tensors=input_tensors):
    return get_torch_dort(model, *input_tensors)


profile_function("dort", function_to_profile, verbose=True, suffix="1")
profile dort: <function function_to_profile at 0x7f2fca84c700>
         1022198 function calls (1000784 primitive calls) in 1.548 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    1.630    1.630 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_aot_201.py:516(function_to_profile)
        1    0.000    0.000    1.630    1.630 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_aot_201.py:258(get_torch_dort)
        2    0.000    0.000    1.373    0.687 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_aot_201.py:218(run)
      9/5    0.000    0.000    1.366    0.273 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:367(_fn)
     23/4    0.000    0.000    1.310    0.327 /home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1523(_wrapped_call_impl)
     23/4    0.000    0.000    1.310    0.327 /home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1529(_call_impl)
      6/4    0.000    0.000    0.774    0.194 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:105(call_func_at_runtime_with_args)
     12/4    0.001    0.000    0.771    0.193 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph_module.py:736(call_wrapped)
        4    0.000    0.000    0.771    0.193 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph_module.py:299(__call__)
        8    0.001    0.000    0.769    0.096 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py:837(_ort_acclerated_call)
        3    0.000    0.000    0.685    0.228 /home/xadupre/.local/lib/python3.10/site-packages/torch/autograd/graph.py:739(_engine_run_backward)
        3    0.002    0.001    0.685    0.228 {method 'run_backward' of 'torch._C._EngineBase' objects}
        4    0.000    0.000    0.583    0.146 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_models/training_helper.py:77(<lambda>)
        4    0.000    0.000    0.583    0.146 /home/xadupre/github/experimental-experiment/experimental_experiment/convert/convert_helper.py:55(optimize_model_proto)
        4    0.001    0.000    0.582    0.146 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/__init__.py:28(optimize)
        2    0.000    0.000    0.577    0.289 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_aot_201.py:230(_backward_)
        2    0.000    0.000    0.577    0.289 /home/xadupre/.local/lib/python3.10/site-packages/torch/_tensor.py:466(backward)
        2    0.000    0.000    0.577    0.289 /home/xadupre/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:165(backward)
        2    0.000    0.000    0.577    0.288 /home/xadupre/.local/lib/python3.10/site-packages/torch/autograd/function.py:286(apply)
        2    0.000    0.000    0.577    0.288 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:679(backward)
        2    0.000    0.000    0.576    0.288 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:879(call_compiled_backward)
        1    0.000    0.000    0.535    0.535 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:887(catch_errors)
        1    0.000    0.000    0.535    0.535 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:288(_convert_frame_assert)
      2/1    0.000    0.000    0.535    0.535 /usr/lib/python3.10/contextlib.py:76(inner)
        1    0.000    0.000    0.534    0.534 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:444(_compile)
      3/1    0.000    0.000    0.533    0.533 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py:258(time_wrapper)
        1    0.000    0.000    0.533    0.533 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:527(compile_inner)
      5/4    0.000    0.000    0.530    0.132 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:34(inner)
        1    0.000    0.000    0.522    0.522 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:1028(transform_code_object)
        1    0.000    0.000    0.519    0.519 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:150(_fn)
        1    0.000    0.000    0.519    0.519 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:480(transform)
        1    0.000    0.000    0.516    0.516 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2115(run)
        1    0.000    0.000    0.516    0.516 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:835(run)
       44    0.000    0.000    0.516    0.012 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:727(step)
        2    0.000    0.000    0.514    0.257 <eval_with_key>.28:4(forward)
        1    0.000    0.000    0.476    0.476 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:2221(RETURN_VALUE)
        1    0.000    0.000    0.476    0.476 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:813(compile_subgraph)
        1    0.000    0.000    0.475    0.475 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:1075(compile_and_call_fx_graph)
       32    0.001    0.000    0.473    0.015 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:784(visit_model)
        1    0.000    0.000    0.470    0.470 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:1160(call_user_compiler)
      2/1    0.000    0.000    0.470    0.470 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py:59(debug_wrapper)
        1    0.000    0.000    0.470    0.470 /home/xadupre/.local/lib/python3.10/site-packages/torch/__init__.py:1777(__call__)
        1    0.000    0.000    0.470    0.470 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py:1089(__call__)
        1    0.000    0.000    0.470    0.470 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/backends/common.py:18(compiler_fn)
        1    0.000    0.000    0.469    0.469 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:804(aot_module_simplified)
        1    0.000    0.000    0.469    0.469 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:411(create_aot_dispatcher_function)
        1    0.000    0.000    0.400    0.400 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:411(aot_wrapper_dedupe)
        1    0.000    0.000    0.400    0.400 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:630(aot_wrapper_synthetic_base)
        1    0.000    0.000    0.399    0.399 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:233(aot_dispatch_autograd)
    90/32    0.006    0.000    0.323    0.010 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:644(visit_graph)
       24    0.000    0.000    0.290    0.012 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/simple_function_folding.py:28(visit_model)
        2    0.000    0.000    0.260    0.130 /home/xadupre/github/experimental-experiment/_doc/examples/plot_torch_aot_201.py:163(forward)
        2    0.000    0.000    0.260    0.130 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:913(forward)
      6/2    0.000    0.000    0.260    0.130 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:88(g)
        2    0.000    0.000    0.260    0.130 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py:77(runtime_wrapper)
        2    0.000    0.000    0.260    0.130 /home/xadupre/.local/lib/python3.10/site-packages/torch/autograd/function.py:590(apply)
        2    0.000    0.000    0.259    0.130 {built-in method apply}
        2    0.000    0.000    0.259    0.130 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:534(forward)
        1    0.000    0.000    0.258    0.258 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py:112(_lazy_forward)
        2    0.000    0.000    0.256    0.128 <eval_with_key>.24:4(forward)
        1    0.000    0.000    0.256    0.256 /home/xadupre/github/experimental-experiment/experimental_experiment/torch_models/training_helper.py:5(make_aot_ort)
        1    0.000    0.000    0.255    0.255 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py:722(__init__)
        2    0.044    0.022    0.243    0.122 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/decomposition_table.py:18(_create_onnx_supports_op_overload_table)
      285    0.011    0.000    0.238    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:268(__torch_dispatch__)
     9914    0.029    0.000    0.234    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:559(process_value_info)
        1    0.000    0.000    0.221    0.221 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:129(aot_dispatch_autograd_graph)
        1    0.000    0.000    0.215    0.215 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:33(_create_graph)
        1    0.000    0.000    0.215    0.215 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:1080(wrapped)
        1    0.000    0.000    0.214    0.214 /home/xadupre/.local/lib/python3.10/site-packages/torch/_compile.py:20(inner)
        1    0.000    0.000    0.214    0.214 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:627(dispatch_trace)
        1    0.000    0.000    0.213    0.213 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:663(trace)
        1    0.000    0.000    0.208    0.208 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:650(flatten_fn)
        1    0.000    0.000    0.208    0.208 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:643(wrapped)
 1248/483    0.007    0.000    0.197    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:797(visit_node)
       16    0.000    0.000    0.192    0.012 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/simple_function_folding.py:209(inline_simple_functions)
        1    0.000    0.000    0.191    0.191 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:520(joint_helper)
        1    0.000    0.000    0.191    0.191 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:350(_functionalized_f_helper)
        8    0.000    0.000    0.183    0.023 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/constant_folding.py:272(fold_constants)
        8    0.000    0.000    0.183    0.023 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/constant_folding.py:266(visit_model)
        1    0.000    0.000    0.172    0.172 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:251(inner_fn_with_anomaly)
        1    0.000    0.000    0.171    0.171 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:186(inner_fn)
  958/569    0.002    0.000    0.171    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_stats.py:15(wrapper)
        6    0.001    0.000    0.145    0.024 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:106(run)
  130/100    0.006    0.000    0.144    0.001 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:830(process_function_node)
    12090    0.032    0.000    0.137    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:60(load_from_value_info)
        1    0.000    0.000    0.136    0.136 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:357(__init__)
       32    0.000    0.000    0.134    0.004 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:734(_gather_function_metadata)
       32    0.000    0.000    0.134    0.004 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:696(visit_model)
       32    0.000    0.000    0.133    0.004 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:333(visit_model)
       32    0.005    0.000    0.132    0.004 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:347(visit_graph)
   305/12    0.008    0.000    0.131    0.011 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py:71(wrapper)
        1    0.000    0.000    0.125    0.125 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/decomposition_table.py:78(create_onnx_friendly_decomposition_table)
        1    0.001    0.001    0.125    0.125 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/partitioners.py:637(min_cut_rematerialization_partition)
      114    0.001    0.000    0.120    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:184(run_node)
    14704    0.017    0.000    0.119    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:251(is_registered_op)
        1    0.000    0.000    0.109    0.109 /home/xadupre/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:278(grad)
6057/5961    0.011    0.000    0.109    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/node.py:724(map_arg)
  263/242    0.002    0.000    0.109    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:727(__torch_dispatch__)
       24    0.000    0.000    0.107    0.004 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/simple_function_folding.py:24(_gather_function_metadata)
      420    0.007    0.000    0.106    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/constant_folding.py:181(process_node)
    14763    0.026    0.000    0.103    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/exporter.py:228(get_op_functions)
  263/242    0.001    0.000    0.101    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:758(inner_torch_dispatch)
        8    0.000    0.000    0.101    0.013 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/simple_function_folding.py:227(inline_functions_with_unused_outputs)
  671/667    0.002    0.000    0.098    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:886(__torch_dispatch__)
    12162    0.014    0.000    0.098    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:465(lookup_or_create)
    27/21    0.000    0.000    0.097    0.005 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/constant_folding.py:234(process_function_node)
    14266    0.055    0.000    0.096    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:69(process_value_info)
12850/5970    0.043    0.000    0.096    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/node.py:732(map_aggregate)
  671/667    0.006    0.000    0.095    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1215(dispatch)
        2    0.000    0.000    0.095    0.047 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:670(functional_call)
    69/54    0.003    0.000    0.094    0.002 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:298(proxy_call)
       22    0.000    0.000    0.092    0.004 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:652(run_node)
       12    0.000    0.000    0.091    0.008 /home/xadupre/github/onnx-rewriter/onnxrewriter/rewriter/__init__.py:24(rewrite)
    12162    0.020    0.000    0.084    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:288(lookup_or_create)
      345    0.003    0.000    0.084    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:940(_cached_dispatch_impl)
  595/417    0.002    0.000    0.082    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:859(tree_map)
        2    0.000    0.000    0.082    0.041 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/onnxruntime.py:1035(compile)
        2    0.000    0.000    0.081    0.040 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/passes/infra/partitioner.py:326(partition_and_fuse)
      525    0.003    0.000    0.079    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:886(create_node)
        4    0.000    0.000    0.077    0.019 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py:495(run)
97898/96567    0.052    0.000    0.074    0.000 {built-in method builtins.isinstance}
        2    0.000    0.000    0.072    0.036 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/passes/infra/partitioner.py:265(fuse_partitions)
        2    0.000    0.000    0.072    0.036 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/passes/utils/fuser_utils.py:218(fuse_by_partitions)
      340    0.002    0.000    0.068    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:1228(node_copy)
      572    0.002    0.000    0.068    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:1066(tree_map_only)
        4    0.002    0.000    0.068    0.017 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/partitioners.py:59(_extract_graph_with_inputs_outputs)
       89    0.001    0.000    0.067    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py:413(run_node)
       74    0.000    0.000    0.066    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:256(call_function)
  775/668    0.001    0.000    0.066    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/_ops.py:597(__call__)
    39456    0.034    0.000    0.063    0.000 {method 'get' of 'dict' objects}
   103/79    0.003    0.000    0.062    0.001 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/simple_function_folding.py:36(process_function_node)
7927/7689    0.007    0.000    0.062    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/node.py:738(<genexpr>)
        1    0.000    0.000    0.061    0.061 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/backends/common.py:34(_wrapped_bw_compiler)
        1    0.000    0.000    0.061    0.061 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:102(inner_fn)
       61    0.001    0.000    0.060    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py:647(call_function)
        8    0.000    0.000    0.060    0.008 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/irbuilder.py:218(build_ir)
        8    0.000    0.000    0.060    0.008 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/irbuilder.py:34(visit_model)
 2422/498    0.011    0.000    0.059    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:734(unflatten)
    12162    0.024    0.000    0.057    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:201(lookup_or_create)
      538    0.006    0.000    0.054    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/node.py:166(__init__)
        1    0.000    0.000    0.054    0.054 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/partitioners.py:157(_extract_fwd_bwd_modules)
        1    0.000    0.000    0.054    0.054 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:128(inner)
        8    0.000    0.000    0.053    0.007 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/_pass.py:240(run)
        4    0.000    0.000    0.052    0.013 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py:1716(_run)
      943    0.001    0.000    0.050    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:799(tree_flatten)
       10    0.000    0.000    0.049    0.005 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:1291(python_code)
 2978/943    0.010    0.000    0.048    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:778(_tree_flatten_helper)
       10    0.001    0.000    0.046    0.005 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:1353(_python_code)
        8    0.000    0.000    0.045    0.006 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:298(call_module)
       10    0.004    0.000    0.045    0.004 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:380(_gen_python_code)
    26869    0.029    0.000    0.044    0.000 /usr/lib/python3.10/logging/__init__.py:1455(debug)
   105/97    0.002    0.000    0.043    0.000 {method 'detach' of 'torch._C.TensorBase' objects}
       86    0.001    0.000    0.042    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/passes/type_promotion.py:1618(run_node)
       61    0.000    0.000    0.041    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:212(track_tensor_tree)
    76/61    0.000    0.000    0.041    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:213(wrap_with_proxy)
     8699    0.022    0.000    0.041    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/registration.py:55(from_qualified_name)
      345    0.007    0.000    0.040    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:975(_cache_key)
       59    0.000    0.000    0.040    0.001 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py:111(dispatch)
        4    0.001    0.000    0.039    0.010 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/passes/utils/fuser_utils.py:91(fuse_as_graphmodule)
      212    0.001    0.000    0.039    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/optimizer/evaluator.py:38(evaluate)
        9    0.000    0.000    0.038    0.004 /home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py:115(forward)
        9    0.002    0.000    0.038    0.004 {built-in method torch._C._nn.linear}
    85/80    0.000    0.000    0.037    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:139(extract_val)
       82    0.000    0.000    0.036    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:136(snapshot_fake)
      212    0.001    0.000    0.036    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:624(eval)
       74    0.000    0.000    0.035    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:172(set_meta)
      770    0.005    0.000    0.035    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/node.py:461(__update_args_kwargs)
      104    0.001    0.000    0.034    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/proxy.py:173(create_proxy)
        8    0.002    0.000    0.033    0.004 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/irbuilder.py:48(visit_graph)
       19    0.000    0.000    0.032    0.002 /home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/functional_utils.py:21(to_fun)
       19    0.000    0.000    0.032    0.002 /home/xadupre/.local/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:172(to_functional)
        9    0.000    0.000    0.032    0.004 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:474(wrapper)
        9    0.000    0.000    0.032    0.004 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1198(CALL_FUNCTION)
        9    0.000    0.000    0.031    0.003 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:673(call_function)
        6    0.000    0.000    0.031    0.005 /home/xadupre/.local/lib/python3.10/site-packages/torch/_logging/_internal.py:1026(trace_structured)
    14763    0.017    0.000    0.030    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/registration.py:44(from_name_parts)
 1507/644    0.003    0.000    0.030    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:792(<listcomp>)
        4    0.000    0.000    0.029    0.007 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph_module.py:820(print_readable)
        9    0.000    0.000    0.028    0.003 /home/xadupre/.local/lib/python3.10/site-packages/torch/nn/functional.py:1489(relu)
        9    0.001    0.000    0.028    0.003 {built-in method torch.relu}
        4    0.000    0.000    0.027    0.007 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:761(module_call_wrapper)
    46965    0.027    0.000    0.027    0.000 {method 'split' of 'str' objects}
        4    0.000    0.000    0.027    0.007 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py:552(call_module)
        4    0.000    0.000    0.027    0.007 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py:763(forward)
     2176    0.004    0.000    0.027    0.000 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/irbuilder.py:200(process_value_info)
       14    0.003    0.000    0.026    0.002 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:1395(lint)
       10    0.000    0.000    0.026    0.003 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:1332(wrap_fx_proxy)
       10    0.000    0.000    0.026    0.003 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py:1392(wrap_fx_proxy_cls)
        8    0.002    0.000    0.026    0.003 /home/xadupre/github/onnx-rewriter/onnxrewriter/ir/visitor.py:42(load_from_model_proto)
53660/53572    0.026    0.000    0.026    0.000 {built-in method builtins.len}
      212    0.002    0.000    0.024    0.000 /home/xadupre/github/onnx/onnx/reference/op_run.py:588(create)
       59    0.000    0.000    0.024    0.000 /home/xadupre/.local/lib/python3.10/site-packages/torch/onnx/_internal/fx/onnxfunction_dispatcher.py:198(_find_the_perfect_or_nearest_match_onnxfunction)
        4    0.000    0.000    0.024    0.006 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py:249(call_function)
        6    0.001    0.000    0.024    0.004 /home/xadupre/.local/lib/python3.10/site-packages/torch/fx/graph.py:1466(eliminate_dead_code)
       14    0.000    0.000    0.023    0.002 /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py:1207(wrap_fake_exception)
done.

Benchmark exported models with ORT

def benchmark(shape):
    data = []
    data_mem_first_run = []
    data_mem_run = []
    confs = list(
        itertools.product(
            export_functions,
            ["CPU", "CUDA"],
        )
    )
    loop = tqdm(confs)
    print(f"number of experiments: {len(loop)}")
    for export_fct, p in loop:
        name = export_fct.__name__.replace("get_torch_", "")
        obs = {}  # system_info()
        obs["name"] = name
        obs["compute"] = p
        obs["export"] = name

        model, input_tensors = create_model_and_input()
        if p == "CUDA":
            if not has_cuda:
                continue
            model = model.cuda()
            input_tensors = [i.cuda() for i in input_tensors]
        try:
            exported_model = export_fct(model, *input_tensors)
        except Exception as e:
            obs["error"] = str(e)
            data.append(obs)
            continue

        def call_model(
            export_fct=export_fct,
            exported_model=exported_model,
            input_tensors=input_tensors,
        ):
            res = run(exported_model, *input_tensors)
            return res

        stat = start_spying_on(cuda=1 if has_cuda else 0)
        try:
            call_model()
        except Exception as e:
            loop.set_description(f"ERROR-run: {name} {e}")
            obs.update({"error": e, "step": "load"})
            data.append(obs)
            stat.stop()
            continue
        memobs = flatten(stat.stop())
        memobs.update(obs)
        data_mem_first_run.append(memobs)

        # memory consumption
        stat = start_spying_on(cuda=1 if has_cuda else 0)
        for i in range(0, script_args.warmup):
            call_model()
        memobs = flatten(stat.stop())
        memobs.update(obs)
        data_mem_run.append(memobs)

        obs.update(
            measure_time(
                call_model,
                max_time=script_args.maxtime,
                repeat=script_args.repeat,
                number=1,
            )
        )

        profile_function(name, call_model, with_args=False, suffix=f"run_{p}")

        loop.set_description(f"{obs['average']} {name} {p}")
        data.append(obs)
        del model
        del exported_model
        gc.collect()
        time.sleep(1)

    df = pandas.DataFrame(data)
    df.to_csv("plot_torch_aot_ort_time.csv", index=False)
    df.to_excel("plot_torch_aot_ort_time.xlsx", index=False)
    dfmemr = pandas.DataFrame(data_mem_run)
    dfmemr.to_csv("plot_torch_aot_ort_run_mem.csv", index=False)
    dfmemr.to_excel("plot_torch_aot_ort_run_mem.xlsx", index=False)
    dfmemfr = pandas.DataFrame(data_mem_first_run)
    dfmemfr.to_csv("plot_torch_aot_ort_first_run_mem.csv", index=False)
    dfmemfr.to_excel("plot_torch_aot_ort_first_run_mem.xlsx", index=False)
    return df, dfmemfr, dfmemr


df, dfmemfr, dfmemr = benchmark(list(input_tensors[0].shape))
print(df)
  0%|          | 0/6 [00:00<?, ?it/s]number of experiments: 6

0.0010289129629738993 eager CPU:   0%|          | 0/6 [00:00<?, ?it/s]
0.0010289129629738993 eager CPU:  17%|█▋        | 1/6 [00:01<00:08,  1.72s/it]
0.0011632103448402172 default CPU:  17%|█▋        | 1/6 [00:15<00:08,  1.72s/it]
0.0011632103448402172 default CPU:  50%|█████     | 3/6 [00:17<00:18,  6.20s/it]/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:117: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(

0.0029109999999776598 dort CPU:  50%|█████     | 3/6 [00:18<00:18,  6.20s/it]
0.0029109999999776598 dort CPU:  83%|████████▎ | 5/6 [00:19<00:03,  3.63s/it]
0.0029109999999776598 dort CPU: 100%|██████████| 6/6 [00:19<00:00,  3.28s/it]
      name compute  ... context_size  warmup_time
0    eager     CPU  ...           64     0.001599
1  default     CPU  ...           64     0.001992
2     dort     CPU  ...           64     0.004979

[3 rows x 12 columns]

Other view

def view_time(df, title, suffix="time"):
    piv = pandas.pivot_table(df, index="export", columns=["compute"], values="average")
    print(piv)
    piv.to_csv(f"plot_torch_aot_{suffix}_compute.csv")
    piv.to_excel(f"plot_torch_aot_{suffix}_compute.xlsx")

    piv_cpu = pandas.pivot_table(
        df[df.compute == "CPU"],
        index="export",
        columns=["compute"],
        values="average",
    )

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    fig.suptitle(title)
    piv_cpu.plot.barh(ax=ax[0], title="CPU", logx=True)

    if has_cuda:
        piv_gpu = pandas.pivot_table(
            df[df.compute == "CUDA"],
            index="export",
            columns=["compute"],
            values="average",
        )
        piv_gpu.plot.barh(ax=ax[1], title="CUDA", logx=True)

    fig.tight_layout()
    fig.savefig(f"plot_torch_aot_{suffix}.png")
    return ax


view_time(df, "Compares processing time on backends")
Compares processing time on backends, CPU
compute       CPU
export
default  0.001163
dort     0.002911
eager    0.001029

array([<Axes: title={'center': 'CPU'}, ylabel='export'>, <Axes: >],
      dtype=object)

Memory First Running Time (ORT)

for compute in ["CPU", "CUDA"]:
    if not has_cuda and compute == "CUDA":
        continue
    ax = memory_peak_plot(
        dfmemfr[dfmemfr.compute == compute],
        ("export",),
        suptitle=f"Memory Consumption of backend, first running time"
        f"\nrunning on {compute}",
        bars=[model_size * i / 2**20 for i in range(1, 3)],
        figsize=(18, 6),
    )
    get_figure(ax).savefig(f"plot_torch_aot_first_run_mem_{compute}.png")
Memory Consumption of backend, first running time running on CPU, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb)
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")

Memory Running Time (ORT)

for compute in ["CPU", "CUDA"]:
    if not has_cuda and compute == "CUDA":
        continue
    ax = memory_peak_plot(
        dfmemr[dfmemr.compute == compute],
        ("export",),
        suptitle=f"Memory Consumption of backens, running time"
        f"\nrunning on {compute}",
        bars=[model_size * i / 2**20 for i in range(1, 3)],
        figsize=(18, 6),
    )
    get_figure(ax).savefig(f"plot_torch_aot_run_mem_{compute}.png")
Memory Consumption of backens, running time running on CPU, Memory peak (Mb), Memory peak - memory begin (Mb), Memory average - memory begin (Mb)
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")
/home/xadupre/github/experimental-experiment/experimental_experiment/plotting/memory.py:68: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax[i, j].set_xticklabels(ls, ha="right")

Total running time of the script: (0 minutes 38.796 seconds)

Gallery generated by Sphinx-Gallery