Export Phi-3.5-mini-instruct with draft_export

Tries torch.export._draft_export.draft_export().

Model

from contextlib import redirect_stderr
from io import StringIO
from typing import Any, Dict
import torch
import torch.export._draft_export
import transformers
from experimental_experiment.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions


def get_phi35_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with two sets of inputs and different shapes.

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-3.5-mini-instruct/config.json
    <https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "Phi-3.5-mini-instruct",
        "architectures": ["Phi3ForCausalLM"],
        "attention_dropout": 0.0,
        "auto_map": {
            "AutoConfig": "configuration_phi3.Phi3Config",
            "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
        },
        "bos_token_id": 1,
        "embd_pdrop": 0.0,
        "eos_token_id": 32000,
        "hidden_act": "silu",
        "hidden_size": 3072,
        "initializer_range": 0.02,
        "intermediate_size": 8192,
        "max_position_embeddings": 131072,
        "model_type": "phi3",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "original_max_position_embeddings": 4096,
        "pad_token_id": 32000,
        "resid_pdrop": 0.0,
        "rms_norm_eps": 1e-05,
        "rope_scaling": {
            "long_factor": [
                1.0800000429153442,
                1.1100000143051147,
                1.1399999856948853,
                1.340000033378601,
                1.5899999141693115,
                1.600000023841858,
                1.6200000047683716,
                2.620000123977661,
                3.2300000190734863,
                3.2300000190734863,
                4.789999961853027,
                7.400000095367432,
                7.700000286102295,
                9.09000015258789,
                12.199999809265137,
                17.670000076293945,
                24.46000099182129,
                28.57000160217285,
                30.420001983642578,
                30.840002059936523,
                32.590003967285156,
                32.93000411987305,
                42.320003509521484,
                44.96000289916992,
                50.340003967285156,
                50.45000457763672,
                57.55000305175781,
                57.93000411987305,
                58.21000289916992,
                60.1400032043457,
                62.61000442504883,
                62.62000274658203,
                62.71000289916992,
                63.1400032043457,
                63.1400032043457,
                63.77000427246094,
                63.93000411987305,
                63.96000289916992,
                63.970001220703125,
                64.02999877929688,
                64.06999969482422,
                64.08000183105469,
                64.12000274658203,
                64.41000366210938,
                64.4800033569336,
                64.51000213623047,
                64.52999877929688,
                64.83999633789062,
            ],
            "short_factor": [
                1.0,
                1.0199999809265137,
                1.0299999713897705,
                1.0299999713897705,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0699999332427979,
                1.0999999046325684,
                1.1099998950958252,
                1.1599998474121094,
                1.1599998474121094,
                1.1699998378753662,
                1.2899998426437378,
                1.339999794960022,
                1.679999828338623,
                1.7899998426437378,
                1.8199998140335083,
                1.8499997854232788,
                1.8799997568130493,
                1.9099997282028198,
                1.9399996995925903,
                1.9899996519088745,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0799996852874756,
                2.0899996757507324,
                2.189999580383301,
                2.2199995517730713,
                2.5899994373321533,
                2.729999542236328,
                2.749999523162842,
                2.8399994373321533,
            ],
            "type": "longrope",
        },
        "rope_theta": 10000.0,
        "sliding_window": 262144,
        "tie_word_embeddings": False,
        "torch_dtype": "bfloat16",
        "use_cache": True,
        "attention_bias": False,
        "vocab_size": 32064,
    }
    config.update(**kwargs)
    conf = transformers.Phi3Config(**config)
    model = transformers.Phi3ForCausalLM(conf)
    model.eval()

    cache = make_dynamic_cache(
        [
            (torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96))
            for i in range(config["num_hidden_layers"])
        ]
    )
    cache2 = make_dynamic_cache(
        [
            (torch.randn(batch_size + 1, 32, 31, 96), torch.randn(batch_size + 1, 32, 31, 96))
            for i in range(config["num_hidden_layers"])
        ]
    )

    inputs = dict(
        input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    return dict(inputs=inputs, model=model, inputs2=inputs2)


data = get_phi35_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))

Draft Export

The function we want to try.

err = StringIO()
with redirect_stderr(err), register_additional_serialization_functions(patch_transformers=True):
    ep = torch.export._draft_export.draft_export(model, tuple(), kwargs=inputs, strict=False)

Errors if any.

print(err.getvalue())

Let’s print the report.

print(ep._report)
##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################

And the exported program.

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32064, 3072]", p_model_layers_0_self_attn_o_proj_weight: "f32[3072, 3072]", p_model_layers_0_self_attn_qkv_proj_weight: "f32[9216, 3072]", p_model_layers_0_mlp_gate_up_proj_weight: "f32[16384, 3072]", p_model_layers_0_mlp_down_proj_weight: "f32[3072, 8192]", p_model_layers_0_input_layernorm_weight: "f32[3072]", p_model_layers_0_post_attention_layernorm_weight: "f32[3072]", p_model_layers_1_self_attn_o_proj_weight: "f32[3072, 3072]", p_model_layers_1_self_attn_qkv_proj_weight: "f32[9216, 3072]", p_model_layers_1_mlp_gate_up_proj_weight: "f32[16384, 3072]", p_model_layers_1_mlp_down_proj_weight: "f32[3072, 8192]", p_model_layers_1_input_layernorm_weight: "f32[3072]", p_model_layers_1_post_attention_layernorm_weight: "f32[3072]", p_model_norm_weight: "f32[3072]", p_lm_head_weight: "f32[32064, 3072]", b_model_rotary_emb_inv_freq: "f32[48]", c_lifted_tensor_0: "f32[0]", c_lifted_tensor_1: "f32[0]", c_lifted_tensor_2: "f32[0]", c_lifted_tensor_3: "f32[0]", c_model_rotary_emb_lifted_tensor_4: "f32[48]", input_ids: "i64[2, 3]", attention_mask: "i64[2, 33]", past_key_values_key_cache_0: "f32[2, 32, 30, 96]", past_key_values_key_cache_1: "f32[2, 32, 30, 96]", past_key_values_value_cache_0: "f32[2, 32, 30, 96]", past_key_values_value_cache_1: "f32[2, 32, 30, 96]"):
             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:388 in forward, code: causal_mask = mask_function(
            function_const_func_spec0 = self.function_const_func_spec0
            torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0

            # No stacktrace found for following nodes
            lift_fresh_copy: "f32[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
            detach_: "f32[0]" = torch.ops.aten.detach_.default(lift_fresh_copy);  lift_fresh_copy = None
            lift_fresh_copy_1: "f32[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_1);  c_lifted_tensor_1 = None
            detach__1: "f32[0]" = torch.ops.aten.detach_.default(lift_fresh_copy_1);  lift_fresh_copy_1 = None
            cat: "f32[2, 32, 30, 96]" = torch.ops.aten.cat.default([detach_, past_key_values_key_cache_0], -2);  detach_ = past_key_values_key_cache_0 = None
            cat_1: "f32[2, 32, 30, 96]" = torch.ops.aten.cat.default([detach__1, past_key_values_value_cache_0], -2);  detach__1 = past_key_values_value_cache_0 = None
            lift_fresh_copy_2: "f32[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_2);  c_lifted_tensor_2 = None
            detach__2: "f32[0]" = torch.ops.aten.detach_.default(lift_fresh_copy_2);  lift_fresh_copy_2 = None
            lift_fresh_copy_3: "f32[0]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_3);  c_lifted_tensor_3 = None
            detach__3: "f32[0]" = torch.ops.aten.detach_.default(lift_fresh_copy_3);  lift_fresh_copy_3 = None
            cat_2: "f32[2, 32, 30, 96]" = torch.ops.aten.cat.default([detach__2, past_key_values_key_cache_1], -2);  detach__2 = past_key_values_key_cache_1 = None
            cat_3: "f32[2, 32, 30, 96]" = torch.ops.aten.cat.default([detach__3, past_key_values_value_cache_1], -2);  detach__3 = past_key_values_value_cache_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[2, 3, 3072]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids, 32000);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:380 in forward, code: cache_position = torch.arange(
            arange: "i64[3]" = torch.ops.aten.arange.start(30, 33, device = device(type='cpu'), pin_memory = False)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:385 in forward, code: position_ids = cache_position.unsqueeze(0)
            unsqueeze: "i64[1, 3]" = torch.ops.aten.unsqueeze.default(arange, 0)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:388 in forward, code: causal_mask = mask_function(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[2, 33]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[33]" = torch.ops.aten.arange.default(33, device = device(type='cpu'), pin_memory = False)
            add_: "i64[33]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[2]" = torch.ops.aten.arange.default(2, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions = None
            _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
            _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1);  arange_2 = None
            lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_1 = None
            _vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error');  _vmap_increment_nesting_1 = None
            _add_batch_dim_1: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_3, 0, 2);  arange_3 = _add_batch_dim_1 = None
            lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_2 = None
            _vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting_2 = None
            _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3);  arange = None
            lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_3 = None
            _vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(33, 'error');  _vmap_increment_nesting_3 = None
            _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4);  add_ = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
            new_ones_1: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
            sub: "i64[]" = torch.ops.aten.sub.Tensor(_add_batch_dim_2, 262144)
            gt: "b8[]" = torch.ops.aten.gt.Tensor(_add_batch_dim_3, sub);  sub = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(gt, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "b8[]" = torch.ops.aten.to.dtype_layout(gt, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  gt = None
            and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones_1, to_1);  new_ones_1 = to_1 = None
            le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2);  _add_batch_dim_2 = None
            _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
            to_2: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
            and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None
            _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(and_2, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
            to_3: "b8[]" = torch.ops.aten.to.dtype_layout(and_2, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  and_2 = None
            and_3: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_3);  new_ones = to_3 = None
            flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3);  function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None
            _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(flat_apply, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
            to_4: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  flat_apply = None
            and_4: "b8[]" = torch.ops.aten.__and__.Tensor(and_3, to_4);  and_3 = to_4 = None
            _remove_batch_dim: "b8[33]" = torch._functorch.predispatch._remove_batch_dim(and_4, 4, 33, 0);  and_4 = None
            _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
            _remove_batch_dim_1: "b8[3, 33]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, 3, 0);  _remove_batch_dim = None
            _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
            _remove_batch_dim_2: "b8[1, 3, 33]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0)
            expand: "b8[1, 3, 33]" = torch.ops.aten.expand.default(_remove_batch_dim_1, [1, 3, 33]);  _remove_batch_dim_1 = expand = None
            _vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_2 = None
            _remove_batch_dim_3: "b8[2, 1, 3, 33]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, 2, 0);  _remove_batch_dim_2 = None
            _vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_3 = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, unsqueeze, c_model_rotary_emb_lifted_tensor_4);  submod_3 = unsqueeze = c_model_rotary_emb_lifted_tensor_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:335 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_11: "f32[1, 3, 96]" = wrap_with_set_grad_enabled[0]
            to_12: "f32[1, 3, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:227 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:228 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2)
            mean: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:229 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_1: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_1);  add_1 = None
            mul_2: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_13, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:230 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_14);  p_model_layers_0_input_layernorm_weight = to_14 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2, 3, 9216]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_qkv_proj_weight);  mul_3 = p_model_layers_0_self_attn_qkv_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:178 in forward, code: query_states = qkv[..., :query_pos]
            slice_1: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 0, 3072)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:179 in forward, code: key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
            slice_2: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 3072, 6144)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:180 in forward, code: value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
            slice_3: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 6144, 9223372036854775807);  linear = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:182 in forward, code: query_states = query_states.view(hidden_shape).transpose(1, 2)
            view: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_1, [2, 3, -1, 96]);  slice_1 = None
            transpose_1: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:183 in forward, code: key_states = key_states.view(hidden_shape).transpose(1, 2)
            view_1: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_2, [2, 3, -1, 96]);  slice_2 = None
            transpose_2: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:184 in forward, code: value_states = value_states.view(hidden_shape).transpose(1, 2)
            view_2: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_3, [2, 3, -1, 96]);  slice_3 = None
            transpose_3: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:187 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_4: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_11, 1)
            unsqueeze_5: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_12, 1)
            alias: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_1)
            slice_4: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 96, 9223372036854775807);  transpose_1 = None
            alias_1: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_2)
            slice_5: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 96, 9223372036854775807);  transpose_2 = None
            mul_4: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias, unsqueeze_4)
            slice_6: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias, 3, 0, 48)
            slice_7: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias, 3, 48, 9223372036854775807);  alias = None
            neg: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_7);  slice_7 = None
            cat_5: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg, slice_6], -1);  neg = slice_6 = None
            mul_5: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_5);  cat_5 = None
            add_2: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            cat_6: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_2, slice_4], -1);  add_2 = slice_4 = None
            mul_6: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_1, unsqueeze_4);  unsqueeze_4 = None
            slice_8: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_1, 3, 0, 48)
            slice_9: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_1, 3, 48, 9223372036854775807);  alias_1 = None
            neg_1: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_9);  slice_9 = None
            cat_7: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_1, slice_8], -1);  neg_1 = slice_8 = None
            mul_7: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_7, unsqueeze_5);  cat_7 = unsqueeze_5 = None
            add_3: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
            cat_8: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_3, slice_5], -1);  add_3 = slice_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:192 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_9: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([cat, cat_8], -2);  cat = cat_8 = None
            cat_10: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2);  cat_1 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:198 in forward, code: attn_output, attn_weights = attention_interface(
            alias_2: "b8[2, 1, 3, 33]" = torch.ops.aten.alias.default(_remove_batch_dim_3)
            scaled_dot_product_attention: "f32[2, 32, 3, 96]" = torch.ops.aten.scaled_dot_product_attention.default(cat_6, cat_9, cat_10, alias_2, scale = 0.10206207261596575);  cat_6 = alias_2 = None
            transpose_4: "f32[2, 3, 32, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous: "f32[2, 3, 32, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:210 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape: "f32[2, 3, 3072]" = torch.ops.aten.reshape.default(contiguous, [2, 3, -1]);  contiguous = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(reshape, p_model_layers_0_self_attn_o_proj_weight);  reshape = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:73 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_1, 0.0, False);  linear_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:273 in forward, code: hidden_states = residual + self.resid_attn_dropout(hidden_states)  # main diff with Llama
            add_4: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_13, dropout);  to_13 = dropout = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:227 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(add_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_4, torch.float32);  add_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:228 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2)
            mean_1: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:229 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_5: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_5);  add_5 = None
            mul_8: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:230 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_16 = None
            to_16: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_16);  p_model_layers_0_post_attention_layernorm_weight = to_16 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[2, 3, 16384]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_up_proj_weight);  mul_9 = p_model_layers_0_mlp_gate_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:62 in forward, code: gate, up_states = up_states.chunk(2, dim=-1)
            chunk = torch.ops.aten.chunk.default(linear_2, 2, -1);  linear_2 = None
            getitem_14: "f32[2, 3, 8192]" = chunk[0]
            getitem_15: "f32[2, 3, 8192]" = chunk[1];  chunk = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[2, 3, 8192]" = torch.ops.aten.silu.default(getitem_14);  getitem_14 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:63 in forward, code: up_states = up_states * self.activation_fn(gate)
            mul_10: "f32[2, 3, 8192]" = torch.ops.aten.mul.Tensor(getitem_15, silu);  getitem_15 = silu = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:73 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout_1: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_3, 0.0, False);  linear_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:278 in forward, code: hidden_states = residual + self.resid_mlp_dropout(hidden_states)  # main diff with Llama
            add_6: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_15, dropout_1);  to_15 = dropout_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:227 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_17 = None
            to_17: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:228 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_17, 2)
            mean_2: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:229 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_7: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
            mul_11: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_17, rsqrt_2);  rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:230 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_18 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_18 = None
            to_18: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_18);  p_model_layers_1_input_layernorm_weight = to_18 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[2, 3, 9216]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_qkv_proj_weight);  mul_12 = p_model_layers_1_self_attn_qkv_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:178 in forward, code: query_states = qkv[..., :query_pos]
            slice_10: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 0, 3072)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:179 in forward, code: key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
            slice_11: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 3072, 6144)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:180 in forward, code: value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
            slice_12: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 6144, 9223372036854775807);  linear_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:182 in forward, code: query_states = query_states.view(hidden_shape).transpose(1, 2)
            view_3: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_10, [2, 3, -1, 96]);  slice_10 = None
            transpose_5: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_3, 1, 2);  view_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:183 in forward, code: key_states = key_states.view(hidden_shape).transpose(1, 2)
            view_4: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_11, [2, 3, -1, 96]);  slice_11 = None
            transpose_6: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_4, 1, 2);  view_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:184 in forward, code: value_states = value_states.view(hidden_shape).transpose(1, 2)
            view_5: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_12, [2, 3, -1, 96]);  slice_12 = None
            transpose_7: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_5, 1, 2);  view_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:187 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_6: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_11, 1);  to_11 = None
            unsqueeze_7: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_12, 1);  to_12 = None
            alias_3: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_5)
            slice_13: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 96, 9223372036854775807);  transpose_5 = None
            alias_4: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_6)
            slice_14: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 96, 9223372036854775807);  transpose_6 = None
            mul_13: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_3, unsqueeze_6)
            slice_15: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_3, 3, 0, 48)
            slice_16: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_3, 3, 48, 9223372036854775807);  alias_3 = None
            neg_2: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_16);  slice_16 = None
            cat_11: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_2, slice_15], -1);  neg_2 = slice_15 = None
            mul_14: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_11, unsqueeze_7);  cat_11 = None
            add_8: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_13, mul_14);  mul_13 = mul_14 = None
            cat_12: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_8, slice_13], -1);  add_8 = slice_13 = None
            mul_15: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_4, unsqueeze_6);  unsqueeze_6 = None
            slice_17: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_4, 3, 0, 48)
            slice_18: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_4, 3, 48, 9223372036854775807);  alias_4 = None
            neg_3: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_18);  slice_18 = None
            cat_13: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_3, slice_17], -1);  neg_3 = slice_17 = None
            mul_16: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_13, unsqueeze_7);  cat_13 = unsqueeze_7 = None
            add_9: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_15, mul_16);  mul_15 = mul_16 = None
            cat_14: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_9, slice_14], -1);  add_9 = slice_14 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:192 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_15: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([cat_2, cat_14], -2);  cat_2 = cat_14 = None
            cat_16: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([cat_3, transpose_7], -2);  cat_3 = transpose_7 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:198 in forward, code: attn_output, attn_weights = attention_interface(
            alias_5: "b8[2, 1, 3, 33]" = torch.ops.aten.alias.default(_remove_batch_dim_3);  _remove_batch_dim_3 = None
            scaled_dot_product_attention_1: "f32[2, 32, 3, 96]" = torch.ops.aten.scaled_dot_product_attention.default(cat_12, cat_15, cat_16, alias_5, scale = 0.10206207261596575);  cat_12 = alias_5 = None
            transpose_8: "f32[2, 3, 32, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2);  scaled_dot_product_attention_1 = None
            contiguous_1: "f32[2, 3, 32, 96]" = torch.ops.aten.contiguous.default(transpose_8);  transpose_8 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:210 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_1: "f32[2, 3, 3072]" = torch.ops.aten.reshape.default(contiguous_1, [2, 3, -1]);  contiguous_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(reshape_1, p_model_layers_1_self_attn_o_proj_weight);  reshape_1 = p_model_layers_1_self_attn_o_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:73 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout_2: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_5, 0.0, False);  linear_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:273 in forward, code: hidden_states = residual + self.resid_attn_dropout(hidden_states)  # main diff with Llama
            add_10: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_17, dropout_2);  to_17 = dropout_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:227 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_19 = torch.ops.aten._assert_tensor_metadata.default(add_10, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_19 = None
            to_19: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_10, torch.float32);  add_10 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:228 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_4: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_19, 2)
            mean_3: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True);  pow_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:229 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_11: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05);  mean_3 = None
            rsqrt_3: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_11);  add_11 = None
            mul_17: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_19, rsqrt_3);  rsqrt_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:230 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_20 = torch.ops.aten._assert_tensor_metadata.default(mul_17, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_20 = None
            to_20: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_17, torch.float32);  mul_17 = None
            mul_18: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_20);  p_model_layers_1_post_attention_layernorm_weight = to_20 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[2, 3, 16384]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_gate_up_proj_weight);  mul_18 = p_model_layers_1_mlp_gate_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:62 in forward, code: gate, up_states = up_states.chunk(2, dim=-1)
            chunk_1 = torch.ops.aten.chunk.default(linear_6, 2, -1);  linear_6 = None
            getitem_16: "f32[2, 3, 8192]" = chunk_1[0]
            getitem_17: "f32[2, 3, 8192]" = chunk_1[1];  chunk_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
            silu_1: "f32[2, 3, 8192]" = torch.ops.aten.silu.default(getitem_16);  getitem_16 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:63 in forward, code: up_states = up_states * self.activation_fn(gate)
            mul_19: "f32[2, 3, 8192]" = torch.ops.aten.mul.Tensor(getitem_17, silu_1);  getitem_17 = silu_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(mul_19, p_model_layers_1_mlp_down_proj_weight);  mul_19 = p_model_layers_1_mlp_down_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:73 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout_3: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_7, 0.0, False);  linear_7 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:278 in forward, code: hidden_states = residual + self.resid_mlp_dropout(hidden_states)  # main diff with Llama
            add_12: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_19, dropout_3);  to_19 = dropout_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:227 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_21 = torch.ops.aten._assert_tensor_metadata.default(add_12, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_21 = None
            to_21: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_12, torch.float32);  add_12 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:228 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_5: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_21, 2)
            mean_4: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True);  pow_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:229 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_13: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05);  mean_4 = None
            rsqrt_4: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_13);  add_13 = None
            mul_20: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_21, rsqrt_4);  to_21 = rsqrt_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:230 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_22 = torch.ops.aten._assert_tensor_metadata.default(mul_20, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_22 = None
            to_22: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_20, torch.float32);  mul_20 = None
            mul_21: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_22);  p_model_norm_weight = to_22 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:479 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            alias_6: "f32[2, 3, 3072]" = torch.ops.aten.alias.default(mul_21);  mul_21 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_8: "f32[2, 3, 32064]" = torch.ops.aten.linear.default(alias_6, p_lm_head_weight);  alias_6 = p_lm_head_weight = None
            return (linear_8, cat_9, cat_15, cat_10, cat_16)

        class submod_1(torch.nn.Module):
            def forward(self, unsqueeze: "i64[1, 3]", c_model_rotary_emb_lifted_tensor_4: "f32[48]"):
                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:398 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
                max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze)
                add: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
                gt_1: "b8[]" = torch.ops.aten.gt.Scalar(add, 4096);  add = None
                ne: "b8[]" = torch.ops.aten.ne.Scalar(gt_1, 0);  gt_1 = None
                item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None
                _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(c_model_rotary_emb_lifted_tensor_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                to_5: "f32[48]" = torch.ops.aten.to.dtype_layout(c_model_rotary_emb_lifted_tensor_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  c_model_rotary_emb_lifted_tensor_4 = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:325 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                unsqueeze_1: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(to_5, 0);  to_5 = None
                unsqueeze_2: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 2);  unsqueeze_1 = None
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_6: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_2, torch.float32);  unsqueeze_2 = None
                expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_6, [1, -1, 1]);  to_6 = None
                _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                to_7: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:326 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                unsqueeze_3: "i64[1, 1, 3]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
                _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
                to_8: "f32[1, 1, 3]" = torch.ops.aten.to.dtype(unsqueeze_3, torch.float32);  unsqueeze_3 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_7, to_8);  submod_3 = to_7 = to_8 = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:332 in forward, code: cos = emb.cos() * self.attention_scaling
                mul: "f32[1, 3, 96]" = wrap_with_autocast[0]

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:333 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_1: "f32[1, 3, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:335 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
                to_11: "f32[1, 3, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
                to_12: "f32[1, 3, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_11, to_12)

            class submod_1(torch.nn.Module):
                def forward(self, to_7: "f32[1, 48, 1]", to_8: "f32[1, 1, 3]"):
                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:330 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(to_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
                    to_9: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(to_7, torch.float32);  to_7 = None
                    _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(to_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
                    to_10: "f32[1, 1, 3]" = torch.ops.aten.to.dtype(to_8, torch.float32);  to_8 = None
                    matmul: "f32[1, 48, 3]" = torch.ops.aten.matmul.default(to_9, to_10);  to_9 = to_10 = None
                    transpose: "f32[1, 3, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:331 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat_4: "f32[1, 3, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:332 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[1, 3, 96]" = torch.ops.aten.cos.default(cat_4)
                    mul: "f32[1, 3, 96]" = torch.ops.aten.mul.Tensor(cos, 1.1902380714238083);  cos = None

                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:333 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[1, 3, 96]" = torch.ops.aten.sin.default(cat_4);  cat_4 = None
                    mul_1: "f32[1, 3, 96]" = torch.ops.aten.mul.Tensor(sin, 1.1902380714238083);  sin = None
                    return (mul, mul_1)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_self_attn_qkv_proj_weight: PARAMETER target='model.layers.0.self_attn.qkv_proj.weight'
    p_model_layers_0_mlp_gate_up_proj_weight: PARAMETER target='model.layers.0.mlp.gate_up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
    p_model_layers_1_self_attn_qkv_proj_weight: PARAMETER target='model.layers.1.self_attn.qkv_proj.weight'
    p_model_layers_1_mlp_gate_up_proj_weight: PARAMETER target='model.layers.1.mlp.gate_up_proj.weight'
    p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
    p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
    p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    c_lifted_tensor_0: CONSTANT_TENSOR target='lifted_tensor_0'
    c_lifted_tensor_1: CONSTANT_TENSOR target='lifted_tensor_1'
    c_lifted_tensor_2: CONSTANT_TENSOR target='lifted_tensor_2'
    c_lifted_tensor_3: CONSTANT_TENSOR target='lifted_tensor_3'
    c_model_rotary_emb_lifted_tensor_4: CONSTANT_TENSOR target='model.rotary_emb.lifted_tensor_4'
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_key_cache_1: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT
    past_key_values_value_cache_1: USER_INPUT

    # outputs
    linear_8: USER_OUTPUT
    cat_9: USER_OUTPUT
    cat_15: USER_OUTPUT
    cat_10: USER_OUTPUT
    cat_16: USER_OUTPUT

Range constraints: {u0: VR[0, 1]}

Total running time of the script: (0 minutes 5.085 seconds)

Related examples

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and padding one dimension to a mulitple of a constant

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

Gallery generated by Sphinx-Gallery