Times

fx_mode

symbolic

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

begin = time.perf_counter()
print("creating model")
model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=2,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

torch._dynamo.reset()
begin = time.perf_counter()
torch._dynamo.export(model, tracing_mode="symbolic")(*example_args_collection[0])
print(f"time to export symbolic --- {time.perf_counter() - begin}")

>>>

    creating model
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    time to export symbolic --- 0.9746730999977444

fake

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

begin = time.perf_counter()
print("creating model")
model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=2,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

torch._dynamo.reset()
begin = time.perf_counter()
torch._dynamo.export(model, tracing_mode="fake")(*example_args_collection[0])
print(f"time to export fake --- {time.perf_counter() - begin}")

>>>

    creating model
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:206: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.fast_dtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:327: DeprecationWarning: torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:593.)
      self.prev = torch.is_autocast_cpu_enabled()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:328: DeprecationWarning: torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:728.)
      self.prev_fastdtype = torch.get_autocast_cpu_dtype()
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:329: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self._enabled)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:330: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.fast_dtype)  # type: ignore[arg-type]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:378: DeprecationWarning: torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:584.)
      torch.set_autocast_cpu_enabled(self.prev)
    /home/xadupre/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:379: DeprecationWarning: torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead. (Triggered internally at ../torch/csrc/autograd/init.cpp:677.)
      torch.set_autocast_cpu_dtype(self.prev_fastdtype)
    time to export fake --- 0.22533040000416804

Custom Exporter

With a very simple model:

<<<

import time
from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 0.18910039999900619
    time to import onnx_array_api --- 0.0001554000045871362
    time to import torch --- 1.0526032999987365
    'torch.export' already imported
    time to import torch.export --- 1.8999999156221747e-06
    time to import onnxscript --- 0.0835092999986955
    [2024-05-08 14:07:30,451] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    time to import onnxruntime --- 2.1038609000024735
    'torch.onnx' already imported
    time to import torch.onnx --- 2.3999964469112456e-06
    time to import torch._dynamo --- 0.1279768000022159
    time to import experimental_experiment.torch_interpreter --- 0.008113299998512957
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.0021956999989924952
    time to export 1x --- 0.3059434999959194
    time to export 2x --- 0.04171910000150092

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    [2024-05-08 14:07:36,450] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 37, in <module>
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 321, in to_onnx
        graph_module, builder, interpreter = _make_builder_interpreter(
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 175, in _make_builder_interpreter
        exported_mod = _export(
      File "/home/xadupre/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 102, in _export
        exported_mod = torch.export.export(mod, args, dynamic_shapes=dynamic_shapes)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 174, in export
        return _export(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 833, in wrapper
        raise e
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 816, in wrapper
        ep = fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 85, in wrapper
        return fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1168, in _export
        ep_non_strict = _export_non_strict(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 520, in _export_non_strict
        gm, graph_signature = transform(aot_export_module)(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1135, in aot_export_module
        fx_g, metadata, in_spec, out_spec = _aot_export_function(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1354, in _aot_export_function
        fx_g, meta = create_aot_dispatcher_function(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 268, in time_wrapper
        r = func(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 685, in create_aot_dispatcher_function
        compiled_fn = compiler_fn(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 470, in aot_wrapper_dedupe
        return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 672, in aot_wrapper_synthetic_base
        return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 134, in aot_dispatch_base_graph
        fw_module = _create_graph(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 43, in _create_graph
        fx_g = make_fx(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1276, in wrapped
        t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
        return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
        return fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
        return fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 658, in dispatch_trace
        graph = tracer.trace(root, concrete_args)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1028, in trace
        res = super().trace(root, concrete_args)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
        return fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
        return fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 793, in trace
        (self.create_arg(fn(*args)),),
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 676, in wrapped
        out = f(*tensors)
      File "<string>", line 1, in <lambda>
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 387, in _functionalized_f_helper
        f_outs = fn(*f_args)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 71, in inner_fn
        outs = fn(*args)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 176, in flat_fn
        tree_out = fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 695, in functional_call
        out = PropagateUnbackedSymInts(mod).run(
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run
        self.env[node] = self.run_node(node)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4835, in run_node
        result = super().run_node(n)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node
        return getattr(self, n.op)(n.target, args, kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/interpreter.py", line 274, in call_function
        return target(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 721, in __torch_function__
        return func(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/export/_safeguard.py", line 37, in __torch_function__
        raise RuntimeError(
    RuntimeError: Encountered autograd state manager op <built-in function _set_grad_enabled> trying to change global autograd state while exporting. This is unsafe because we don't capture this op in torch.export today, hence we can't reflect the user intention soundly. You can fix this by adding a torch.no_grad() context around the export call.
    
    While executing %_set_grad_enabled_1 : [num_users=0] = call_function[target=torch._C._set_grad_enabled](args = (True,), kwargs = {})
    Original traceback:
    None

Dynamo Exporter

<<<

import time
import warnings

from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, x)
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, x)
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 0.19956049999746028
    time to import onnx_array_api --- 0.00016119999781949446
    time to import torch --- 1.07216960000369
    'torch.export' already imported
    time to import torch.export --- 2.0000006770715117e-06
    time to import onnxscript --- 0.09584490000270307
    [2024-05-08 14:08:07,020] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    time to import onnxruntime --- 2.1254180999967502
    'torch.onnx' already imported
    time to import torch.onnx --- 1.5999976312741637e-06
    time to import torch._dynamo --- 0.10234369999670889
    time to import experimental_experiment.torch_interpreter --- 0.007185000002209563
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.0016757000048528425
    Applied 0 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    time to export 1x --- 0.8329499000028591
    Applied 0 of general pattern rewrite rules.
    Applied 0 of general pattern rewrite rules.
    time to export 2x --- 0.07871480000176234

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, *example_args_collection[0])
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, *example_args_collection[0])
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    [2024-05-08 14:08:13,681] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    [runpythonerror]
    2024-05-08 14:09:18,568 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 67108864.
    2024-05-08 14:09:18,575 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t due to large size 67108864.
    2024-05-08 14:09:18,583 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 67108864.
    2024-05-08 14:09:18,583 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t_1 due to large size 67108864.
    2024-05-08 14:09:18,589 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 67108864.
    2024-05-08 14:09:18,589 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t_2 due to large size 67108864.
    2024-05-08 14:09:18,734 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 67108864.
    2024-05-08 14:09:18,734 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t_3 due to large size 67108864.
    2024-05-08 14:09:18,748 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 180355072.
    2024-05-08 14:09:18,748 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t_4 due to large size 180355072.
    2024-05-08 14:09:18,752 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 180355072.
    2024-05-08 14:09:18,752 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t_5 due to large size 180355072.
    2024-05-08 14:09:18,756 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue result_1 due to large size 180355072.
    2024-05-08 14:09:18,756 onnxscript.optimizer.constant_folding [WARNING] - Skip storing constant folded nvalue t_6 due to large size 180355072.