.torch_interpreter.oxs_dispatcher

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]

Tries the fallback even if is not necessary to check it is working.

Parameters:
  • verbose – verbosity

  • raise_exc – fail or raise an exception

The class can be used the following way.

<<<

import torch
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import (
    OxsDebugDispatcher,
)

with torch.no_grad():
    model, input_tensors = get_llama_model()
    input_tensors = input_tensors[0]

    to_onnx(
        model,
        input_tensors,
        input_names=[f"input{i}" for i in range(len(input_tensors))],
        options=OptimizationOptions(patterns=None),
        verbose=0,
        dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False),
    )

>>>

    
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 19, in <module>
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 1064, in to_onnx
        builder.process(graph_module, interpreter)
      File "~/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 5173, in process
        interpreter.run_node(node)
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 229, in run_node
        res = self.get_attr(node)
              ^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 336, in get_attr
        raise NotImplementedError(
    NotImplementedError: Unable to handle type <class 'torch.utils._pytree.TreeSpec'> for node.name='function_const_func_spec0'
    Node(type=<class 'torch._higher_order_ops.flat_apply._ConstantFunction'>, leaves=0)
    --
    --DEBUG--
    [GraphBuilder-YHC] Message starts, there are 24 initializers, 2 nodes, 2 inputs, 2 outputs.
    --SHAPE--
    _dynamic_examples=
    dynamic_objects=
    dynamic_objects_rev=
    dynamic_dimensions_source={}
    dynamic_dimensions_source_flat=None
    output_dynamic_dimensions_source_flat=None
    dynamic_alias={}
    dynamic_shapes=None
    _known_shapes={'attention_mask': (2, 8),
     'b_model_rotary_emb_inv_freq': (4,),
     'c_model_lifted_tensor_0': (),
     'input0': (2, 8),
     'input1': (2, 8),
     'input_ids': (2, 8),
     'model.embed_tokens.weight': (1024, 16),
     'model.layers.0.input_layernorm.weight': (16,),
     'model.layers.0.mlp.down_proj.weight': (16, 16),
     'model.layers.0.mlp.gate_proj.weight': (16, 16),
     'model.layers.0.mlp.up_proj.weight': (16, 16),
     'model.layers.0.post_attention_layernorm.weight': (16,),
     'model.layers.0.self_attn.k_proj.weight': (16, 16),
     'model.layers.0.self_attn.o_proj.weight': (16, 16),
     'model.layers.0.self_attn.q_proj.weight': (16, 16),
     'model.layers.0.self_attn.v_proj.weight': (16, 16),
     'model.norm.weight': (16,),
     'p_model_embed_tokens_weight': (1024, 16),
     'p_model_layers_0_input_layernorm_weight': (16,),
     'p_model_layers_0_mlp_down_proj_weight': (16, 16),
     'p_model_layers_0_mlp_gate_proj_weight': (16, 16),
     'p_model_layers_0_mlp_up_proj_weight': (16, 16),
     'p_model_layers_0_post_attention_layernorm_weight': (16,),
     'p_model_layers_0_self_attn_k_proj_weight': (16, 16),
     'p_model_layers_0_self_attn_o_proj_weight': (16, 16),
     'p_model_layers_0_self_attn_q_proj_weight': (16, 16),
     'p_model_layers_0_self_attn_v_proj_weight': (16, 16),
     'p_model_norm_weight': (16,)}
    _known_types={'attention_mask': 1,
     'b_model_rotary_emb_inv_freq': 1,
     'c_model_lifted_tensor_0': 1,
     'input0': 7,
     'input1': 1,
     'input_ids': 7,
     'model.embed_tokens.weight': 1,
     'model.layers.0.input_layernorm.weight': 1,
     'model.layers.0.mlp.down_proj.weight': 1,
     'model.layers.0.mlp.gate_proj.weight': 1,
     'model.layers.0.mlp.up_proj.weight': 1,
     'model.layers.0.post_attention_layernorm.weight': 1,
     'model.layers.0.self_attn.k_proj.weight': 1,
     'model.layers.0.self_attn.o_proj.weight': 1,
     'model.layers.0.self_attn.q_proj.weight': 1,
     'model.layers.0.self_attn.v_proj.weight': 1,
     'model.norm.weight': 1,
     'p_model_embed_tokens_weight': 1,
     'p_model_layers_0_input_layernorm_weight': 1,
     'p_model_layers_0_mlp_down_proj_weight': 1,
     'p_model_layers_0_mlp_gate_proj_weight': 1,
     'p_model_layers_0_mlp_up_proj_weight': 1,
     'p_model_layers_0_post_attention_layernorm_weight': 1,
     'p_model_layers_0_self_attn_k_proj_weight': 1,
     'p_model_layers_0_self_attn_o_proj_weight': 1,
     'p_model_layers_0_self_attn_q_proj_weight': 1,
     'p_model_layers_0_self_attn_v_proj_weight': 1,
     'p_model_norm_weight': 1}
    _known_value_shape={}
    _known_constants=['b_model_rotary_emb_inv_freq',
     'c_model_lifted_tensor_0',
     'model.embed_tokens.weight',
     'model.layers.0.input_layernorm.weight',
     'model.layers.0.mlp.down_proj.weight',
     'model.layers.0.mlp.gate_proj.weight',
     'model.layers.0.mlp.up_proj.weight',
     'model.layers.0.post_attention_layernorm.weight',
     'model.layers.0.self_attn.k_proj.weight',
     'model.layers.0.self_attn.o_proj.weight',
     'model.layers.0.self_attn.q_proj.weight',
     'model.layers.0.self_attn.v_proj.weight',
     'model.norm.weight',
     'p_model_embed_tokens_weight',
     'p_model_layers_0_input_layernorm_weight',
     'p_model_layers_0_mlp_down_proj_weight',
     'p_model_layers_0_mlp_gate_proj_weight',
     'p_model_layers_0_mlp_up_proj_weight',
     'p_model_layers_0_post_attention_layernorm_weight',
     'p_model_layers_0_self_attn_k_proj_weight',
     'p_model_layers_0_self_attn_o_proj_weight',
     'p_model_layers_0_self_attn_q_proj_weight',
     'p_model_layers_0_self_attn_v_proj_weight',
     'p_model_norm_weight']
    _known_ranks={}
    --PARAMETERS--
    _parameter_renaming=
       p_model_embed_tokens_weight = 'model.embed_tokens.weight'
       p_model_layers_0_input_layernorm_weight = 'model.layers.0.input_layernorm.weight'
       p_model_layers_0_mlp_down_proj_weight = 'model.layers.0.mlp.down_proj.weight'
       p_model_layers_0_mlp_gate_proj_weight = 'model.layers.0.mlp.gate_proj.weight'
       p_model_layers_0_mlp_up_proj_weight = 'model.layers.0.mlp.up_proj.weight'
       p_model_layers_0_post_attention_layernorm_weight = 'model.layers.0.post_attention_layernorm.weight'
       p_model_layers_0_self_attn_k_proj_weight = 'model.layers.0.self_attn.k_proj.weight'
       p_model_layers_0_self_attn_o_proj_weight = 'model.layers.0.self_attn.o_proj.weight'
       p_model_layers_0_self_attn_q_proj_weight = 'model.layers.0.self_attn.q_proj.weight'
       p_model_layers_0_self_attn_v_proj_weight = 'model.layers.0.self_attn.v_proj.weight'
       p_model_norm_weight = 'model.norm.weight'
    --TORCH-USERS--
        attention_mask -> {to}
        b_model_rotary_emb_inv_freq -> {unsqueeze_1}
        c_model_lifted_tensor_0 -> {lift_fresh_copy}
        function_const_func_spec0 -> {flat_apply}
        input_ids -> {embedding}
        p_model_embed_tokens_weight -> {embedding}
        p_model_layers_0_input_layernorm_weight -> {mul_3}
        p_model_layers_0_mlp_down_proj_weight -> {linear_6}
        p_model_layers_0_mlp_gate_proj_weight -> {linear_4}
        p_model_layers_0_mlp_up_proj_weight -> {linear_5}
        p_model_layers_0_post_attention_layernorm_weight -> {mul_10}
        p_model_layers_0_self_attn_k_proj_weight -> {linear_1}
        p_model_layers_0_self_attn_o_proj_weight -> {linear_3}
        p_model_layers_0_self_attn_q_proj_weight -> {linear}
        p_model_layers_0_self_attn_v_proj_weight -> {linear_2}
        p_model_norm_weight -> {mul_13}
    --TORCH-SHAPES--
        p_model_embed_tokens_weight: ('run_node', ('', ('val', torch.float32, torch.Size([1024, 16])))) --- 1:2:(1024, 16):
        p_model_layers_0_self_attn_q_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_self_attn_k_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_self_attn_v_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_self_attn_o_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_mlp_gate_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_mlp_up_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_mlp_down_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_input_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,):
        p_model_layers_0_post_attention_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,):
        p_model_norm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,):
        b_model_rotary_emb_inv_freq: ('run_node', ('', ('val', torch.float32, torch.Size([4])))) --- 1:1:(4,):
        c_model_lifted_tensor_0: ('run_node', ('', ('val', torch.float32, torch.Size([])))) --- 1:0:():
        input_ids: ('run_node', ('', ('val', torch.int64, torch.Size([2, 8])))) --- 7:2:(2, 8):
        attention_mask: ('run_node', ('', ('val', torch.float32, torch.Size([2, 8])))) --- 1:2:(2, 8):
        function_const_func_spec0: ('run_node', ('', '')) --- :::
    --ONNX--
    -- EXEPATH --
    export-export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.scaled_dot_product_attention.default', 'aten.setitem', <built-in function setitem>))
    -- process.graph_module --
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_model_embed_tokens_weight: "f32[1024, 16]", p_model_layers_0_self_attn_q_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_k_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_v_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_o_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_gate_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_up_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_down_proj_weight: "f32[16, 16]", p_model_layers_0_input_layernorm_weight: "f32[16]", p_model_layers_0_post_attention_layernorm_weight: "f32[16]", p_model_norm_weight: "f32[16]", b_model_rotary_emb_inv_freq: "f32[4]", c_model_lifted_tensor_0: "f32[]", input_ids: "i64[2, 8]", attention_mask: "f32[2, 8]"):
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
                function_const_func_spec0 = self.function_const_func_spec0
                torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
                embedding: "f32[2, 8, 16]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
                arange: "i64[8]" = torch.ops.aten.arange.start(0, 8, device = device(type='cpu'), pin_memory = False)
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:380 in forward, code: position_ids = cache_position.unsqueeze(0)
                unsqueeze: "i64[1, 8]" = torch.ops.aten.unsqueeze.default(arange, 0)
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
                to: "b8[2, 8]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
                arange_1: "i64[8]" = torch.ops.aten.arange.default(8, device = device(type='cpu'), pin_memory = False)
                add_: "i64[8]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
                arange_2: "i64[2]" = torch.ops.aten.arange.default(2, device = device(type='cpu'), pin_memory = False)
                _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1);  arange_2 = None
                _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3);  arange = None
                _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4);  add_ = None
                new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
                le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2);  _add_batch_dim_2 = None
                to_1: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
                and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None
                flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3);  function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None
                to_2: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  flat_apply = None
                and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None
                _remove_batch_dim: "b8[8]" = torch._functorch.predispatch._remove_batch_dim(and_2, 4, 8, 0);  and_2 = None
                _remove_batch_dim_1: "b8[8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 8, 0);  _remove_batch_dim = None
                _remove_batch_dim_2: "b8[1, 8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0);  _remove_batch_dim_1 = None
                _remove_batch_dim_3: "b8[2, 1, 8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0);  _remove_batch_dim_2 = None
                lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_model_lifted_tensor_0);  c_model_lifted_tensor_0 = None
                detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None
                where: "f32[2, 1, 8, 8]" = torch.ops.aten.where.ScalarOther(_remove_batch_dim_3, detach_, -3.4028234663852886e+38);  _remove_batch_dim_3 = detach_ = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:96 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                unsqueeze_1: "f32[1, 4]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                unsqueeze_2: "f32[1, 4, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 2);  unsqueeze_1 = None
                to_3: "f32[1, 4, 1]" = torch.ops.aten.to.dtype(unsqueeze_2, torch.float32);  unsqueeze_2 = None
                expand_1: "f32[1, 4, 1]" = torch.ops.aten.expand.default(to_3, [1, -1, 1]);  to_3 = None
                to_4: "f32[1, 4, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:97 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                unsqueeze_3: "i64[1, 1, 8]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
                to_5: "f32[1, 1, 8]" = torch.ops.aten.to.dtype(unsqueeze_3, torch.float32);  unsqueeze_3 = None
                
                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
                mul: "f32[1, 8, 8]" = wrap_with_autocast[0]
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_1: "f32[1, 8, 8]" = wrap_with_autocast[1];  wrap_with_autocast = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_8: "f32[1, 8, 8]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_9: "f32[1, 8, 8]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_10: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_1: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
                mean: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean, 1e-06);  mean = None
                rsqrt: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add);  add = None
                mul_2: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_11: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
                mul_3: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear, [2, 8, -1, 8]);  linear = None
                transpose_1: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_1: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view_1: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear_1, [2, 8, -1, 8]);  linear_1 = None
                transpose_2: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_2: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view_2: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear_2, [2, 8, -1, 8]);  linear_2 = None
                transpose_3: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
                unsqueeze_4: "f32[1, 1, 8, 8]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
                unsqueeze_5: "f32[1, 1, 8, 8]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
                mul_4: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_4)
                slice_1: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 4)
                slice_2: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 4, 9223372036854775807);  transpose_1 = None
                neg: "f32[2, 2, 8, 4]" = torch.ops.aten.neg.default(slice_2);  slice_2 = None
                cat_1: "f32[2, 2, 8, 8]" = torch.ops.aten.cat.default([neg, slice_1], -1);  neg = slice_1 = None
                mul_5: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5);  cat_1 = None
                add_1: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
                mul_6: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_4);  unsqueeze_4 = None
                slice_3: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 4)
                slice_4: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 4, 9223372036854775807);  transpose_2 = None
                neg_1: "f32[2, 2, 8, 4]" = torch.ops.aten.neg.default(slice_4);  slice_4 = None
                cat_2: "f32[2, 2, 8, 8]" = torch.ops.aten.cat.default([neg_1, slice_3], -1);  neg_1 = slice_3 = None
                mul_7: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_5);  cat_2 = unsqueeze_5 = None
                add_2: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
                transpose_4: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(add_2, 2, 3);  add_2 = None
                matmul_1: "f32[2, 2, 8, 8]" = torch.ops.aten.matmul.default(add_1, transpose_4);  add_1 = transpose_4 = None
                mul_8: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(matmul_1, 0.3535533905932738);  matmul_1 = None
                alias: "f32[2, 1, 8, 8]" = torch.ops.aten.alias.default(where);  where = None
                add_3: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_8, alias);  mul_8 = alias = None
                softmax: "f32[2, 2, 8, 8]" = torch.ops.aten.softmax.int(add_3, -1, torch.float32);  add_3 = None
                to_12: "f32[2, 2, 8, 8]" = torch.ops.aten.to.dtype(softmax, torch.float32);  softmax = None
                dropout: "f32[2, 2, 8, 8]" = torch.ops.aten.dropout.default(to_12, 0.0, True);  to_12 = None
                matmul_2: "f32[2, 2, 8, 8]" = torch.ops.aten.matmul.default(dropout, transpose_3);  dropout = transpose_3 = None
                transpose_5: "f32[2, 8, 2, 8]" = torch.ops.aten.transpose.int(matmul_2, 1, 2);  matmul_2 = None
                contiguous: "f32[2, 8, 2, 8]" = torch.ops.aten.contiguous.default(transpose_5);  transpose_5 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
                reshape: "f32[2, 8, 16]" = torch.ops.aten.reshape.default(contiguous, [2, 8, -1]);  contiguous = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_3: "f32[2, 8, 16]" = torch.ops.aten.linear.default(reshape, p_model_layers_0_self_attn_o_proj_weight);  reshape = p_model_layers_0_self_attn_o_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
                add_4: "f32[2, 8, 16]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_13: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(add_4, torch.float32);  add_4 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_2: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2)
                mean_1: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add_5: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-06);  mean_1 = None
                rsqrt_1: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add_5);  add_5 = None
                mul_9: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_13, rsqrt_1);  rsqrt_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_14: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_9, torch.float32);  mul_9 = None
                mul_10: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_14);  p_model_layers_0_post_attention_layernorm_weight = to_14 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_4: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
                silu: "f32[2, 8, 16]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_5: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_up_proj_weight);  mul_10 = p_model_layers_0_mlp_up_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                mul_11: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_6: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_11, p_model_layers_0_mlp_down_proj_weight);  mul_11 = p_model_layers_0_mlp_down_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
                add_6: "f32[2, 8, 16]" = torch.ops.aten.add.Tensor(to_13, linear_6);  to_13 = linear_6 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_15: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_3: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2)
                mean_2: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add_7: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-06);  mean_2 = None
                rsqrt_2: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
                mul_12: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_2);  to_15 = rsqrt_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_16: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_12, torch.float32);  mul_12 = None
                mul_13: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_16);  p_model_norm_weight = to_16 = None
                return (mul_13,)
                
            class submod_1(torch.nn.Module):
                def forward(self, to_4: "f32[1, 4, 1]", to_5: "f32[1, 1, 8]"):
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:101 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                    to_6: "f32[1, 4, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                    _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                    to_7: "f32[1, 1, 8]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                    matmul: "f32[1, 4, 8]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                    transpose: "f32[1, 8, 4]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:102 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[1, 8, 8]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[1, 8, 8]" = torch.ops.aten.cos.default(cat)
                    mul: "f32[1, 8, 8]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[1, 8, 8]" = torch.ops.aten.sin.default(cat);  cat = None
                    mul_1: "f32[1, 8, 8]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                    return (mul, mul_1)
                    
    Graph signature: 
        # inputs
        p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
        p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
        p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
        p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
        p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
        p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
        p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
        p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
        p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
        p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
        p_model_norm_weight: PARAMETER target='model.norm.weight'
        b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
        c_model_lifted_tensor_0: CONSTANT_TENSOR target='model.lifted_tensor_0'
        input_ids: USER_INPUT
        attention_mask: USER_INPUT
        
        # outputs
        mul_13: USER_OUTPUT
        
    Range constraints: {}
    
    -- process.graph_module.graph --
    graph():
        %p_model_embed_tokens_weight : [num_users=1] = placeholder[target=p_model_embed_tokens_weight]
        %p_model_layers_0_self_attn_q_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_q_proj_weight]
        %p_model_layers_0_self_attn_k_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_k_proj_weight]
        %p_model_layers_0_self_attn_v_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_v_proj_weight]
        %p_model_layers_0_self_attn_o_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_o_proj_weight]
        %p_model_layers_0_mlp_gate_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_gate_proj_weight]
        %p_model_layers_0_mlp_up_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_up_proj_weight]
        %p_model_layers_0_mlp_down_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_down_proj_weight]
        %p_model_layers_0_input_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_input_layernorm_weight]
        %p_model_layers_0_post_attention_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_post_attention_layernorm_weight]
        %p_model_norm_weight : [num_users=1] = placeholder[target=p_model_norm_weight]
        %b_model_rotary_emb_inv_freq : [num_users=1] = placeholder[target=b_model_rotary_emb_inv_freq]
        %c_model_lifted_tensor_0 : [num_users=1] = placeholder[target=c_model_lifted_tensor_0]
        %input_ids : [num_users=1] = placeholder[target=input_ids]
        %attention_mask : [num_users=1] = placeholder[target=attention_mask]
        %function_const_func_spec0 : [num_users=1] = get_attr[target=function_const_func_spec0]
        %torch__dynamo__trace_wrapped_higher_order_op_mod_index0 : [num_users=1] = get_attr[target=torch__dynamo__trace_wrapped_higher_order_op_ModIndex0]
        %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_model_embed_tokens_weight, %input_ids), kwargs = {})
        %arange : [num_users=2] = call_function[target=torch.ops.aten.arange.start](args = (0, 8), kwargs = {device: cpu, pin_memory: False})
        %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {})
        %to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%attention_mask, cpu, torch.bool), kwargs = {})
        %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (8,), kwargs = {device: cpu, pin_memory: False})
        %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%arange_1, 0), kwargs = {})
        %arange_2 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (2,), kwargs = {device: cpu, pin_memory: False})
        %_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange_2, 0, 1), kwargs = {})
        %_add_batch_dim_2 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange, 0, 3), kwargs = {})
        %_add_batch_dim_3 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%add_, 0, 4), kwargs = {})
        %new_ones : [num_users=1] = call_function[target=torch.ops.aten.new_ones.default](args = (%_add_batch_dim_2, []), kwargs = {dtype: torch.bool, pin_memory: False})
        %le : [num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%_add_batch_dim_3, %_add_batch_dim_2), kwargs = {})
        %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%le,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
        %and_1 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%new_ones, %to_1), kwargs = {})
        %flat_apply : [num_users=1] = call_function[target=torch.ops.higher_order.flat_apply](args = (%function_const_func_spec0, %torch__dynamo__trace_wrapped_higher_order_op_mod_index0, torch._dynamo._trace_wrapped_higher_order_op.ModIndex, %to, %_add_batch_dim, %_add_batch_dim_3), kwargs = {})
        %to_2 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%flat_apply,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
        %and_2 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%and_1, %to_2), kwargs = {})
        %_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%and_2, 4, 8, 0), kwargs = {})
        %_remove_batch_dim_1 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim, 3, 8, 0), kwargs = {})
        %_remove_batch_dim_2 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_1, 2, 1, 0), kwargs = {})
        %_remove_batch_dim_3 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_2, 1, 2, 0), kwargs = {})
        %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_model_lifted_tensor_0,), kwargs = {})
        %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
        %where : [num_users=1] = call_function[target=torch.ops.aten.where.ScalarOther](args = (%_remove_batch_dim_3, %detach_, -3.4028234663852886e+38), kwargs = {})
        %unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%b_model_rotary_emb_inv_freq, 0), kwargs = {})
        %unsqueeze_2 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_1, 2), kwargs = {})
        %to_3 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_2, torch.float32), kwargs = {})
        %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%to_3, [1, -1, 1]), kwargs = {})
        %to_4 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%expand_1,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
        %unsqueeze_3 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze, 1), kwargs = {})
        %to_5 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_3, torch.float32), kwargs = {})
        %submod_3 : [num_users=1] = get_attr[target=submod_1]
        %wrap_with_autocast : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_autocast](args = (cpu, torch.bfloat16, False, False, %submod_3, %to_4, %to_5), kwargs = {})
        %mul : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 0), kwargs = {})
        %mul_1 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 1), kwargs = {})
        %to_8 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul, torch.float32), kwargs = {})
        %to_9 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_1, torch.float32), kwargs = {})
        %to_10 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%embedding, torch.float32), kwargs = {})
        %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_10, 2), kwargs = {})
        %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
        %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {})
        %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
        %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_10, %rsqrt), kwargs = {})
        %to_11 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_2, torch.float32), kwargs = {})
        %mul_3 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_input_layernorm_weight, %to_11), kwargs = {})
        %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_q_proj_weight), kwargs = {})
        %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear, [2, 8, -1, 8]), kwargs = {})
        %transpose_1 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view, 1, 2), kwargs = {})
        %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_k_proj_weight), kwargs = {})
        %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_1, [2, 8, -1, 8]), kwargs = {})
        %transpose_2 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view_1, 1, 2), kwargs = {})
        %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_v_proj_weight), kwargs = {})
        %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_2, [2, 8, -1, 8]), kwargs = {})
        %transpose_3 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_2, 1, 2), kwargs = {})
        %unsqueeze_4 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_8, 1), kwargs = {})
        %unsqueeze_5 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_9, 1), kwargs = {})
        %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_1, %unsqueeze_4), kwargs = {})
        %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 0, 4), kwargs = {})
        %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 4, 9223372036854775807), kwargs = {})
        %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
        %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
        %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_5), kwargs = {})
        %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
        %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_2, %unsqueeze_4), kwargs = {})
        %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 0, 4), kwargs = {})
        %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 4, 9223372036854775807), kwargs = {})
        %neg_1 : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
        %cat_2 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
        %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %unsqueeze_5), kwargs = {})
        %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
        %transpose_4 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%add_2, 2, 3), kwargs = {})
        %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%add_1, %transpose_4), kwargs = {})
        %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_1, 0.3535533905932738), kwargs = {})
        %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%where,), kwargs = {})
        %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_8, %alias), kwargs = {})
        %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%add_3, -1, torch.float32), kwargs = {})
        %to_12 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%softmax, torch.float32), kwargs = {})
        %dropout : [num_users=1] = call_function[target=torch.ops.aten.dropout.default](args = (%to_12, 0.0, True), kwargs = {})
        %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%dropout, %transpose_3), kwargs = {})
        %transpose_5 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%matmul_2, 1, 2), kwargs = {})
        %contiguous : [num_users=1] = call_function[target=torch.ops.aten.contiguous.default](args = (%transpose_5,), kwargs = {})
        %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%contiguous, [2, 8, -1]), kwargs = {})
        %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%reshape, %p_model_layers_0_self_attn_o_proj_weight), kwargs = {})
        %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_10, %linear_3), kwargs = {})
        %to_13 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%add_4, torch.float32), kwargs = {})
        %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_13, 2), kwargs = {})
        %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
        %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-06), kwargs = {})
        %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {})
        %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_13, %rsqrt_1), kwargs = {})
        %to_14 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_9, torch.float32), kwargs = {})
        %mul_10 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_post_attention_layernorm_weight, %to_14), kwargs = {})
        %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_gate_proj_weight), kwargs = {})
        %silu : [num_users=1] = call_function[target=torch.ops.aten.silu.default](args = (%linear_4,), kwargs = {})
        %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_up_proj_weight), kwargs = {})
        %mul_11 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%silu, %linear_5), kwargs = {})
        %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_11, %p_model_layers_0_mlp_down_proj_weight), kwargs = {})
        %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_13, %linear_6), kwargs = {})
        %to_15 : [num_users=2] = call_function[target=torch.ops.aten.to.dtype](args = (%add_6, torch.float32), kwargs = {})
        %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_15, 2), kwargs = {})
        %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
        %add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-06), kwargs = {})
        %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_7,), kwargs = {})
        %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_15, %rsqrt_2), kwargs = {})
        %to_16 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_12, torch.float32), kwargs = {})
        %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_norm_weight, %to_16), kwargs = {})
        return (mul_13,)
    -- process.inputs_to_remove --
    set()
    -- process.progress --
    node 15/124 target=function_const_func_spec0
    -- 2 INPUTS
    [GraphBuilder-YHC.make_tensor_input] input0[7:2x8]
    [GraphBuilder-YHC.make_tensor_input] input1[1:2x8]
    -- 24 INITIALIZERS
    [GraphBuilder-YHC.make_initializer] p_model_embed_tokens_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight)
    [GraphBuilder-YHC.make_initializer] model.embed_tokens.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_q_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.q_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_k_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.k_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_v_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.v_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_self_attn_o_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.self_attn.o_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_mlp_gate_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.mlp.gate_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_mlp_up_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.mlp.up_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_mlp_down_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.mlp.down_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_input_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.input_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight)
    [GraphBuilder-YHC.make_initializer] p_model_layers_0_post_attention_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight)
    [GraphBuilder-YHC.make_initializer] model.layers.0.post_attention_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight)
    [GraphBuilder-YHC.make_initializer] p_model_norm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight)
    [GraphBuilder-YHC.make_initializer] model.norm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight)
    [GraphBuilder-YHC.make_initializer] b_model_rotary_emb_inv_freq[torch.float32:torch.float32:[1.0, 0.10000000149011612, 0.009999999776482582, 0.0010000000474974513]] - SOURCE: DynamoInterpret.placeholder.0
    [GraphBuilder-YHC.make_initializer] c_model_lifted_tensor_0[torch.float32:torch.float32:[0.0]] - SOURCE: DynamoInterpret.placeholder.0
    [GraphBuilder-YHC.make_node] make_tensor_input_id [#:#   ] Identity:['input0']->['input_ids']
    [GraphBuilder-YHC.make_node] make_tensor_input_id2 [#:#   ] Identity:['input1']->['attention_mask']
    -- 0 OUTPUTS
    [GraphBuilder-YHC] Message completed, there are 24 initializers, 2 nodes, 2 inputs, 2 outputs.
    nanobind: leaked 1 instances!
    nanobind: leaked 10 types!
     - leaked type "onnx.onnx_cpp2py_export.defs.Attribute"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.InferenceContext"
     - leaked type "onnx.onnx_cpp2py_export.defs.OpSchema"
     - leaked type "onnx.onnx_cpp2py_export.shape_inference.GraphInferencer"
     - leaked type "onnx.onnx_cpp2py_export.defs.FormalParameter"
     - leaked type "FormalParameterOption"
     - leaked type "SupportType"
     - leaked type "DifferentiationCategory"
     - leaked type "onnx.onnx_cpp2py_export.defs.TypeConstraintParam"
     - leaked type "AttrType"
     - ... skipped remainder
    nanobind: leaked 73 functions!
     - leaked function ""
     - leaked function ""
     - leaked function "has_schema"
     - leaked function ""
     - leaked function "set_output_type"
     - leaked function "get_symbolic_input"
     - leaked function ""
     - leaked function "get_function_with_opset_version"
     - leaked function ""
     - leaked function ""
     - leaked function "has_output"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_input_sparse_data"
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function "get_num_outputs"
     - leaked function "get_input_data"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "_infer_node_outputs"
     - leaked function "get_type_and_shape_inference_function"
     - leaked function "__init__"
     - leaked function ""
     - leaked function ""
     - leaked function "do_inferencing"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "__init__"
     - leaked function "get_schema"
     - leaked function "get_all_schemas"
     - leaked function ""
     - leaked function ""
     - leaked function "deregister_schema"
     - leaked function ""
     - leaked function ""
     - leaked function "get_output_type"
     - leaked function "is_infinite"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_input_type"
     - leaked function ""
     - leaked function "get_graph_attribute_inferencer"
     - leaked function "set_type_and_shape_inference_function"
     - leaked function ""
     - leaked function ""
     - leaked function "get_attribute"
     - leaked function ""
     - leaked function "get_all_schemas_with_history"
     - leaked function ""
     - leaked function "has_input"
     - leaked function "get_context_dependent_function"
     - leaked function "get_display_name"
     - leaked function ""
     - leaked function "__init__"
     - leaked function "get_num_inputs"
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function ""
     - leaked function "get_context_dependent_function_with_opset_version"
    nanobind: this is likely caused by a reference counting issue in the binding code.
    See https://nanobind.readthedocs.io/en/latest/refleaks.html
fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]

If DynamoInterpreter cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.

Parameters:

verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

property submodules: Dict[str, Callable]

Returns the submodules implementing torch functions.