.torch_interpreter.oxs_dispatcher

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]

Tries the fallback even if is not necessary to check it is working.

Parameters:
  • verbose – verbosity

  • raise_exc – fail or raise an exception

The class can be used the following way.

<<<

import torch
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import (
    OxsDebugDispatcher,
)

with torch.no_grad():
    model, input_tensors = get_llama_model()
    input_tensors = input_tensors[0]

    to_onnx(
        model,
        input_tensors,
        input_names=[f"input{i}" for i in range(len(input_tensors))],
        options=OptimizationOptions(patterns=None),
        verbose=0,
        dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False),
    )

>>>

    
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 19, in <module>
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 1068, in to_onnx
        builder.process(graph_module, interpreter)
      File "~/github/experimental-experiment/experimental_experiment/xbuilder/graph_builder.py", line 5032, in process
        interpreter.run_node(node)
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 229, in run_node
        res = self.get_attr(node)
              ^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/interpreter.py", line 336, in get_attr
        raise NotImplementedError(
    NotImplementedError: Unable to handle type <class 'torch.utils._pytree.TreeSpec'> for node.name='function_const_func_spec0'
    Node(type=<class 'torch._higher_order_ops.flat_apply._ConstantFunction'>, leaves=0)
    --
    --DEBUG--
    [GraphBuilder-ZYE] Message starts, there are 24 initializers, 2 nodes, 2 inputs, 2 outputs.
    --SHAPE--
    _dynamic_examples=
    dynamic_objects=
    dynamic_objects_rev=
    dynamic_dimensions_source={}
    dynamic_dimensions_source_flat=None
    output_dynamic_dimensions_source_flat=None
    dynamic_alias={}
    dynamic_shapes=None
    _known_shapes={'attention_mask': (2, 8),
     'b_model_rotary_emb_inv_freq': (4,),
     'c_model_lifted_tensor_0': (),
     'input0': (2, 8),
     'input1': (2, 8),
     'input_ids': (2, 8),
     'model.embed_tokens.weight': (1024, 16),
     'model.layers.0.input_layernorm.weight': (16,),
     'model.layers.0.mlp.down_proj.weight': (16, 16),
     'model.layers.0.mlp.gate_proj.weight': (16, 16),
     'model.layers.0.mlp.up_proj.weight': (16, 16),
     'model.layers.0.post_attention_layernorm.weight': (16,),
     'model.layers.0.self_attn.k_proj.weight': (16, 16),
     'model.layers.0.self_attn.o_proj.weight': (16, 16),
     'model.layers.0.self_attn.q_proj.weight': (16, 16),
     'model.layers.0.self_attn.v_proj.weight': (16, 16),
     'model.norm.weight': (16,),
     'p_model_embed_tokens_weight': (1024, 16),
     'p_model_layers_0_input_layernorm_weight': (16,),
     'p_model_layers_0_mlp_down_proj_weight': (16, 16),
     'p_model_layers_0_mlp_gate_proj_weight': (16, 16),
     'p_model_layers_0_mlp_up_proj_weight': (16, 16),
     'p_model_layers_0_post_attention_layernorm_weight': (16,),
     'p_model_layers_0_self_attn_k_proj_weight': (16, 16),
     'p_model_layers_0_self_attn_o_proj_weight': (16, 16),
     'p_model_layers_0_self_attn_q_proj_weight': (16, 16),
     'p_model_layers_0_self_attn_v_proj_weight': (16, 16),
     'p_model_norm_weight': (16,)}
    _known_types={'attention_mask': 1,
     'b_model_rotary_emb_inv_freq': 1,
     'c_model_lifted_tensor_0': 1,
     'input0': 7,
     'input1': 1,
     'input_ids': 7,
     'model.embed_tokens.weight': 1,
     'model.layers.0.input_layernorm.weight': 1,
     'model.layers.0.mlp.down_proj.weight': 1,
     'model.layers.0.mlp.gate_proj.weight': 1,
     'model.layers.0.mlp.up_proj.weight': 1,
     'model.layers.0.post_attention_layernorm.weight': 1,
     'model.layers.0.self_attn.k_proj.weight': 1,
     'model.layers.0.self_attn.o_proj.weight': 1,
     'model.layers.0.self_attn.q_proj.weight': 1,
     'model.layers.0.self_attn.v_proj.weight': 1,
     'model.norm.weight': 1,
     'p_model_embed_tokens_weight': 1,
     'p_model_layers_0_input_layernorm_weight': 1,
     'p_model_layers_0_mlp_down_proj_weight': 1,
     'p_model_layers_0_mlp_gate_proj_weight': 1,
     'p_model_layers_0_mlp_up_proj_weight': 1,
     'p_model_layers_0_post_attention_layernorm_weight': 1,
     'p_model_layers_0_self_attn_k_proj_weight': 1,
     'p_model_layers_0_self_attn_o_proj_weight': 1,
     'p_model_layers_0_self_attn_q_proj_weight': 1,
     'p_model_layers_0_self_attn_v_proj_weight': 1,
     'p_model_norm_weight': 1}
    _known_value_shape={}
    _known_constants=['b_model_rotary_emb_inv_freq',
     'c_model_lifted_tensor_0',
     'model.embed_tokens.weight',
     'model.layers.0.input_layernorm.weight',
     'model.layers.0.mlp.down_proj.weight',
     'model.layers.0.mlp.gate_proj.weight',
     'model.layers.0.mlp.up_proj.weight',
     'model.layers.0.post_attention_layernorm.weight',
     'model.layers.0.self_attn.k_proj.weight',
     'model.layers.0.self_attn.o_proj.weight',
     'model.layers.0.self_attn.q_proj.weight',
     'model.layers.0.self_attn.v_proj.weight',
     'model.norm.weight',
     'p_model_embed_tokens_weight',
     'p_model_layers_0_input_layernorm_weight',
     'p_model_layers_0_mlp_down_proj_weight',
     'p_model_layers_0_mlp_gate_proj_weight',
     'p_model_layers_0_mlp_up_proj_weight',
     'p_model_layers_0_post_attention_layernorm_weight',
     'p_model_layers_0_self_attn_k_proj_weight',
     'p_model_layers_0_self_attn_o_proj_weight',
     'p_model_layers_0_self_attn_q_proj_weight',
     'p_model_layers_0_self_attn_v_proj_weight',
     'p_model_norm_weight']
    _known_ranks (with no shape)={}
    --PARAMETERS--
    _parameter_renaming=
       p_model_embed_tokens_weight = 'model.embed_tokens.weight'
       p_model_layers_0_input_layernorm_weight = 'model.layers.0.input_layernorm.weight'
       p_model_layers_0_mlp_down_proj_weight = 'model.layers.0.mlp.down_proj.weight'
       p_model_layers_0_mlp_gate_proj_weight = 'model.layers.0.mlp.gate_proj.weight'
       p_model_layers_0_mlp_up_proj_weight = 'model.layers.0.mlp.up_proj.weight'
       p_model_layers_0_post_attention_layernorm_weight = 'model.layers.0.post_attention_layernorm.weight'
       p_model_layers_0_self_attn_k_proj_weight = 'model.layers.0.self_attn.k_proj.weight'
       p_model_layers_0_self_attn_o_proj_weight = 'model.layers.0.self_attn.o_proj.weight'
       p_model_layers_0_self_attn_q_proj_weight = 'model.layers.0.self_attn.q_proj.weight'
       p_model_layers_0_self_attn_v_proj_weight = 'model.layers.0.self_attn.v_proj.weight'
       p_model_norm_weight = 'model.norm.weight'
    --TORCH-USERS--
        attention_mask -> {to}
        b_model_rotary_emb_inv_freq -> {unsqueeze_1}
        c_model_lifted_tensor_0 -> {lift_fresh_copy}
        function_const_func_spec0 -> {flat_apply}
        input_ids -> {embedding}
        p_model_embed_tokens_weight -> {embedding}
        p_model_layers_0_input_layernorm_weight -> {mul_3}
        p_model_layers_0_mlp_down_proj_weight -> {linear_6}
        p_model_layers_0_mlp_gate_proj_weight -> {linear_4}
        p_model_layers_0_mlp_up_proj_weight -> {linear_5}
        p_model_layers_0_post_attention_layernorm_weight -> {mul_10}
        p_model_layers_0_self_attn_k_proj_weight -> {linear_1}
        p_model_layers_0_self_attn_o_proj_weight -> {linear_3}
        p_model_layers_0_self_attn_q_proj_weight -> {linear}
        p_model_layers_0_self_attn_v_proj_weight -> {linear_2}
        p_model_norm_weight -> {mul_13}
    --TORCH-SHAPES--
        p_model_embed_tokens_weight: ('run_node', ('', ('val', torch.float32, torch.Size([1024, 16])))) --- 1:2:(1024, 16):
        p_model_layers_0_self_attn_q_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_self_attn_k_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_self_attn_v_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_self_attn_o_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_mlp_gate_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_mlp_up_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_mlp_down_proj_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16, 16])))) --- 1:2:(16, 16):
        p_model_layers_0_input_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,):
        p_model_layers_0_post_attention_layernorm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,):
        p_model_norm_weight: ('run_node', ('', ('val', torch.float32, torch.Size([16])))) --- 1:1:(16,):
        b_model_rotary_emb_inv_freq: ('run_node', ('', ('val', torch.float32, torch.Size([4])))) --- 1:1:(4,):
        c_model_lifted_tensor_0: ('run_node', ('', ('val', torch.float32, torch.Size([])))) --- 1:0:():
        input_ids: ('run_node', ('', ('val', torch.int64, torch.Size([2, 8])))) --- 7:2:(2, 8):
        attention_mask: ('run_node', ('', ('val', torch.float32, torch.Size([2, 8])))) --- 1:2:(2, 8):
        function_const_func_spec0: ('run_node', ('', '')) --- :::
    --ONNX--
    -- EXEPATH --
    export-export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>))
    -- process.graph_module --
    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_model_embed_tokens_weight: "f32[1024, 16]", p_model_layers_0_self_attn_q_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_k_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_v_proj_weight: "f32[16, 16]", p_model_layers_0_self_attn_o_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_gate_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_up_proj_weight: "f32[16, 16]", p_model_layers_0_mlp_down_proj_weight: "f32[16, 16]", p_model_layers_0_input_layernorm_weight: "f32[16]", p_model_layers_0_post_attention_layernorm_weight: "f32[16]", p_model_norm_weight: "f32[16]", b_model_rotary_emb_inv_freq: "f32[4]", c_model_lifted_tensor_0: "f32[]", input_ids: "i64[2, 8]", attention_mask: "f32[2, 8]"):
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:380 in forward, code: causal_mask = create_causal_mask(
                function_const_func_spec0 = self.function_const_func_spec0
                torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
                embedding: "f32[2, 8, 16]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:373 in forward, code: cache_position: torch.Tensor = torch.arange(
                arange: "i64[8]" = torch.ops.aten.arange.start(0, 8, device = device(type='cpu'), pin_memory = False)
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:378 in forward, code: position_ids = cache_position.unsqueeze(0)
                unsqueeze: "i64[1, 8]" = torch.ops.aten.unsqueeze.default(arange, 0)
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:380 in forward, code: causal_mask = create_causal_mask(
                to: "b8[2, 8]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
                arange_1: "i64[8]" = torch.ops.aten.arange.default(8, device = device(type='cpu'), pin_memory = False)
                add_: "i64[8]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
                arange_2: "i64[2]" = torch.ops.aten.arange.default(2, device = device(type='cpu'), pin_memory = False)
                _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1);  arange_2 = None
                _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3);  arange = None
                _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4);  add_ = None
                new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
                le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2);  _add_batch_dim_2 = None
                to_1: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
                and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None
                flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3);  function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None
                to_2: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  flat_apply = None
                and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None
                _remove_batch_dim: "b8[8]" = torch._functorch.predispatch._remove_batch_dim(and_2, 4, 8, 0);  and_2 = None
                _remove_batch_dim_1: "b8[8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 8, 0);  _remove_batch_dim = None
                _remove_batch_dim_2: "b8[1, 8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0);  _remove_batch_dim_1 = None
                _remove_batch_dim_3: "b8[2, 1, 8, 8]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0);  _remove_batch_dim_2 = None
                lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_model_lifted_tensor_0);  c_model_lifted_tensor_0 = None
                detach_: "f32[]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None
                where: "f32[2, 1, 8, 8]" = torch.ops.aten.where.ScalarOther(_remove_batch_dim_3, detach_, -3.4028234663852886e+38);  _remove_batch_dim_3 = detach_ = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:96 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                unsqueeze_1: "f32[1, 4]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                unsqueeze_2: "f32[1, 4, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 2);  unsqueeze_1 = None
                to_3: "f32[1, 4, 1]" = torch.ops.aten.to.dtype(unsqueeze_2, torch.float32);  unsqueeze_2 = None
                expand_1: "f32[1, 4, 1]" = torch.ops.aten.expand.default(to_3, [1, -1, 1]);  to_3 = None
                to_4: "f32[1, 4, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:97 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                unsqueeze_3: "i64[1, 1, 8]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
                to_5: "f32[1, 1, 8]" = torch.ops.aten.to.dtype(unsqueeze_3, torch.float32);  unsqueeze_3 = None
                
                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
                mul: "f32[1, 8, 8]" = wrap_with_autocast[0]
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_1: "f32[1, 8, 8]" = wrap_with_autocast[1];  wrap_with_autocast = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_8: "f32[1, 8, 8]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_9: "f32[1, 8, 8]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_10: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_1: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
                mean: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean, 1e-06);  mean = None
                rsqrt: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add);  add = None
                mul_2: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_11: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
                mul_3: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:235 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear, [2, 8, -1, 8]);  linear = None
                transpose_1: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_1: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view_1: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear_1, [2, 8, -1, 8]);  linear_1 = None
                transpose_2: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_2: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
                view_2: "f32[2, 8, 2, 8]" = torch.ops.aten.view.default(linear_2, [2, 8, -1, 8]);  linear_2 = None
                transpose_3: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:240 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
                unsqueeze_4: "f32[1, 1, 8, 8]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
                unsqueeze_5: "f32[1, 1, 8, 8]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
                mul_4: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_4)
                slice_1: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 4)
                slice_2: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 4, 9223372036854775807);  transpose_1 = None
                neg: "f32[2, 2, 8, 4]" = torch.ops.aten.neg.default(slice_2);  slice_2 = None
                cat_1: "f32[2, 2, 8, 8]" = torch.ops.aten.cat.default([neg, slice_1], -1);  neg = slice_1 = None
                mul_5: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5);  cat_1 = None
                add_1: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
                mul_6: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_4);  unsqueeze_4 = None
                slice_3: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 4)
                slice_4: "f32[2, 2, 8, 4]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 4, 9223372036854775807);  transpose_2 = None
                neg_1: "f32[2, 2, 8, 4]" = torch.ops.aten.neg.default(slice_4);  slice_4 = None
                cat_2: "f32[2, 2, 8, 8]" = torch.ops.aten.cat.default([neg_1, slice_3], -1);  neg_1 = slice_3 = None
                mul_7: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_5);  cat_2 = unsqueeze_5 = None
                add_2: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:251 in forward, code: attn_output, attn_weights = attention_interface(
                transpose_4: "f32[2, 2, 8, 8]" = torch.ops.aten.transpose.int(add_2, 2, 3);  add_2 = None
                matmul_1: "f32[2, 2, 8, 8]" = torch.ops.aten.matmul.default(add_1, transpose_4);  add_1 = transpose_4 = None
                mul_8: "f32[2, 2, 8, 8]" = torch.ops.aten.mul.Tensor(matmul_1, 0.3535533905932738);  matmul_1 = None
                alias: "f32[2, 1, 8, 8]" = torch.ops.aten.alias.default(where);  where = None
                add_3: "f32[2, 2, 8, 8]" = torch.ops.aten.add.Tensor(mul_8, alias);  mul_8 = alias = None
                softmax: "f32[2, 2, 8, 8]" = torch.ops.aten.softmax.int(add_3, -1, torch.float32);  add_3 = None
                to_12: "f32[2, 2, 8, 8]" = torch.ops.aten.to.dtype(softmax, torch.float32);  softmax = None
                dropout: "f32[2, 2, 8, 8]" = torch.ops.aten.dropout.default(to_12, 0.0, True);  to_12 = None
                matmul_2: "f32[2, 2, 8, 8]" = torch.ops.aten.matmul.default(dropout, transpose_3);  dropout = transpose_3 = None
                transpose_5: "f32[2, 8, 2, 8]" = torch.ops.aten.transpose.int(matmul_2, 1, 2);  matmul_2 = None
                contiguous: "f32[2, 8, 2, 8]" = torch.ops.aten.contiguous.default(transpose_5);  transpose_5 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:262 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
                reshape: "f32[2, 8, 16]" = torch.ops.aten.reshape.default(contiguous, [2, 8, -1]);  contiguous = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_3: "f32[2, 8, 16]" = torch.ops.aten.linear.default(reshape, p_model_layers_0_self_attn_o_proj_weight);  reshape = p_model_layers_0_self_attn_o_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:302 in forward, code: hidden_states = residual + hidden_states
                add_4: "f32[2, 8, 16]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_13: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(add_4, torch.float32);  add_4 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_2: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2)
                mean_1: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add_5: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-06);  mean_1 = None
                rsqrt_1: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add_5);  add_5 = None
                mul_9: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_13, rsqrt_1);  rsqrt_1 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_14: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_9, torch.float32);  mul_9 = None
                mul_10: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_14);  p_model_layers_0_post_attention_layernorm_weight = to_14 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_4: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/activations.py:103 in forward, code: return nn.functional.silu(input)
                silu: "f32[2, 8, 16]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_5: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_up_proj_weight);  mul_10 = p_model_layers_0_mlp_up_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
                mul_11: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None
                
                 # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
                linear_6: "f32[2, 8, 16]" = torch.ops.aten.linear.default(mul_11, p_model_layers_0_mlp_down_proj_weight);  mul_11 = p_model_layers_0_mlp_down_proj_weight = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:308 in forward, code: hidden_states = residual + hidden_states
                add_6: "f32[2, 8, 16]" = torch.ops.aten.add.Tensor(to_13, linear_6);  to_13 = linear_6 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
                to_15: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
                pow_3: "f32[2, 8, 16]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2)
                mean_2: "f32[2, 8, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
                add_7: "f32[2, 8, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-06);  mean_2 = None
                rsqrt_2: "f32[2, 8, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
                mul_12: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_2);  to_15 = rsqrt_2 = None
                
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
                to_16: "f32[2, 8, 16]" = torch.ops.aten.to.dtype(mul_12, torch.float32);  mul_12 = None
                mul_13: "f32[2, 8, 16]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_16);  p_model_norm_weight = to_16 = None
                return (mul_13,)
                
            class submod_1(torch.nn.Module):
                def forward(self, to_4: "f32[1, 4, 1]", to_5: "f32[1, 1, 8]"):
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:101 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                    to_6: "f32[1, 4, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                    _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                    to_7: "f32[1, 1, 8]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                    matmul: "f32[1, 4, 8]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                    transpose: "f32[1, 8, 4]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:102 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[1, 8, 8]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[1, 8, 8]" = torch.ops.aten.cos.default(cat)
                    mul: "f32[1, 8, 8]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None
                    
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[1, 8, 8]" = torch.ops.aten.sin.default(cat);  cat = None
                    mul_1: "f32[1, 8, 8]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                    return (mul, mul_1)
                    
    Graph signature: 
        # inputs
        p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
        p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
        p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
        p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
        p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
        p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
        p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
        p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
        p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
        p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
        p_model_norm_weight: PARAMETER target='model.norm.weight'
        b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
        c_model_lifted_tensor_0: CONSTANT_TENSOR target='model.lifted_tensor_0'
        input_ids: USER_INPUT
        attention_mask: USER_INPUT
        
        # outputs
        mul_13: USER_OUTPUT
        
    Range constraints: {}
    
    -- process.graph_module.graph --
    graph():
        %p_model_embed_tokens_weight : [num_users=1] = placeholder[target=p_model_embed_tokens_weight]
        %p_model_layers_0_self_attn_q_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_q_proj_weight]
        %p_model_layers_0_self_attn_k_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_k_proj_weight]
        %p_model_layers_0_self_attn_v_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_v_proj_weight]
        %p_model_layers_0_self_attn_o_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_self_attn_o_proj_weight]
        %p_model_layers_0_mlp_gate_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_gate_proj_weight]
        %p_model_layers_0_mlp_up_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_up_proj_weight]
        %p_model_layers_0_mlp_down_proj_weight : [num_users=1] = placeholder[target=p_model_layers_0_mlp_down_proj_weight]
        %p_model_layers_0_input_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_input_layernorm_weight]
        %p_model_layers_0_post_attention_layernorm_weight : [num_users=1] = placeholder[target=p_model_layers_0_post_attention_layernorm_weight]
        %p_model_norm_weight : [num_users=1] = placeholder[target=p_model_norm_weight]
        %b_model_rotary_emb_inv_freq : [num_users=1] = placeholder[target=b_model_rotary_emb_inv_freq]
        %c_model_lifted_tensor_0 : [num_users=1] = placeholder[target=c_model_lifted_tensor_0]
        %input_ids : [num_users=1] = placeholder[target=input_ids]
        %attention_mask : [num_users=1] = placeholder[target=attention_mask]
        %function_const_func_spec0 : [num_users=1] = get_attr[target=function_const_func_spec0]
        %torch__dynamo__trace_wrapped_higher_order_op_mod_index0 : [num_users=1] = get_attr[target=torch__dynamo__trace_wrapped_higher_order_op_ModIndex0]
        %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%p_model_embed_tokens_weight, %input_ids), kwargs = {})
        %arange : [num_users=2] = call_function[target=torch.ops.aten.arange.start](args = (0, 8), kwargs = {device: cpu, pin_memory: False})
        %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {})
        %to : [num_users=1] = call_function[target=torch.ops.aten.to.device](args = (%attention_mask, cpu, torch.bool), kwargs = {})
        %arange_1 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (8,), kwargs = {device: cpu, pin_memory: False})
        %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%arange_1, 0), kwargs = {})
        %arange_2 : [num_users=1] = call_function[target=torch.ops.aten.arange.default](args = (2,), kwargs = {device: cpu, pin_memory: False})
        %_add_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange_2, 0, 1), kwargs = {})
        %_add_batch_dim_2 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%arange, 0, 3), kwargs = {})
        %_add_batch_dim_3 : [num_users=2] = call_function[target=torch._functorch.predispatch._add_batch_dim](args = (%add_, 0, 4), kwargs = {})
        %new_ones : [num_users=1] = call_function[target=torch.ops.aten.new_ones.default](args = (%_add_batch_dim_2, []), kwargs = {dtype: torch.bool, pin_memory: False})
        %le : [num_users=1] = call_function[target=torch.ops.aten.le.Tensor](args = (%_add_batch_dim_3, %_add_batch_dim_2), kwargs = {})
        %to_1 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%le,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
        %and_1 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%new_ones, %to_1), kwargs = {})
        %flat_apply : [num_users=1] = call_function[target=torch.ops.higher_order.flat_apply](args = (%function_const_func_spec0, %torch__dynamo__trace_wrapped_higher_order_op_mod_index0, torch._dynamo._trace_wrapped_higher_order_op.ModIndex, %to, %_add_batch_dim, %_add_batch_dim_3), kwargs = {})
        %to_2 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%flat_apply,), kwargs = {dtype: torch.bool, layout: torch.strided, device: cpu})
        %and_2 : [num_users=1] = call_function[target=torch.ops.aten.__and__.Tensor](args = (%and_1, %to_2), kwargs = {})
        %_remove_batch_dim : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%and_2, 4, 8, 0), kwargs = {})
        %_remove_batch_dim_1 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim, 3, 8, 0), kwargs = {})
        %_remove_batch_dim_2 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_1, 2, 1, 0), kwargs = {})
        %_remove_batch_dim_3 : [num_users=1] = call_function[target=torch._functorch.predispatch._remove_batch_dim](args = (%_remove_batch_dim_2, 1, 2, 0), kwargs = {})
        %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_model_lifted_tensor_0,), kwargs = {})
        %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {})
        %where : [num_users=1] = call_function[target=torch.ops.aten.where.ScalarOther](args = (%_remove_batch_dim_3, %detach_, -3.4028234663852886e+38), kwargs = {})
        %unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%b_model_rotary_emb_inv_freq, 0), kwargs = {})
        %unsqueeze_2 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze_1, 2), kwargs = {})
        %to_3 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_2, torch.float32), kwargs = {})
        %expand_1 : [num_users=1] = call_function[target=torch.ops.aten.expand.default](args = (%to_3, [1, -1, 1]), kwargs = {})
        %to_4 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype_layout](args = (%expand_1,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
        %unsqueeze_3 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%unsqueeze, 1), kwargs = {})
        %to_5 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%unsqueeze_3, torch.float32), kwargs = {})
        %submod_3 : [num_users=1] = get_attr[target=submod_1]
        %wrap_with_autocast : [num_users=2] = call_function[target=torch.ops.higher_order.wrap_with_autocast](args = (cpu, torch.bfloat16, False, False, %submod_3, %to_4, %to_5), kwargs = {})
        %mul : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 0), kwargs = {})
        %mul_1 : [num_users=1] = call_function[target=operator.getitem](args = (%wrap_with_autocast, 1), kwargs = {})
        %to_8 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul, torch.float32), kwargs = {})
        %to_9 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_1, torch.float32), kwargs = {})
        %to_10 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%embedding, torch.float32), kwargs = {})
        %pow_1 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_10, 2), kwargs = {})
        %mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
        %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-06), kwargs = {})
        %rsqrt : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add,), kwargs = {})
        %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_10, %rsqrt), kwargs = {})
        %to_11 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_2, torch.float32), kwargs = {})
        %mul_3 : [num_users=3] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_input_layernorm_weight, %to_11), kwargs = {})
        %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_q_proj_weight), kwargs = {})
        %view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear, [2, 8, -1, 8]), kwargs = {})
        %transpose_1 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view, 1, 2), kwargs = {})
        %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_k_proj_weight), kwargs = {})
        %view_1 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_1, [2, 8, -1, 8]), kwargs = {})
        %transpose_2 : [num_users=3] = call_function[target=torch.ops.aten.transpose.int](args = (%view_1, 1, 2), kwargs = {})
        %linear_2 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_3, %p_model_layers_0_self_attn_v_proj_weight), kwargs = {})
        %view_2 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%linear_2, [2, 8, -1, 8]), kwargs = {})
        %transpose_3 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%view_2, 1, 2), kwargs = {})
        %unsqueeze_4 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_8, 1), kwargs = {})
        %unsqueeze_5 : [num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%to_9, 1), kwargs = {})
        %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_1, %unsqueeze_4), kwargs = {})
        %slice_1 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 0, 4), kwargs = {})
        %slice_2 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_1, 3, 4, 9223372036854775807), kwargs = {})
        %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_2,), kwargs = {})
        %cat_1 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg, %slice_1], -1), kwargs = {})
        %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_1, %unsqueeze_5), kwargs = {})
        %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_4, %mul_5), kwargs = {})
        %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%transpose_2, %unsqueeze_4), kwargs = {})
        %slice_3 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 0, 4), kwargs = {})
        %slice_4 : [num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%transpose_2, 3, 4, 9223372036854775807), kwargs = {})
        %neg_1 : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%slice_4,), kwargs = {})
        %cat_2 : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%neg_1, %slice_3], -1), kwargs = {})
        %mul_7 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_2, %unsqueeze_5), kwargs = {})
        %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_6, %mul_7), kwargs = {})
        %transpose_4 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%add_2, 2, 3), kwargs = {})
        %matmul_1 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%add_1, %transpose_4), kwargs = {})
        %mul_8 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%matmul_1, 0.3535533905932738), kwargs = {})
        %alias : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%where,), kwargs = {})
        %add_3 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_8, %alias), kwargs = {})
        %softmax : [num_users=1] = call_function[target=torch.ops.aten.softmax.int](args = (%add_3, -1, torch.float32), kwargs = {})
        %to_12 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%softmax, torch.float32), kwargs = {})
        %dropout : [num_users=1] = call_function[target=torch.ops.aten.dropout.default](args = (%to_12, 0.0, True), kwargs = {})
        %matmul_2 : [num_users=1] = call_function[target=torch.ops.aten.matmul.default](args = (%dropout, %transpose_3), kwargs = {})
        %transpose_5 : [num_users=1] = call_function[target=torch.ops.aten.transpose.int](args = (%matmul_2, 1, 2), kwargs = {})
        %contiguous : [num_users=1] = call_function[target=torch.ops.aten.contiguous.default](args = (%transpose_5,), kwargs = {})
        %reshape : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%contiguous, [2, 8, -1]), kwargs = {})
        %linear_3 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%reshape, %p_model_layers_0_self_attn_o_proj_weight), kwargs = {})
        %add_4 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_10, %linear_3), kwargs = {})
        %to_13 : [num_users=3] = call_function[target=torch.ops.aten.to.dtype](args = (%add_4, torch.float32), kwargs = {})
        %pow_2 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_13, 2), kwargs = {})
        %mean_1 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
        %add_5 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-06), kwargs = {})
        %rsqrt_1 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_5,), kwargs = {})
        %mul_9 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_13, %rsqrt_1), kwargs = {})
        %to_14 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_9, torch.float32), kwargs = {})
        %mul_10 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_layers_0_post_attention_layernorm_weight, %to_14), kwargs = {})
        %linear_4 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_gate_proj_weight), kwargs = {})
        %silu : [num_users=1] = call_function[target=torch.ops.aten.silu.default](args = (%linear_4,), kwargs = {})
        %linear_5 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_10, %p_model_layers_0_mlp_up_proj_weight), kwargs = {})
        %mul_11 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%silu, %linear_5), kwargs = {})
        %linear_6 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%mul_11, %p_model_layers_0_mlp_down_proj_weight), kwargs = {})
        %add_6 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%to_13, %linear_6), kwargs = {})
        %to_15 : [num_users=2] = call_function[target=torch.ops.aten.to.dtype](args = (%add_6, torch.float32), kwargs = {})
        %pow_3 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%to_15, 2), kwargs = {})
        %mean_2 : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_3, [-1], True), kwargs = {})
        %add_7 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_2, 1e-06), kwargs = {})
        %rsqrt_2 : [num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_7,), kwargs = {})
        %mul_12 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%to_15, %rsqrt_2), kwargs = {})
        %to_16 : [num_users=1] = call_function[target=torch.ops.aten.to.dtype](args = (%mul_12, torch.float32), kwargs = {})
        %mul_13 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%p_model_norm_weight, %to_16), kwargs = {})
        return (mul_13,)
    -- process.inputs_to_remove --
    set()
    -- process.progress --
    node 15/124 target=function_const_func_spec0
    -- 2 INPUTS
    [GraphBuilder-ZYE.make_tensor_input] input0[7:2x8]
    [GraphBuilder-ZYE.make_tensor_input] input1[1:2x8]
    -- 24 INITIALIZERS
    [GraphBuilder-ZYE.make_initializer] p_model_embed_tokens_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight)
    [GraphBuilder-ZYE.make_initializer] model.embed_tokens.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.embed_tokens.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_self_attn_q_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.self_attn.q_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.q_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_self_attn_k_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.self_attn.k_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.k_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_self_attn_v_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.self_attn.v_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.v_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_self_attn_o_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.self_attn.o_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.self_attn.o_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_mlp_gate_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.mlp.gate_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.gate_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_mlp_up_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.mlp.up_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.up_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_mlp_down_proj_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.mlp.down_proj.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.mlp.down_proj.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_input_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.input_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.input_layernorm.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_layers_0_post_attention_layernorm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight)
    [GraphBuilder-ZYE.make_initializer] model.layers.0.post_attention_layernorm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.layers.0.post_attention_layernorm.weight)
    [GraphBuilder-ZYE.make_initializer] p_model_norm_weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight)
    [GraphBuilder-ZYE.make_initializer] model.norm.weight[torch.float32:torch.float32] - SOURCE: DynamoInterpret.placeholder.1/P(model.norm.weight)
    [GraphBuilder-ZYE.make_initializer] b_model_rotary_emb_inv_freq[torch.float32:torch.float32:[1.0, 0.10000000149011612, 0.009999999776482582, 0.0010000000474974513]] - SOURCE: DynamoInterpret.placeholder.0
    [GraphBuilder-ZYE.make_initializer] c_model_lifted_tensor_0[torch.float32:torch.float32:[0.0]] - SOURCE: DynamoInterpret.placeholder.0
    [GraphBuilder-ZYE.make_node] make_tensor_input_id [#:#   ] Identity:['input0']->['input_ids']
    [GraphBuilder-ZYE.make_node] make_tensor_input_id2 [#:#   ] Identity:['input1']->['attention_mask']
    -- 0 OUTPUTS
    [GraphBuilder-ZYE] Message completed, there are 24 initializers, 2 nodes, 2 inputs, 2 outputs.
fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]

If DynamoInterpreter cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.

Parameters:

verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

property submodules: Dict[str, Callable]

Returns the submodules implementing torch functions.