Export Phi-3.5-mini-instruct with draft_export

Tries torch.export._draft_export.draft_export().

Model

from contextlib import redirect_stderr
from io import StringIO
from typing import Any, Dict
import torch
import torch.export._draft_export
import transformers
from experimental_experiment.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions


def get_phi35_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with two sets of inputs and different shapes.

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-3.5-mini-instruct/config.json
    <https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "Phi-3.5-mini-instruct",
        "architectures": ["Phi3ForCausalLM"],
        "attention_dropout": 0.0,
        "auto_map": {
            "AutoConfig": "configuration_phi3.Phi3Config",
            "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
        },
        "bos_token_id": 1,
        "embd_pdrop": 0.0,
        "eos_token_id": 32000,
        "hidden_act": "silu",
        "hidden_size": 3072,
        "initializer_range": 0.02,
        "intermediate_size": 8192,
        "max_position_embeddings": 131072,
        "model_type": "phi3",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "original_max_position_embeddings": 4096,
        "pad_token_id": 32000,
        "resid_pdrop": 0.0,
        "rms_norm_eps": 1e-05,
        "rope_scaling": {
            "long_factor": [
                1.0800000429153442,
                1.1100000143051147,
                1.1399999856948853,
                1.340000033378601,
                1.5899999141693115,
                1.600000023841858,
                1.6200000047683716,
                2.620000123977661,
                3.2300000190734863,
                3.2300000190734863,
                4.789999961853027,
                7.400000095367432,
                7.700000286102295,
                9.09000015258789,
                12.199999809265137,
                17.670000076293945,
                24.46000099182129,
                28.57000160217285,
                30.420001983642578,
                30.840002059936523,
                32.590003967285156,
                32.93000411987305,
                42.320003509521484,
                44.96000289916992,
                50.340003967285156,
                50.45000457763672,
                57.55000305175781,
                57.93000411987305,
                58.21000289916992,
                60.1400032043457,
                62.61000442504883,
                62.62000274658203,
                62.71000289916992,
                63.1400032043457,
                63.1400032043457,
                63.77000427246094,
                63.93000411987305,
                63.96000289916992,
                63.970001220703125,
                64.02999877929688,
                64.06999969482422,
                64.08000183105469,
                64.12000274658203,
                64.41000366210938,
                64.4800033569336,
                64.51000213623047,
                64.52999877929688,
                64.83999633789062,
            ],
            "short_factor": [
                1.0,
                1.0199999809265137,
                1.0299999713897705,
                1.0299999713897705,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0699999332427979,
                1.0999999046325684,
                1.1099998950958252,
                1.1599998474121094,
                1.1599998474121094,
                1.1699998378753662,
                1.2899998426437378,
                1.339999794960022,
                1.679999828338623,
                1.7899998426437378,
                1.8199998140335083,
                1.8499997854232788,
                1.8799997568130493,
                1.9099997282028198,
                1.9399996995925903,
                1.9899996519088745,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0799996852874756,
                2.0899996757507324,
                2.189999580383301,
                2.2199995517730713,
                2.5899994373321533,
                2.729999542236328,
                2.749999523162842,
                2.8399994373321533,
            ],
            "type": "longrope",
        },
        "rope_theta": 10000.0,
        "sliding_window": 262144,
        "tie_word_embeddings": False,
        "torch_dtype": "bfloat16",
        "use_cache": True,
        "attention_bias": False,
        "vocab_size": 32064,
    }
    config.update(**kwargs)
    conf = transformers.Phi3Config(**config)
    model = transformers.Phi3ForCausalLM(conf)
    model.eval()

    cache = make_dynamic_cache(
        [
            (torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96))
            for i in range(config["num_hidden_layers"])
        ]
    )
    cache2 = make_dynamic_cache(
        [
            (torch.randn(batch_size + 1, 32, 31, 96), torch.randn(batch_size + 1, 32, 31, 96))
            for i in range(config["num_hidden_layers"])
        ]
    )

    inputs = dict(
        input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    return dict(inputs=inputs, model=model, inputs2=inputs2)


data = get_phi35_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))

Draft Export

The function we want to try.

err = StringIO()
with redirect_stderr(err), register_additional_serialization_functions():
    ep = torch.export._draft_export.draft_export(model, tuple(), kwargs=inputs, strict=False)

Errors if any.

print(err.getvalue())

Let’s print the report.

print(ep._report)
###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################

1. Data dependent error.
    When exporting, we were unable to evaluate the value of `Eq(u0, 1)`.
    This was encountered 1 times.
    This occurred at the following user stacktrace:
        File ~/vv/this312/lib/python3.12/site-packages/torch/utils/_contextlib.py, lineno 116, in decorate_context
        File ~/github/transformers/src/transformers/modeling_rope_utils.py, lineno 86, in wrapper
        File ~/github/transformers/src/transformers/modeling_rope_utils.py, lineno 50, in longrope_frequency_update
            if seq_len > original_max_position_embeddings:

        Locals:
            seq_len: ['Tensor(shape: torch.Size([]), stride: (), storage_offset: 0)']
            original_max_position_embeddings: [None]

    And the following framework stacktrace:
        File ~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py, lineno 1326, in __torch_function__
        File ~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py, lineno 1373, in __torch_function__
        File ~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py, lineno 973, in __torch_function__
            return func(*args, **kwargs)

    As a result, it was specialized to a constant (e.g. `0` in the 1st occurrence), and asserts were inserted into the graph.

    Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
    Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.

And the exported program.

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32064, 3072]", p_model_layers_0_self_attn_o_proj_weight: "f32[3072, 3072]", p_model_layers_0_self_attn_qkv_proj_weight: "f32[9216, 3072]", p_model_layers_0_mlp_gate_up_proj_weight: "f32[16384, 3072]", p_model_layers_0_mlp_down_proj_weight: "f32[3072, 8192]", p_model_layers_0_input_layernorm_weight: "f32[3072]", p_model_layers_0_post_attention_layernorm_weight: "f32[3072]", p_model_layers_1_self_attn_o_proj_weight: "f32[3072, 3072]", p_model_layers_1_self_attn_qkv_proj_weight: "f32[9216, 3072]", p_model_layers_1_mlp_gate_up_proj_weight: "f32[16384, 3072]", p_model_layers_1_mlp_down_proj_weight: "f32[3072, 8192]", p_model_layers_1_input_layernorm_weight: "f32[3072]", p_model_layers_1_post_attention_layernorm_weight: "f32[3072]", p_model_norm_weight: "f32[3072]", p_lm_head_weight: "f32[32064, 3072]", b_model_rotary_emb_inv_freq: "f32[48]", c_model_rotary_emb_lifted_tensor_0: "f32[48]", input_ids: "i64[2, 3]", attention_mask: "i64[2, 33]", past_key_values_key_cache_0: "f32[2, 32, 30, 96]", past_key_values_key_cache_1: "f32[2, 32, 30, 96]", past_key_values_value_cache_0: "f32[2, 32, 30, 96]", past_key_values_value_cache_1: "f32[2, 32, 30, 96]"):
             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[2, 3, 3072]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids, 32000);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:454 in forward, code: cache_position = torch.arange(
            arange: "i64[3]" = torch.ops.aten.arange.start(30, 33, device = device(type='cpu'), pin_memory = False)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:459 in forward, code: position_ids = cache_position.unsqueeze(0)
            unsqueeze: "i64[1, 3]" = torch.ops.aten.unsqueeze.default(arange, 0)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:461 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[3, 33]" = torch.ops.aten.full.default([3, 33], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            arange_1: "i64[33]" = torch.ops.aten.arange.default(33, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[3, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1])
            gt: "b8[3, 33]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            arange_2: "i64[33]" = torch.ops.aten.arange.default(33, device = device(type='cpu'), pin_memory = False)
            reshape_1: "i64[3, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            sub: "i64[3, 1]" = torch.ops.aten.sub.Tensor(reshape_1, 262144);  reshape_1 = None
            le: "b8[3, 33]" = torch.ops.aten.le.Tensor(arange_2, sub);  arange_2 = sub = None
            bitwise_or_: "b8[3, 33]" = torch.ops.aten.bitwise_or_.Tensor(gt, le);  gt = le = None
            mul_: "f32[3, 33]" = torch.ops.aten.mul_.Tensor(full, bitwise_or_);  full = bitwise_or_ = None
            unsqueeze_1: "f32[1, 3, 33]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_2: "f32[1, 1, 3, 33]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
            slice_1: "f32[1, 1, 3, 33]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
            slice_2: "f32[1, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            expand: "f32[2, 1, 3, 33]" = torch.ops.aten.expand.default(slice_2, [2, 1, -1, -1]);  slice_2 = None
            clone: "f32[2, 1, 3, 33]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone)
            slice_4: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_3, 1);  slice_3 = None
            slice_5: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
            slice_6: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, 33);  slice_5 = None
            slice_7: "i64[2, 33]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_3: "i64[2, 1, 33]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
            unsqueeze_4: "i64[2, 1, 1, 33]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
            slice_8: "i64[2, 1, 1, 33]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(slice_8, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "i64[2, 1, 1, 33]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
            add: "f32[2, 1, 3, 33]" = torch.ops.aten.add.Tensor(slice_6, to);  slice_6 = to = None
            eq: "b8[2, 1, 3, 33]" = torch.ops.aten.eq.Scalar(add, 0);  add = None
            slice_9: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone)
            slice_10: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
            slice_11: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
            slice_12: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, 33);  slice_11 = None
            masked_fill: "f32[2, 1, 3, 33]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq, -3.4028234663852886e+38);  slice_12 = eq = None
            slice_13: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_14: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807);  slice_13 = None
            slice_15: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
            copy_: "f32[2, 1, 3, 33]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, unsqueeze, c_model_rotary_emb_lifted_tensor_0);  submod_3 = unsqueeze = c_model_rotary_emb_lifted_tensor_0 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:385 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_7: "f32[1, 3, 96]" = wrap_with_set_grad_enabled[0]
            to_8: "f32[1, 3, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
            to_9: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_9, 2)
            mean: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_2: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_2);  add_2 = None
            mul_2: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_9, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
            to_10: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_10);  p_model_layers_0_input_layernorm_weight = to_10 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2, 3, 9216]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_qkv_proj_weight);  mul_3 = p_model_layers_0_self_attn_qkv_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:184 in forward, code: query_states = qkv[..., :query_pos]
            slice_19: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 0, 3072)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:185 in forward, code: key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
            slice_20: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 3072, 6144)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:186 in forward, code: value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
            slice_21: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 6144, 9223372036854775807);  linear = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:188 in forward, code: query_states = query_states.view(hidden_shape).transpose(1, 2)
            view: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_19, [2, 3, -1, 96]);  slice_19 = None
            transpose_1: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:189 in forward, code: key_states = key_states.view(hidden_shape).transpose(1, 2)
            view_1: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_20, [2, 3, -1, 96]);  slice_20 = None
            transpose_2: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:190 in forward, code: value_states = value_states.view(hidden_shape).transpose(1, 2)
            view_2: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_21, [2, 3, -1, 96]);  slice_21 = None
            transpose_3: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:193 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_8: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1)
            unsqueeze_9: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1)
            alias: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_1)
            slice_22: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 96, 9223372036854775807);  transpose_1 = None
            alias_1: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_2)
            slice_23: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 96, 9223372036854775807);  transpose_2 = None
            mul_4: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias, unsqueeze_8)
            slice_24: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias, 3, 0, 48)
            slice_25: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias, 3, 48, 9223372036854775807);  alias = None
            neg: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_25);  slice_25 = None
            cat_1: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg, slice_24], -1);  neg = slice_24 = None
            mul_5: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9);  cat_1 = None
            add_3: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            cat_2: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_3, slice_22], -1);  add_3 = slice_22 = None
            mul_6: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_1, unsqueeze_8);  unsqueeze_8 = None
            slice_26: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_1, 3, 0, 48)
            slice_27: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_1, 3, 48, 9223372036854775807);  alias_1 = None
            neg_1: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_27);  slice_27 = None
            cat_3: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_1, slice_26], -1);  neg_1 = slice_26 = None
            mul_7: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_3, unsqueeze_9);  cat_3 = unsqueeze_9 = None
            add_4: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
            cat_4: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_4, slice_23], -1);  add_4 = slice_23 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:198 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_5: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, cat_4], -2);  past_key_values_key_cache_0 = cat_4 = None
            cat_6: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:210 in forward, code: attn_output, attn_weights = attention_interface(
            slice_28: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone)
            slice_29: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_28, 1);  slice_28 = None
            slice_30: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_29, 2);  slice_29 = None
            slice_31: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_30, 3, None, 33);  slice_30 = None
            scaled_dot_product_attention: "f32[2, 32, 3, 96]" = torch.ops.aten.scaled_dot_product_attention.default(cat_2, cat_5, cat_6, slice_31, scale = 0.10206207261596575);  cat_2 = slice_31 = None
            transpose_4: "f32[2, 3, 32, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous: "f32[2, 3, 32, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:222 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_2: "f32[2, 3, 3072]" = torch.ops.aten.reshape.default(contiguous, [2, 3, -1]);  contiguous = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_0_self_attn_o_proj_weight);  reshape_2 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_1, 0.0, False);  linear_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:310 in forward, code: hidden_states = residual + self.resid_attn_dropout(hidden_states)  # main diff with Llama
            add_5: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_9, dropout);  to_9 = dropout = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_11: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_5, torch.float32);  add_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_11, 2)
            mean_1: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_6: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_6);  add_6 = None
            mul_8: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_11, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_12: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_12);  p_model_layers_0_post_attention_layernorm_weight = to_12 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[2, 3, 16384]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_up_proj_weight);  mul_9 = p_model_layers_0_mlp_gate_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:69 in forward, code: gate, up_states = up_states.chunk(2, dim=-1)
            chunk = torch.ops.aten.chunk.default(linear_2, 2, -1);  linear_2 = None
            getitem_10: "f32[2, 3, 8192]" = chunk[0]
            getitem_11: "f32[2, 3, 8192]" = chunk[1];  chunk = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[2, 3, 8192]" = torch.ops.aten.silu.default(getitem_10);  getitem_10 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:70 in forward, code: up_states = up_states * self.activation_fn(gate)
            mul_10: "f32[2, 3, 8192]" = torch.ops.aten.mul.Tensor(getitem_11, silu);  getitem_11 = silu = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout_1: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_3, 0.0, False);  linear_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:315 in forward, code: hidden_states = residual + self.resid_mlp_dropout(hidden_states)  # main diff with Llama
            add_7: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_11, dropout_1);  to_11 = dropout_1 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_13: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2)
            mean_2: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_11: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_13, rsqrt_2);  rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_14: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_14);  p_model_layers_1_input_layernorm_weight = to_14 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[2, 3, 9216]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_qkv_proj_weight);  mul_12 = p_model_layers_1_self_attn_qkv_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:184 in forward, code: query_states = qkv[..., :query_pos]
            slice_32: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 0, 3072)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:185 in forward, code: key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
            slice_33: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 3072, 6144)

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:186 in forward, code: value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
            slice_34: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 6144, 9223372036854775807);  linear_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:188 in forward, code: query_states = query_states.view(hidden_shape).transpose(1, 2)
            view_3: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_32, [2, 3, -1, 96]);  slice_32 = None
            transpose_5: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_3, 1, 2);  view_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:189 in forward, code: key_states = key_states.view(hidden_shape).transpose(1, 2)
            view_4: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_33, [2, 3, -1, 96]);  slice_33 = None
            transpose_6: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_4, 1, 2);  view_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:190 in forward, code: value_states = value_states.view(hidden_shape).transpose(1, 2)
            view_5: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_34, [2, 3, -1, 96]);  slice_34 = None
            transpose_7: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_5, 1, 2);  view_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:193 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_10: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            unsqueeze_11: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
            alias_2: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_5)
            slice_35: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 96, 9223372036854775807);  transpose_5 = None
            alias_3: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_6)
            slice_36: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 96, 9223372036854775807);  transpose_6 = None
            mul_13: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_2, unsqueeze_10)
            slice_37: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_2, 3, 0, 48)
            slice_38: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_2, 3, 48, 9223372036854775807);  alias_2 = None
            neg_2: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_38);  slice_38 = None
            cat_7: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_2, slice_37], -1);  neg_2 = slice_37 = None
            mul_14: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_7, unsqueeze_11);  cat_7 = None
            add_9: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_13, mul_14);  mul_13 = mul_14 = None
            cat_8: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_9, slice_35], -1);  add_9 = slice_35 = None
            mul_15: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_3, unsqueeze_10);  unsqueeze_10 = None
            slice_39: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_3, 3, 0, 48)
            slice_40: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_3, 3, 48, 9223372036854775807);  alias_3 = None
            neg_3: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_40);  slice_40 = None
            cat_9: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_3, slice_39], -1);  neg_3 = slice_39 = None
            mul_16: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_9, unsqueeze_11);  cat_9 = unsqueeze_11 = None
            add_10: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_15, mul_16);  mul_15 = mul_16 = None
            cat_10: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_10, slice_36], -1);  add_10 = slice_36 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:198 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_11: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_1, cat_10], -2);  past_key_values_key_cache_1 = cat_10 = None
            cat_12: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_1, transpose_7], -2);  past_key_values_value_cache_1 = transpose_7 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:210 in forward, code: attn_output, attn_weights = attention_interface(
            slice_41: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone);  clone = None
            slice_42: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_41, 1);  slice_41 = None
            slice_43: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_42, 2);  slice_42 = None
            slice_44: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_43, 3, None, 33);  slice_43 = None
            scaled_dot_product_attention_1: "f32[2, 32, 3, 96]" = torch.ops.aten.scaled_dot_product_attention.default(cat_8, cat_11, cat_12, slice_44, scale = 0.10206207261596575);  cat_8 = slice_44 = None
            transpose_8: "f32[2, 3, 32, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2);  scaled_dot_product_attention_1 = None
            contiguous_1: "f32[2, 3, 32, 96]" = torch.ops.aten.contiguous.default(transpose_8);  transpose_8 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:222 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[2, 3, 3072]" = torch.ops.aten.reshape.default(contiguous_1, [2, 3, -1]);  contiguous_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_1_self_attn_o_proj_weight);  reshape_3 = p_model_layers_1_self_attn_o_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout_2: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_5, 0.0, False);  linear_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:310 in forward, code: hidden_states = residual + self.resid_attn_dropout(hidden_states)  # main diff with Llama
            add_11: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_13, dropout_2);  to_13 = dropout_2 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_15: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_11, torch.float32);  add_11 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_4: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2)
            mean_3: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True);  pow_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_12: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05);  mean_3 = None
            rsqrt_3: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_12);  add_12 = None
            mul_17: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_3);  rsqrt_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_17, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_16: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_17, torch.float32);  mul_17 = None
            mul_18: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_16);  p_model_layers_1_post_attention_layernorm_weight = to_16 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[2, 3, 16384]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_gate_up_proj_weight);  mul_18 = p_model_layers_1_mlp_gate_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:69 in forward, code: gate, up_states = up_states.chunk(2, dim=-1)
            chunk_1 = torch.ops.aten.chunk.default(linear_6, 2, -1);  linear_6 = None
            getitem_12: "f32[2, 3, 8192]" = chunk_1[0]
            getitem_13: "f32[2, 3, 8192]" = chunk_1[1];  chunk_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
            silu_1: "f32[2, 3, 8192]" = torch.ops.aten.silu.default(getitem_12);  getitem_12 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:70 in forward, code: up_states = up_states * self.activation_fn(gate)
            mul_19: "f32[2, 3, 8192]" = torch.ops.aten.mul.Tensor(getitem_13, silu_1);  getitem_13 = silu_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(mul_19, p_model_layers_1_mlp_down_proj_weight);  mul_19 = p_model_layers_1_mlp_down_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
            dropout_3: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_7, 0.0, False);  linear_7 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:315 in forward, code: hidden_states = residual + self.resid_mlp_dropout(hidden_states)  # main diff with Llama
            add_13: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_15, dropout_3);  to_15 = dropout_3 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(add_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_16 = None
            to_17: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_13, torch.float32);  add_13 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_5: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_17, 2)
            mean_4: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True);  pow_5 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_14: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05);  mean_4 = None
            rsqrt_4: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_14);  add_14 = None
            mul_20: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_17, rsqrt_4);  to_17 = rsqrt_4 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(mul_20, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_17 = None
            to_18: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_20, torch.float32);  mul_20 = None
            mul_21: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_18);  p_model_norm_weight = to_18 = None

             # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:760 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_45: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(mul_21);  mul_21 = None
            slice_46: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(slice_45, 1, 0);  slice_45 = None
            slice_47: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(slice_46, 2);  slice_46 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_8: "f32[2, 3, 32064]" = torch.ops.aten.linear.default(slice_47, p_lm_head_weight);  slice_47 = p_lm_head_weight = None
            return (linear_8, cat_5, cat_11, cat_6, cat_12)

        class submod_1(torch.nn.Module):
            def forward(self, unsqueeze: "i64[1, 3]", c_model_rotary_emb_lifted_tensor_0: "f32[48]"):
                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:468 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
                max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze)
                add_1: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
                gt_1: "b8[]" = torch.ops.aten.gt.Scalar(add_1, 4096);  add_1 = None
                ne: "b8[]" = torch.ops.aten.ne.Scalar(gt_1, 0);  gt_1 = None
                item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None
                to_1: "f32[48]" = torch.ops.aten.to.dtype_layout(c_model_rotary_emb_lifted_tensor_0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  c_model_rotary_emb_lifted_tensor_0 = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:375 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(to_1, 0);  to_1 = None
                slice_16: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807);  unsqueeze_5 = None
                unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_16, 2);  slice_16 = None
                _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
                to_2: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32);  unsqueeze_6 = None
                expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_2, [1, -1, 1]);  to_2 = None
                _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
                to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:376 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_17: "i64[1, 3]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807);  unsqueeze = None
                unsqueeze_7: "i64[1, 1, 3]" = torch.ops.aten.unsqueeze.default(slice_17, 1);  slice_17 = None
                slice_18: "i64[1, 1, 3]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807);  unsqueeze_7 = None
                _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_18, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                to_4: "f32[1, 1, 3]" = torch.ops.aten.to.dtype(slice_18, torch.float32);  slice_18 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_3, to_4);  submod_3 = to_3 = to_4 = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:382 in forward, code: cos = emb.cos() * self.attention_scaling
                mul: "f32[1, 3, 96]" = wrap_with_autocast[0]

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:383 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_1: "f32[1, 3, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:385 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_7: "f32[1, 3, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                to_8: "f32[1, 3, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_7, to_8)

            class submod_1(torch.nn.Module):
                def forward(self, to_3: "f32[1, 48, 1]", to_4: "f32[1, 1, 3]"):
                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:380 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                    to_5: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(to_3, torch.float32);  to_3 = None
                    _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                    to_6: "f32[1, 1, 3]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                    matmul: "f32[1, 48, 3]" = torch.ops.aten.matmul.default(to_5, to_6);  to_5 = to_6 = None
                    transpose: "f32[1, 3, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:381 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[1, 3, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:382 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[1, 3, 96]" = torch.ops.aten.cos.default(cat)
                    mul: "f32[1, 3, 96]" = torch.ops.aten.mul.Tensor(cos, 1.1902380714238083);  cos = None

                     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:383 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[1, 3, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    mul_1: "f32[1, 3, 96]" = torch.ops.aten.mul.Tensor(sin, 1.1902380714238083);  sin = None
                    return (mul, mul_1)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_self_attn_qkv_proj_weight: PARAMETER target='model.layers.0.self_attn.qkv_proj.weight'
    p_model_layers_0_mlp_gate_up_proj_weight: PARAMETER target='model.layers.0.mlp.gate_up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
    p_model_layers_1_self_attn_qkv_proj_weight: PARAMETER target='model.layers.1.self_attn.qkv_proj.weight'
    p_model_layers_1_mlp_gate_up_proj_weight: PARAMETER target='model.layers.1.mlp.gate_up_proj.weight'
    p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
    p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
    p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    c_model_rotary_emb_lifted_tensor_0: CONSTANT_TENSOR target='model.rotary_emb.lifted_tensor_0'
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_key_cache_1: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT
    past_key_values_value_cache_1: USER_INPUT

    # outputs
    linear_8: USER_OUTPUT
    cat_5: USER_OUTPUT
    cat_11: USER_OUTPUT
    cat_6: USER_OUTPUT
    cat_12: USER_OUTPUT

Range constraints: {u0: VR[0, 1]}

Total running time of the script: (0 minutes 15.820 seconds)

Related examples

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

torch.onnx.export and Phi-2

torch.onnx.export and Phi-2

to_onnx and Phi-2

to_onnx and Phi-2

Check the exporter on a dummy from HuggingFace

Check the exporter on a dummy from HuggingFace

Gallery generated by Sphinx-Gallery