Export Times

Custom Exporter

With a very simple model:

<<<

import time
from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 0.7503929399990739
    time to import onnx_array_api --- 0.00013528499948733952
    time to import torch --- 2.000482268998894
    'torch.export' already imported
    time to import torch.export --- 3.3070009521907195e-06
    time to import onnxscript --- 0.12176383400037594
    time to import onnxruntime --- 0.022342414000377175
    time to import torch.onnx --- 0.029618896000101813
    time to import torch._dynamo --- 1.2298131280003872
    time to import experimental_experiment.torch_interpreter --- 2.2079997140008345
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.004384150999612757
    time to export 1x --- 2.5003895059999195
    time to export 2x --- 0.018451580999681028
    [runpythonerror]
    nanobind: leaked 1 instances!
    nanobind: leaked 10 types!
     - leaked type "onnx.onnx_cpp2py_export.defs.Attribute"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.InferenceContext"
     - leaked type "onnx.onnx_cpp2py_export.defs.OpSchema"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.GraphInferencer"
     - leaked type "onnx.onnx_cpp2py_export.defs.FormalParameter"
     - leaked type "FormalParameterOption"
     - leaked type "SupportType"
     - leaked type "DifferentiationCategory"
     - leaked type "onnx.onnx_cpp2py_export.defs.TypeConstraintParam"
     - leaked type "AttrType"
     - ... skipped remainder
    nanobind: leaked 73 functions!
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "do_inferencing"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_display_name"
     - leaked function ""
     - leaked function "get_input_data"
     - leaked function "get_all_schemas"
     - leaked function "get_all_schemas_with_history"
     - leaked function ""
     - leaked function "get_input_sparse_data"
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "has_output"
     - leaked function "__init__"
     - leaked function "has_input"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_context_dependent_function_with_opset_version"
     - leaked function ""
     - leaked function ""
     - leaked function "get_graph_attribute_inferencer"
     - leaked function "has_schema"
     - leaked function ""
     - leaked function "set_output_type"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_num_inputs"
     - leaked function ""
     - leaked function ""
     - leaked function "_infer_node_outputs"
     - leaked function "get_num_outputs"
     - leaked function ""
     - leaked function ""
     - leaked function "get_type_and_shape_inference_function"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "is_infinite"
     - leaked function ""
     - leaked function "get_attribute"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_context_dependent_function"
     - leaked function ""
     - leaked function "get_symbolic_input"
     - leaked function "get_function_with_opset_version"
     - leaked function "set_type_and_shape_inference_function"
     - leaked function "get_input_type"
     - leaked function ""
     - leaked function ""
     - leaked function "deregister_schema"
     - leaked function "get_schema"
     - leaked function "__init__"
     - leaked function "get_output_type"
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
    nanobind: this is likely caused by a reference counting issue in the binding code.
    See https://nanobind.readthedocs.io/en/latest/refleaks.html

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_model

model, example_args_collection = get_llama_model(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 38, in <module>
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 1059, in to_onnx
        builder.process(graph_module, interpreter)
      File "~/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 5257, in process
        interpreter.run_node(node)
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 229, in run_node
        res = self.get_attr(node)
              ^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 336, in get_attr
        raise NotImplementedError(
    NotImplementedError: Unable to handle type <class 'torch.utils._pytree.TreeSpec'> for node.name='function_const_func_spec0'
    Node(type=<class 'torch._higher_order_ops.flat_apply._ConstantFunction'>, leaves=0)
    --
    --DEBUG--
    [GraphBuilder-HDO] Message starts, there are 24 initializers, 0 nodes, 2 inputs, 2 outputs.
    --SHAPE--
    _dynamic_examples=
    dynamic_objects=
    dynamic_objects_rev=
    dynamic_dimensions_source={}
    dynamic_dimensions_source_flat=None
    output_dynamic_dimensions_source_flat=None
    dynamic_alias={}
    dynamic_shapes=None
    _known_shapes={'attention_mask': (2, 1024),
     'b_model_rotary_emb_inv_freq': (64,),
     'c_model_lifted_tensor_0': (),
     'input_ids': (2, 1024),
     'model.embed_tokens.weight': (32000, 4096),
     'model.layers.0.input_layernorm.weight': (4096,),
     'model.layers.0.mlp.down_proj.weight': (4096, 11008),
     'model.layers.0.mlp.gate_proj.weight': (11008, 4096),
     'model.layers.0.mlp.up_proj.weight': (11008, 4096),
     'model.layers.0.post_attention_layernorm.weight': (4096,),
     'model.layers.0.self_attn.k_proj.weight': (4096, 4096),
     'model.layers.0.self_attn.o_proj.weight': (4096, 4096),
     'model.layers.0.self_attn.q_proj.weight': (4096, 4096),
     'model.layers.0.self_attn.v_proj.weight': (4096, 4096),
     'model.norm.weight': (4096,),
     'p_model_embed_tokens_weight': (32000, 4096),
     'p_model_layers_0_input_layernorm_weight': (4096,),
     'p_model_layers_0_mlp_down_proj_weight': (4096, 11008),
     'p_model_layers_0_mlp_gate_proj_weight': (11008, 4096),
     'p_model_layers_0_mlp_up_proj_weight': (11008, 4096),
     'p_model_layers_0_post_attention_layernorm_weight': (4096,),
     'p_model_layers_0_self_attn_k_proj_weight': (4096, 4096),
     'p_model_layers_0_self_attn_o_proj_weight': (4096, 4096),
     'p_model_layers_0_self_attn_q_proj_weight': (4096, 4096),
     'p_model_layers_0_self_attn_v_proj_weight': (4096, 4096),
     'p_model_norm_weight': (4096,)}
    _known_types={'attention_mask': 1,
     'b_model_rotary_emb_inv_freq': 1,
     'c_model_lifted_tensor_0': 1,
     'input_ids': 7,
     'model.embed_tokens.weight': 1,
     'model.layers.0.input_layernorm.weight': 1,
     'model.layers.0.mlp.down_proj.weight': 1,
     'model.layers.0.mlp.gate_proj.weight': 1,
     'model.layers.0.mlp.up_proj.weight': 1,
     'model.layers.0.post_attention_layernorm.weight': 1,
     'model.layers.0.self_attn.k_proj.weight': 1,
     'model.layers.0.self_attn.o_proj.weight': 1,
     'model.layers.0.self_attn.q_proj.weight': 1,
     'model.layers.0.self_attn.v_proj.weight': 1,
     'model.norm.weight': 1,
     'p_model_embed_tokens_weight': 1,
     'p_model_layers_0_input_layernorm_weight': 1,
     'p_model_layers_0_mlp_down_proj_weight': 1,
     'p_model_layers_0_mlp_gate_proj_weight': 1,
     'p_model_layers_0_mlp_up_proj_weight': 1,
     'p_model_layers_0_post_attention_layernorm_weight': 1,
     'p_model_layers_0_self_attn_k_proj_weight': 1,
     'p_model_layers_0_self_attn_o_proj_weight': 1,
     'p_model_layers_0_self_attn_q_proj_weight': 1,
     'p_model_layers_0_self_attn_v_proj_weight': 1,
     'p_model_norm_weight': 1}
    _known_value_shape={}
    _known_constants=['b_model_rotary_emb_inv_freq',
     'c_model_lifted_tensor_0',
     'model.embed_tokens.weight',
     'model.layers.0.input_layernorm.weight',
     'model.layers.0.mlp.down_proj.weight',
     'model.layers.0.mlp.gate_proj.weight',
     'model.layers.0.mlp.up_proj.weight',
     'model.layers.0.post_attention_layernorm.weight',
     'model.layers.0.self_attn.k_proj.weight',
     'model.layers.0.self_attn.o_proj.weight',
     'model.layers.0.self_attn.q_proj.weight',
     'model.layers.0.self_attn.v_proj.weight',
     'model.norm.weight',
     'p_model_embed_tokens_weight',
     'p_model_layers_0_input_layernorm_weight',
     'p_model_layers_0_mlp_down_proj_weight',
     'p_model_layers_0_mlp_gate_proj_weight',
     'p_model_layers_0_mlp_up_proj_weight',
     'p_model_layers_0_post_attention_layernorm_weight',
     'p_model_layers_0_self_attn_k_proj_weight',
     'p_model_layers_0_self_attn_o_proj_weight',
     'p_model_layers_0_self_attn_q_proj_weight',
     'p_model_layers_0_self_attn_v_proj_weight',
     'p_model_norm_weight']
    _known_ranks={}
    --PARAMETERS--
    _parameter_renaming=
       p_model_embed_tokens_weight = 'model.embed_tokens.weight'
       p_model_layers_0_input_layernorm_weight = 'model.layers.0.input_layernorm.weight'
       p_model_layers_0_mlp_down_proj_weight = 'model.layers.0.mlp.down_proj.weight'
       p_model_layers_0_mlp_gate_proj_weight = 'model.layers.0.mlp.gate_proj.weight'
       p_model_layers_0_mlp_up_proj_weight = 'model.layers.0.mlp.up_proj.weight'
       p_model_layers_0_post_attention_layernorm_weight = 'model.layers.0.post_attention_layernorm.weight'
       p_model_layers_0_self_attn_k_proj_weight = 'model.layers.0.self_attn.k_proj.weight'
       p_model_layers_0_self_attn_o_proj_weight = 'model.layers.0.self_attn.o_proj.weight'
       p_model_layers_0_self_attn_q_proj_weight = 'model.layers.0.self_attn.q_proj.weight'
       p_model_layers_0_self_attn_v_proj_weight = 'model.layers.0.self_attn.v_proj.weight'
       p_model_norm_weight = 'model.norm.weight'
    --TORCH-USERS--
        attention_mask -> {to}
        b_model_rotary_emb_inv_freq -> {wrap_with_set_grad_enabled}
        c_model_lifted_tensor_0 -> {lift_fresh_copy}
        function_const_func_spec0 -> {flat_apply}
        input_ids -> {embedding}
        p_model_embed_tokens_weight -> {embedding}
        p_model_layers_0_input_layernorm_weight -> {mul_3}
        p_model_layers_0_mlp_down_proj_weight -> {linear_6}
        p_model_layers_0_mlp_gate_proj_weight -> {linear_4}
        p_model_layers_0_mlp_up_proj_weight -> {linear_5}
        p_model_layers_0_post_attention_layernorm_weight -> {mul_10}
        p_model_layers_0_self_attn_k_proj_weight -> {linear_1}
        p_model_layers_0_self_attn_o_proj_weight -> {linear_3}
        p_model_layers_0_self_attn_q_proj_weight -> {linear}
        p_model_layers_0_self_attn_v_proj_weight -> {linear_2}
        p_model_norm_weight -> {mul_13}
    --TORCH-SHAPES--
        p_model_embed_tokens_weight: ('run_node', ('', ('val', torch.float32, torch.Size([32000, 4096])))) --- 1:2:(32000, 4096):
        p_model_layers_0_self_attn_q_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096, 4096])))) --- 1:2:(4096, 4096):
        p_model_layers_0_self_attn_k_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096, 4096])))) --- 1:2:(4096, 4096):
        p_model_layers_0_self_attn_v_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096, 4096])))) --- 1:2:(4096, 4096):
        p_model_layers_0_self_attn_o_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096, 4096])))) --- 1:2:(4096, 4096):
        p_model_layers_0_mlp_gate_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([11008, 4096])))) --- 1:2:(11008, 4096):
        p_model_layers_0_mlp_up_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([11008, 4096])))) --- 1:2:(11008, 4096):
        p_model_layers_0_mlp_down_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096, 11008])))) --- 1:2:(4096, 11008):
        p_model_layers_0_input_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096])))) --- 1:1:(4096,):
        p_model_layers_0_post_attention_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096])))) --- 1:1:(4096,):
        p_model_norm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([4096])))) --- 1:1:(4096,):
        b_model_rotary_emb_inv_freq: ('run_node', ('', ('val', torch.float32, torch.Size([64])))) --- 1:1:(64,):
        c_model_lifted_tensor_0: ('run_node', ('', ('val', torch.float32, torch.Size([])))) --- 1:0:():
        input_ids: ('run_node', ('', ('val', torch.int64, torch.Size([2, 1024])))) --- 7:2:(2, 1024):
        attention_mask: ('run_node', ('', ('val', torch.float32, torch.Size([2, 1024])))) --- 1:2:(2, 1024):
        function_const_func_spec0: ('run_node', ('', '')) --- :::
    --ONNX--
    -- EXEPATH --
    export-export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.scaled_dot_product_attention.default', 'aten.setitem', <built-in function setitem>))
    -- process.graph_module --
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_model_embed_tokens_weight: "f32[32000, 4096]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_o_proj_weight: "f32[4096, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[11008, 4096]", p_model_layers_0_mlp_up_proj_weight: "f32[11008, 4096]", p_model_layers_0_mlp_down_proj_weight: "f32[4096, 11008]", p_model_layers_0_input_layernorm_weight: "f32[4096]", p_model_layers_0_post_attention_layernorm_weight: "f32[4096]", p_model_norm_weight: "f32[4096]", b_model_rotary_emb_inv_freq: "f32[64]", c_model_lifted_tensor_0: "f32[]", input_ids: "i64[2, 1024]", attention_mask: "f32[2, 1024]"):
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
                function_const_func_spec0 = self.function_const_func_spec0
                torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
                embedding: "f32[2, 1024, 4096]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
                arange: "i64[1024]" = torch.ops.aten.arange.start(0, 1024, device = device(type='cpu'), pin_memory = False)
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:380 in forward, code: position_ids = cache_position.unsqueeze(0)
                unsqueeze: "i64[1, 1024]" = torch.ops.aten.unsqueeze.default(arange, 0)
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
                to: "b8[2, 1024]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
                arange_1: "i64[1024]" = torch.ops.aten.arange.default(1024, device = device(type='cpu'), pin_memory = False)
                add_: "i64[1024]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
                arange_2: "i64[2]" = torch.ops.aten.arange.default(2, device = device(type='cpu'), pin_memory = False)
                _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1);  arange_2 = None
                _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3);  arange = None
                _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4);  add_ = None
                new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
                le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2);  _add_batch_dim_2 = None
                to_1: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
                and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None
                flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3);  function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None
                to_2: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  flat_apply = None
                and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None
                _remove_batch_dim: "b8[1024]" = torch._functorch.predispatch._remove_batch_dim(and_2, 4, 1024, 0);  and_2 = None
                _remove_batch_dim_1: "b8[1024, 1024]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 1024, 0);  _remove_batch_dim = None
                _remove_batch_dim_2: "b8[1, 1024, 1024]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0);  _remove_batch_dim_1 = None
                _remove_batch_dim_3: "b8[2, 1, 1024, 1024]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0);  _remove_batch_dim_2 = None
                lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_model_lifted_tensor_0);  c_model_lifted_tensor_0 = None
                detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None
                where: "f32[2, 1, 1024, 1024]" = torch.ops.aten.where.ScalarOther(_remove_batch_dim_3, detach_, -3.4028234663852886e+38);  _remove_batch_dim_3 = detach_ = None
                
                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze);  submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_8: "f32[1, 1024, 128]" = wrap_with_set_grad_enabled[0]
                to_9: "f32[1, 1024, 128]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_10: "f32[2, 1024, 4096]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_1: "f32[2, 1024, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
                mean: "f32[2, 1024, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add: "f32[2, 1024, 1]" = torch.ops.aten.add.Tensor(mean, 1e-06);  mean = None
                rsqrt: "f32[2, 1024, 1]" = torch.ops.aten.rsqrt.default(add);  add = None
                mul_2: "f32[2, 1024, 4096]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_11: "f32[2, 1024, 4096]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
                mul_3: "f32[2, 1024, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear: "f32[2, 1024, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view: "f32[2, 1024, 32, 128]" = torch.ops.aten.view.default(linear, [2, 1024, -1, 128]);  linear = None
                transpose_1: "f32[2, 32, 1024, 128]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_1: "f32[2, 1024, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view_1: "f32[2, 1024, 32, 128]" = torch.ops.aten.view.default(linear_1, [2, 1024, -1, 128]);  linear_1 = None
                transpose_2: "f32[2, 32, 1024, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_2: "f32[2, 1024, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view_2: "f32[2, 1024, 32, 128]" = torch.ops.aten.view.default(linear_2, [2, 1024, -1, 128]);  linear_2 = None
                transpose_3: "f32[2, 32, 1024, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
                unsqueeze_4: "f32[1, 1, 1024, 128]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
                unsqueeze_5: "f32[1, 1, 1024, 128]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
                mul_4: "f32[2, 32, 1024, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_4)
                slice_1: "f32[2, 32, 1024, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
                slice_2: "f32[2, 32, 1024, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807);  transpose_1 = None
                neg: "f32[2, 32, 1024, 64]" = torch.ops.aten.neg.default(slice_2);  slice_2 = None
                cat_1: "f32[2, 32, 1024, 128]" = torch.ops.aten.cat.default([neg, slice_1], -1);  neg = slice_1 = None
                mul_5: "f32[2, 32, 1024, 128]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5);  cat_1 = None
                add_1: "f32[2, 32, 1024, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
                mul_6: "f32[2, 32, 1024, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_4);  unsqueeze_4 = None
                slice_3: "f32[2, 32, 1024, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
                slice_4: "f32[2, 32, 1024, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807);  transpose_2 = None
                neg_1: "f32[2, 32, 1024, 64]" = torch.ops.aten.neg.default(slice_4);  slice_4 = None
                cat_2: "f32[2, 32, 1024, 128]" = torch.ops.aten.cat.default([neg_1, slice_3], -1);  neg_1 = slice_3 = None
                mul_7: "f32[2, 32, 1024, 128]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_5);  cat_2 = unsqueeze_5 = None
                add_2: "f32[2, 32, 1024, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
                transpose_4: "f32[2, 32, 128, 1024]" = torch.ops.aten.transpose.int(add_2, 2, 3);  add_2 = None
                matmul_1: "f32[2, 32, 1024, 1024]" = torch.ops.aten.matmul.default(add_1, transpose_4);  add_1 = transpose_4 = None
                mul_8: "f32[2, 32, 1024, 1024]" = torch.ops.aten.mul.Tensor(matmul_1, 0.08838834764831845);  matmul_1 = None
                alias: "f32[2, 1, 1024, 1024]" = torch.ops.aten.alias.default(where);  where = None
                add_3: "f32[2, 32, 1024, 1024]" = torch.ops.aten.add.Tensor(mul_8, alias);  mul_8 = alias = None
                softmax: "f32[2, 32, 1024, 1024]" = torch.ops.aten.softmax.int(add_3, -1, torch.float32);  add_3 = None
                to_12: "f32[2, 32, 1024, 1024]" = torch.ops.aten.to.dtype(softmax, torch.float32);  softmax = None
                dropout: "f32[2, 32, 1024, 1024]" = torch.ops.aten.dropout.default(to_12, 0.0, True);  to_12 = None
                matmul_2: "f32[2, 32, 1024, 128]" = torch.ops.aten.matmul.default(dropout, transpose_3);  dropout = transpose_3 = None
                transpose_5: "f32[2, 1024, 32, 128]" = torch.ops.aten.transpose.int(matmul_2, 1, 2);  matmul_2 = None
                contiguous: "f32[2, 1024, 32, 128]" = torch.ops.aten.contiguous.default(transpose_5);  transpose_5 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
                reshape: "f32[2, 1024, 4096]" = torch.ops.aten.reshape.default(contiguous, [2, 1024, -1]);  contiguous = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_3: "f32[2, 1024, 4096]" = torch.ops.aten.linear.default(reshape, p_model_layers_0_self_attn_o_proj_weight);  reshape = p_model_layers_0_self_attn_o_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
                add_4: "f32[2, 1024, 4096]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_13: "f32[2, 1024, 4096]" = torch.ops.aten.to.dtype(add_4, torch.float32);  add_4 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_2: "f32[2, 1024, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2)
                mean_1: "f32[2, 1024, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add_5: "f32[2, 1024, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-06);  mean_1 = None
                rsqrt_1: "f32[2, 1024, 1]" = torch.ops.aten.rsqrt.default(add_5);  add_5 = None
                mul_9: "f32[2, 1024, 4096]" = torch.ops.aten.mul.Tensor(to_13, rsqrt_1);  rsqrt_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_14: "f32[2, 1024, 4096]" = torch.ops.aten.to.dtype(mul_9, torch.float32);  mul_9 = None
                mul_10: "f32[2, 1024, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_14);  p_model_layers_0_post_attention_layernorm_weight = to_14 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_4: "f32[2, 1024, 11008]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
                silu: "f32[2, 1024, 11008]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_5: "f32[2, 1024, 11008]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_up_proj_weight);  mul_10 = p_model_layers_0_mlp_up_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                mul_11: "f32[2, 1024, 11008]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_6: "f32[2, 1024, 4096]" = torch.ops.aten.linear.default(mul_11, p_model_layers_0_mlp_down_proj_weight);  mul_11 = p_model_layers_0_mlp_down_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
                add_6: "f32[2, 1024, 4096]" = torch.ops.aten.add.Tensor(to_13, linear_6);  to_13 = linear_6 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_15: "f32[2, 1024, 4096]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_3: "f32[2, 1024, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2)
                mean_2: "f32[2, 1024, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add_7: "f32[2, 1024, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-06);  mean_2 = None
                rsqrt_2: "f32[2, 1024, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
                mul_12: "f32[2, 1024, 4096]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_2);  to_15 = rsqrt_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_16: "f32[2, 1024, 4096]" = torch.ops.aten.to.dtype(mul_12, torch.float32);  mul_12 = None
                mul_13: "f32[2, 1024, 4096]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_16);  p_model_norm_weight = to_16 = None
                return (mul_13,)
                
            class submod_1(torch.nn.Module):
                def forward(self, b_model_rotary_emb_inv_freq: "f32[64]", unsqueeze: "i64[1, 1024]"):
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:96 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                    unsqueeze_1: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                    unsqueeze_2: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 2);  unsqueeze_1 = None
                    _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                    to_3: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_2, torch.float32);  unsqueeze_2 = None
                    expand_1: "f32[1, 64, 1]" = torch.ops.aten.expand.default(to_3, [1, -1, 1]);  to_3 = None
                    _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                    to_4: "f32[1, 64, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:97 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                    unsqueeze_3: "i64[1, 1, 1024]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
                    _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                    to_5: "f32[1, 1, 1024]" = torch.ops.aten.to.dtype(unsqueeze_3, torch.float32);  unsqueeze_3 = None
                    
                    # No stacktrace found for following nodes
                    submod_3 = self.submod_1
                    wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
                    mul: "f32[1, 1024, 128]" = wrap_with_autocast[0]
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
                    mul_1: "f32[1, 1024, 128]" = wrap_with_autocast[1];  wrap_with_autocast = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                    _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
                    to_8: "f32[1, 1024, 128]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                    _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
                    to_9: "f32[1, 1024, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                    return (to_8, to_9)
                    
                class submod_1(torch.nn.Module):
                    def forward(self, to_4: "f32[1, 64, 1]", to_5: "f32[1, 1, 1024]"):
                         # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:101 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                        _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                        to_6: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                        _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                        to_7: "f32[1, 1, 1024]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                        matmul: "f32[1, 64, 1024]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                        transpose: "f32[1, 1024, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None
                        
                         # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:102 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                        cat: "f32[1, 1024, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None
                        
                         # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
                        cos: "f32[1, 1024, 128]" = torch.ops.aten.cos.default(cat)
                        mul: "f32[1, 1024, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None
                        
                         # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
                        sin: "f32[1, 1024, 128]" = torch.ops.aten.sin.default(cat);  cat = None
                        mul_1: "f32[1, 1024, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                        return (mul, mul_1)
                        
    Graph signature: 
        # inputs
        p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
        p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
        p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
        p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
        p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
        p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
        p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
        p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
        p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
        p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
        p_model_norm_weight: PARAMETER target='model.norm.weight'
        b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
        c_model_lifted_tensor_0: CONSTANT_TENSOR target='model.lifted_tensor_0'
        input_ids: USER_INPUT
        attention_mask: USER_INPUT
        
        # outputs
        mul_13: USER_OUTPUT
        
    Range constraints: {}
    
    -- process.graph_module.graph --
    graph():
        %p_model_embed_tokens_weight : [num_users=1] = placeholder[target=p_model_embed_tokens_weight]
        %p_model_layers_0_self_attn_q_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_q_proj_weight]
        %p_model_layers_0_self_attn_k_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_k_proj_weight]
        %p_model_layers_0_self_attn_v_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_v_proj_weight]
        %p_model_layers_0_self_attn_o_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_o_proj_weight]
        %p_model_layers_0_mlp_gate_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_gate_proj_weight]
        %p_model_layers_0_mlp_up_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_up_proj_weight]
        %p_model_layers_0_mlp_down_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_down_proj_weight]
        %p_model_layers_0_input_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_input_layernorm_weight]
        %p_model_layers_0_post_attention_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_post_attention_layernorm_weight]
        %p_model_norm_weight : [num_users=1] = placeholder[target=p_model_norm_weight]
        %b_model_rotary_emb_inv_freq : [num_users=1] = placeholder[target=b_model_rotary_emb_inv_freq]
        %c_model_lifted_tensor_0 : [num_users=1] = placeholder[target=c_model_lifted_tensor_0]
        %input_ids : [num_users=1] = placeholder[target=input_ids]
        %attention_mask : [num_users=1] = placeholder[target=attention_mask]
        %function_const_func_spec0 : [num_users=1] = get_attr[target=function_const_func_spec0]
        %torch__dynamo__trace_wrapped_higher_order_op_mod_index0 : [num_users=1] = get_attr[target=torch__dynamo__trace_wrapped_higher_order_op_ModIndex0]
        %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_model_embed_tokens_weight, %input_ids), kwargs = {})
        %arange : [num_users=2] = call_function[target=torch.ops.aten.arange.start](args = (0, 1024), kwargs = {device: cpu, pin_memory: False})
        %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {})
        %to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%attention_mask, cpu, torch.bool), kwargs = {})
        %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (1024,), kwargs = {device: cpu, pin_memory: False})
        %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%arange_1, 0), kwargs = {})
        %arange_2 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (2,), kwargs = {device: cpu, pin_memory: False})
        %_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange_2, 0, 1), kwargs = {})
        %_add_batch_dim_2 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange, 0, 3), kwargs = {})
        %_add_batch_dim_3 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%add_, 0, 4), kwargs = {})
        %new_ones : [num_users=1] = call_function[target=torch.ops.aten.new_ones.default](args = (%_add_batch_dim_2, []), kwargs = {dtype: torch.bool, pin_memory: False})
        %le : [num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%_add_batch_dim_3, %_add_batch_dim_2), kwargs = {})
        %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%le,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
        %and_1 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%new_ones, %to_1), kwargs = {})
        %flat_apply : [num_users=1] = call_function[target=torch.ops.higher_order.flat_apply](args = (%function_const_func_spec0, %torch__dynamo__trace_wrapped_higher_order_op_mod_index0, torch._dynamo._trace_wrapped_higher_order_op.ModIndex, %to, %_add_batch_dim, %_add_batch_dim_3), kwargs = {})
        %to_2 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%flat_apply,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
        %and_2 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%and_1, %to_2), kwargs = {})
        %_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%and_2, 4, 1024, 0), kwargs = {})
        %_remove_batch_dim_1 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim, 3, 1024, 0), kwargs = {})
        %_remove_batch_dim_2 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_1, 2, 1, 0), kwargs = {})
        %_remove_batch_dim_3 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_2, 1, 2, 0), kwargs = {})
        %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_model_lifted_tensor_0,), kwargs = {})
        %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
        %where : [num_users=1] = call_function[target=torch.ops.aten.where.ScalarOther](args = (%_remove_batch_dim_3, %detach_, -3.4028234663852886e+38), kwargs = {})
        %submod_3 : [num_users=1] = get_attr[target=submod_1]
        %wrap_with_set_grad_enabled : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_set_grad_enabled](args = (False, %submod_3, %b_model_rotary_emb_inv_freq, %unsqueeze), kwargs = {})
        %to_8 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 0), kwargs = {})
        %to_9 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_set_grad_enabled, 1), kwargs = {})
        %to_10 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%embedding, torch.float32), kwargs = {})
        %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_10, 2), kwargs = {})
        %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
        %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {})
        %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
        %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_10, %rsqrt), kwargs = {})
        %to_11 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_2, torch.float32), kwargs = {})
        %mul_3 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_input_layernorm_weight, %to_11), kwargs = {})
        %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_q_proj_weight), kwargs = {})
        %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear, [2, 1024, -1, 128]), kwargs = {})
        %transpose_1 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view, 1, 2), kwargs = {})
        %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_k_proj_weight), kwargs = {})
        %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_1, [2, 1024, -1, 128]), kwargs = {})
        %transpose_2 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view_1, 1, 2), kwargs = {})
        %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_v_proj_weight), kwargs = {})
        %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_2, [2, 1024, -1, 128]), kwargs = {})
        %transpose_3 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_2, 1, 2), kwargs = {})
        %unsqueeze_4 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_8, 1), kwargs = {})
        %unsqueeze_5 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_9, 1), kwargs = {})
        %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_1, %unsqueeze_4), kwargs = {})
        %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 0, 64), kwargs = {})
        %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 64, 9223372036854775807), kwargs = {})
        %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
        %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
        %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_5), kwargs = {})
        %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
        %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_2, %unsqueeze_4), kwargs = {})
        %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 0, 64), kwargs = {})
        %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 64, 9223372036854775807), kwargs = {})
        %neg_1 : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
        %cat_2 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
        %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %unsqueeze_5), kwargs = {})
        %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
        %transpose_4 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%add_2, 2, 3), kwargs = {})
        %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%add_1, %transpose_4), kwargs = {})
        %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_1, 0.08838834764831845), kwargs = {})
        %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%where,), kwargs = {})
        %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_8, %alias), kwargs = {})
        %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%add_3, -1, torch.float32), kwargs = {})
        %to_12 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%softmax, torch.float32), kwargs = {})
        %dropout : [num_users=1] = call_function[target=torch.ops.aten.dropout.default](args = (%to_12, 0.0, True), kwargs = {})
        %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%dropout, %transpose_3), kwargs = {})
        %transpose_5 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%matmul_2, 1, 2), kwargs = {})
        %contiguous : [num_users=1] = call_function[target=torch.ops.aten.contiguous.default](args = (%transpose_5,), kwargs = {})
        %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%contiguous, [2, 1024, -1]), kwargs = {})
        %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%reshape, %p_model_layers_0_self_attn_o_proj_weight), kwargs = {})
        %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_10, %linear_3), kwargs = {})
        %to_13 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%add_4, torch.float32), kwargs = {})
        %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_13, 2), kwargs = {})
        %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
        %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-06), kwargs = {})
        %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {})
        %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_13, %rsqrt_1), kwargs = {})
        %to_14 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_9, torch.float32), kwargs = {})
        %mul_10 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_post_attention_layernorm_weight, %to_14), kwargs = {})
        %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_gate_proj_weight), kwargs = {})
        %silu : [num_users=1] = call_function[target=torch.ops.aten.silu.default](args = (%linear_4,), kwargs = {})
        %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_up_proj_weight), kwargs = {})
        %mul_11 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%silu, %linear_5), kwargs = {})
        %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_11, %p_model_layers_0_mlp_down_proj_weight), kwargs = {})
        %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_13, %linear_6), kwargs = {})
        %to_15 : [num_users=2] = call_function[target=torch.ops.aten.to.dtype](args = (%add_6, torch.float32), kwargs = {})
        %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_15, 2), kwargs = {})
        %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
        %add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-06), kwargs = {})
        %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_7,), kwargs = {})
        %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_15, %rsqrt_2), kwargs = {})
        %to_16 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_12, torch.float32), kwargs = {})
        %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_norm_weight, %to_16), kwargs = {})
        return (mul_13,)
    -- process.inputs_to_remove --
    set()
    -- process.progress --
    node 15/115 target=function_const_func_spec0
    -- 2 INPUTS
    [GraphBuilder-HDO.make_tensor_input] input_ids[7:2x1024]
    [GraphBuilder-HDO.make_tensor_input] attention_mask[1:2x1024]
    -- 24 INITIALIZERS
    [GraphBuilder-HDO.make_initializer] p_model_embed_tokens_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight)
    [GraphBuilder-HDO.make_initializer] model.embed_tokens.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_self_attn_q_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.self_attn.q_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_self_attn_k_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.self_attn.k_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_self_attn_v_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.self_attn.v_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_self_attn_o_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.self_attn.o_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_mlp_gate_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.mlp.gate_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_mlp_up_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.mlp.up_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_mlp_down_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.mlp.down_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_input_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.input_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight)
    [GraphBuilder-HDO.make_initializer] p_model_layers_0_post_attention_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight)
    [GraphBuilder-HDO.make_initializer] model.layers.0.post_attention_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight)
    [GraphBuilder-HDO.make_initializer] p_model_norm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight)
    [GraphBuilder-HDO.make_initializer] model.norm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight)
    [GraphBuilder-HDO.make_initializer] b_model_rotary_emb_inv_freq[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.0
    [GraphBuilder-HDO.make_initializer] c_model_lifted_tensor_0[torch.float32:torch.float32:[0.0]] - SOURCE: DynamoInterpret.placeholder.0
    -- 0 OUTPUTS
    [GraphBuilder-HDO] Message completed, there are 24 initializers, 0 nodes, 2 inputs, 2 outputs.
    nanobind: leaked 1 instances!
    nanobind: leaked 10 types!
     - leaked type "onnx.onnx_cpp2py_export.defs.Attribute"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.InferenceContext"
     - leaked type "onnx.onnx_cpp2py_export.defs.OpSchema"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.GraphInferencer"
     - leaked type "onnx.onnx_cpp2py_export.defs.FormalParameter"
     - leaked type "FormalParameterOption"
     - leaked type "SupportType"
     - leaked type "DifferentiationCategory"
     - leaked type "onnx.onnx_cpp2py_export.defs.TypeConstraintParam"
     - leaked type "AttrType"
     - ... skipped remainder
    nanobind: leaked 73 functions!
     - leaked function "set_type_and_shape_inference_function"
     - leaked function "get_symbolic_input"
     - leaked function ""
     - leaked function "has_output"
     - leaked function "get_context_dependent_function_with_opset_version"
     - leaked function ""
     - leaked function "get_context_dependent_function"
     - leaked function "get_input_sparse_data"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "do_inferencing"
     - leaked function ""
     - leaked function "get_schema"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_all_schemas"
     - leaked function "_infer_node_outputs"
     - leaked function "get_output_type"
     - leaked function "has_schema"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_display_name"
     - leaked function ""
     - leaked function "__init__"
     - leaked function "get_type_and_shape_inference_function"
     - leaked function ""
     - leaked function "get_graph_attribute_inferencer"
     - leaked function "get_function_with_opset_version"
     - leaked function ""
     - leaked function ""
     - leaked function "get_attribute"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "set_output_type"
     - leaked function ""
     - leaked function ""
     - leaked function "get_num_outputs"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function "deregister_schema"
     - leaked function ""
     - leaked function ""
     - leaked function "is_infinite"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "has_input"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_input_type"
     - leaked function "get_all_schemas_with_history"
     - leaked function "get_num_inputs"
     - leaked function "__init__"
     - leaked function ""
     - leaked function "get_input_data"
    nanobind: this is likely caused by a reference counting issue in the binding code.
    See https://nanobind.readthedocs.io/en/latest/refleaks.html

Dynamo Exporter

<<<

import time
import warnings

from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, x, dynamo=True)
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, x, dynamo=True)
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 0.9695512269991013
    time to import onnx_array_api --- 0.00018127700059267227
    time to import torch --- 2.42740269999922
    'torch.export' already imported
    time to import torch.export --- 3.5609991755336523e-06
    time to import onnxscript --- 0.1216287650004233
    time to import onnxruntime --- 0.031893441000647726
    time to import torch.onnx --- 0.030215790000511333
    time to import torch._dynamo --- 1.2252427189996524
    time to import experimental_experiment.torch_interpreter --- 3.1105823240013706
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.005543584000406554
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    time to export 1x --- 1.340667578000648
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `Neuron([...]` with `torch.export.export(..., strict=False)`... ✅
    [torch.onnx] Run decomposition...
    [torch.onnx] Run decomposition... ✅
    [torch.onnx] Translate the graph into ONNX...
    [torch.onnx] Translate the graph into ONNX... ✅
    time to export 2x --- 0.4698095599997032
    [runpythonerror]
    nanobind: leaked 1 instances!
    nanobind: leaked 10 types!
     - leaked type "onnx.onnx_cpp2py_export.defs.Attribute"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.InferenceContext"
     - leaked type "onnx.onnx_cpp2py_export.defs.OpSchema"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.GraphInferencer"
     - leaked type "onnx.onnx_cpp2py_export.defs.FormalParameter"
     - leaked type "FormalParameterOption"
     - leaked type "SupportType"
     - leaked type "DifferentiationCategory"
     - leaked type "onnx.onnx_cpp2py_export.defs.TypeConstraintParam"
     - leaked type "AttrType"
     - ... skipped remainder
    nanobind: leaked 73 functions!
     - leaked function ""
     - leaked function "get_symbolic_input"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function "__init__"
     - leaked function "is_infinite"
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function "get_output_type"
     - leaked function "has_schema"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_all_schemas_with_history"
     - leaked function ""
     - leaked function "get_input_sparse_data"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_function_with_opset_version"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "_infer_node_outputs"
     - leaked function ""
     - leaked function "get_context_dependent_function_with_opset_version"
     - leaked function ""
     - leaked function ""
     - leaked function "get_input_data"
     - leaked function "get_schema"
     - leaked function ""
     - leaked function "get_num_outputs"
     - leaked function "deregister_schema"
     - leaked function ""
     - leaked function "get_attribute"
     - leaked function "get_input_type"
     - leaked function "set_output_type"
     - leaked function "get_all_schemas"
     - leaked function ""
     - leaked function ""
     - leaked function "get_graph_attribute_inferencer"
     - leaked function "get_context_dependent_function"
     - leaked function "do_inferencing"
     - leaked function "get_type_and_shape_inference_function"
     - leaked function ""
     - leaked function ""
     - leaked function "has_output"
     - leaked function ""
     - leaked function "get_display_name"
     - leaked function "set_type_and_shape_inference_function"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_num_inputs"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function "has_input"
    nanobind: this is likely caused by a reference counting issue in the binding code.
    See https://nanobind.readthedocs.io/en/latest/refleaks.html

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_model

model, example_args_collection = get_llama_model(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, *example_args_collection[0], dynamo=True)
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.export(model, *example_args_collection[0], dynamo=True)
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=False)`...
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=False)`... ❌
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=True)`...
    [torch.onnx] Obtain model graph for `LlamaModelWrapper([...]` with `torch.export.export(..., strict=True)`... ❌
    [runpythonerror]
    Traceback (most recent call last):
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py", line 118, in __call__
        exported_program = self._capture(model, args, kwargs, dynamic_shapes)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_capture_strategies.py", line 210, in _capture
        return torch.export.export(
               ^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 315, in export
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 280, in export
        return _export(
               ^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1173, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1139, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2275, in _export
        ep = _export_for_training(
             ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1173, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1139, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2090, in _export_for_training
        export_artifact = export_func(
                          ^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1969, in _non_strict_export
        ) = make_fake_inputs(
            ^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 355, in make_fake_inputs
        combined_args = _combine_args(nn_module, args, kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/dynamic_shapes.py", line 702, in _combine_args
        return signature.bind(*args, **kwargs).arguments
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/lib/python3.12/inspect.py", line 3277, in bind
        return self._bind(args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/lib/python3.12/inspect.py", line 3190, in _bind
        raise TypeError(msg) from None
    TypeError: missing a required argument: 'attention_mask'
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "<stdin>", line 41, in <module>
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/__init__.py", line 296, in export
        return _compat.export_compat(
               ^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_compat.py", line 143, in export_compat
        onnx_program = _core.export(
                       ^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_flags.py", line 23, in wrapper
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_core.py", line 1385, in export
        raise _errors.TorchExportError(
    torch.onnx._internal.exporter._errors.TorchExportError: Failed to export the model with torch.export. This is step 1/3 of exporting the model to ONNX. Next steps:
    - Modify the model code for `torch.export.export` to succeed. Refer to https://pytorch.org/docs/stable/generated/exportdb/index.html for more information.
    - Debug `torch.export.export` and submit a PR to PyTorch.
    - Create an issue in the PyTorch GitHub repository against the *torch.export* component and attach the full error stack as well as reproduction scripts.
    
    ## Exception summary
    
    <class 'TypeError'>: missing a required argument: 'attention_mask'
    
    (Refer to the full stack trace above for more information.)
    nanobind: leaked 1 instances!
    nanobind: leaked 10 types!
     - leaked type "onnx.onnx_cpp2py_export.defs.Attribute"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.InferenceContext"
     - leaked type "onnx.onnx_cpp2py_export.defs.OpSchema"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.GraphInferencer"
     - leaked type "onnx.onnx_cpp2py_export.defs.FormalParameter"
     - leaked type "FormalParameterOption"
     - leaked type "SupportType"
     - leaked type "DifferentiationCategory"
     - leaked type "onnx.onnx_cpp2py_export.defs.TypeConstraintParam"
     - leaked type "AttrType"
     - ... skipped remainder
    nanobind: leaked 73 functions!
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_symbolic_input"
     - leaked function "__init__"
     - leaked function "deregister_schema"
     - leaked function ""
     - leaked function ""
     - leaked function "get_num_inputs"
     - leaked function ""
     - leaked function "get_graph_attribute_inferencer"
     - leaked function "get_attribute"
     - leaked function "get_context_dependent_function"
     - leaked function "has_input"
     - leaked function "get_input_sparse_data"
     - leaked function "get_all_schemas_with_history"
     - leaked function ""
     - leaked function "has_schema"
     - leaked function ""
     - leaked function ""
     - leaked function "_infer_node_outputs"
     - leaked function "get_context_dependent_function_with_opset_version"
     - leaked function "has_output"
     - leaked function "get_input_data"
     - leaked function "get_display_name"
     - leaked function ""
     - leaked function "set_type_and_shape_inference_function"
     - leaked function ""
     - leaked function ""
     - leaked function "get_output_type"
     - leaked function "do_inferencing"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_function_with_opset_version"
     - leaked function ""
     - leaked function ""
     - leaked function "get_type_and_shape_inference_function"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_num_outputs"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_schema"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "is_infinite"
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function "set_output_type"
     - leaked function "get_all_schemas"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_input_type"
     - leaked function ""
     - leaked function ""
    nanobind: this is likely caused by a reference counting issue in the binding code.
    See https://nanobind.readthedocs.io/en/latest/refleaks.html