Export Times

Custom Exporter

With a very simple model:

<<<

import time
from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 1.2991559179990873
    time to import onnx_array_api --- 0.0002983110007335199
    time to import torch --- 3.486133541999152
    'torch.export' already imported
    time to import torch.export --- 7.969001671881415e-06
    time to import onnxscript --- 0.2613919089999399
    time to import onnxruntime --- 0.066880033000416
    time to import torch.onnx --- 0.04449679699973785
    time to import torch._dynamo --- 1.7750149229996168
    time to import experimental_experiment.torch_interpreter --- 0.35057601999869803
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.006864240000140853
    time to export 1x --- 7.663285067999823
    time to export 2x --- 0.02144033999866224

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_model

model, example_args_collection = get_llama_model(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 38, in <module>
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 957, in to_onnx
        graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter(
                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 573, in _make_builder_interpreter
        exported_program = export_options.export(
                           ^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 471, in export
        exported_program = torch.export.export(
                           ^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
        return _export(
               ^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export
        ep = _export_for_training(
             ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
        export_artifact = export_func(
                          ^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in _non_strict_export
        aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx
        gm, graph_signature = transform(_make_fx_helper)(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict
        gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1679, in _make_fx_helper
        gm = make_fx(
             ^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2295, in wrapped
        return make_fx_tracer.trace(f, *args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2233, in trace
        return self._trace_inner(f, *args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2204, in _trace_inner
        t = dispatch_trace(
            ^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
        return disable_fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
        graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1792, in trace
        res = super().trace(root, concrete_args)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
        (self.create_arg(fn(*args)),),
                         ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1279, in wrapped
        out = f(*tensors)  # type:ignore[call-arg]
              ^^^^^^^^^^^
      File "<string>", line 1, in <lambda>
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1583, in wrapped_fn
        return tuple(flat_fn(*args))
                     ^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
        tree_out = fn(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
        out = mod(*args[params_len:], **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
        return Tracer.call_module(self, m, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
        ret_val = forward(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
        return _orig_module_call(mod, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
        return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1885, in forward
        tree_out = mod(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
        return Tracer.call_module(self, m, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
        ret_val = forward(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
        return _orig_module_call(mod, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
        return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_models/llama_helper.py", line 75, in forward
        model_output = self.model(
                       ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
        return Tracer.call_module(self, m, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
        ret_val = forward(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
        return _orig_module_call(mod, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
        return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/utils/generic.py", line 943, in wrapper
        output = func(self, *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/models/llama/modeling_llama.py", line 422, in forward
        causal_mask = create_causal_mask(
                      ^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 719, in create_causal_mask
        causal_mask = mask_interface(
                      ^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 466, in eager_mask
        mask = sdpa_mask(
               ^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 350, in sdpa_mask_recent_torch
        causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 48, in and_mask
        result = result & mask(batch_idx, head_idx, q_idx, kv_idx)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 118, in inner_mask
        return padding_mask[batch_idx, kv_idx]
               ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 141, in __torch_function__
        return mod_index(args[0], index_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/autograd/function.py", line 589, in apply
        return custom_function_call(cls, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 492, in wrapper
        return torch.overrides.handle_torch_function(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/overrides.py", line 1725, in handle_torch_function
        result = mode.__torch_function__(public_api, types, args, kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1327, in __torch_function__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 492, in wrapper
        return torch.overrides.handle_torch_function(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/overrides.py", line 1725, in handle_torch_function
        result = mode.__torch_function__(public_api, types, args, kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1374, in __torch_function__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 492, in wrapper
        return torch.overrides.handle_torch_function(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/overrides.py", line 1725, in handle_torch_function
        result = mode.__torch_function__(public_api, types, args, kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 976, in __torch_function__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 497, in wrapper
        return self.dispatch(
               ^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 458, in dispatch
        assert type(curr_mode) in self.python_key_table, (
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    AssertionError: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x71aa24280800> not registered

Dynamo Exporter

<<<

import time
import warnings

from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, x, dynamo=True)
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, x, dynamo=True)
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 1.3706378790011513
    time to import onnx_array_api --- 0.00039332699998340104
    time to import torch --- 2.7809098750003614
    'torch.export' already imported
    time to import torch.export --- 5.783000233350322e-06
    time to import onnxscript --- 0.2271753869990789
    time to import onnxruntime --- 0.0509410510003363
    time to import torch.onnx --- 0.04528911800116475
    time to import torch._dynamo --- 1.476014149000548
    time to import experimental_experiment.torch_interpreter --- 0.34863923900047666
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.00876640100068471
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    time to export 1x --- 1.6561283920000278
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    time to export 2x --- 0.7944334240000899

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_model

model, example_args_collection = get_llama_model(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, *example_args_collection[0], dynamo=True)
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, *example_args_collection[0], dynamo=True)
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=False)`... ❌
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=True)`...
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=True)`... ❌
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export draft_export`...
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export draft_export`... ❌
    [runpythonerror]
    Traceback (most recent call last):
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py", line 118, in __call__
        exported_program = self._capture(model, args, kwargs, dynamic_shapes)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py", line 202, in _capture
        return torch.export.export(
               ^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
        return _export(
               ^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export
        ep = _export_for_training(
             ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
        export_artifact = export_func(
                          ^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1933, in _non_strict_export
        ) = make_fake_inputs(
            ^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 286, in make_fake_inputs
        combined_args = _combine_args(nn_module, args, kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/dynamic_shapes.py", line 654, in _combine_args
        return signature.bind(*args, **kwargs).arguments
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/lib/python3.12/inspect.py", line 3277, in bind
        return self._bind(args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/lib/python3.12/inspect.py", line 3190, in _bind
        raise TypeError(msg) from None
    TypeError: missing a required argument: 'attention_mask'
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<stdin>", line 41, in <module>
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/__init__.py", line 367, in export
        return _compat.export_compat(
               ^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_compat.py", line 119, in export_compat
        onnx_program = _core.export(
                       ^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_flags.py", line 20, in wrapper
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 1332, in export
        raise _errors.TorchExportError(
    torch.onnx._internal.exporter._errors.TorchExportError: Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
    - Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
    - Debug `torch.export.export` and summit a PR to PyTorch.
    - Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.
    
    ## Exception summary
    
    <class 'TypeError'>: missing a required argument: 'attention_mask'
    
    (Refer to the full stack trace above for more information.)