Export Phi-3.5-mini-instruct piece by piece

torch.export.export() often breaks on big models because there are control flows or instructions breaking the propagation of dynamic shapes (see …). The function usually gives an indication where the model implementation can be fixed but in case, that is not possible, we can try to export the model piece by piece: every module is converted separately from its submodule. A model can be exported even if one of its submodules cannot.

Model

import pprint
from typing import Any, Dict
import torch
import torch._export.tools
import transformers
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from experimental_experiment.helpers import string_type
from experimental_experiment.torch_interpreter.piece_by_piece import (
    trace_execution_piece_by_piece,
)


def get_phi35_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with two sets of inputs and different shapes.

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-3.5-mini-instruct/config.json
    <https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "Phi-3.5-mini-instruct",
        "architectures": ["Phi3ForCausalLM"],
        "attention_dropout": 0.0,
        "auto_map": {
            "AutoConfig": "configuration_phi3.Phi3Config",
            "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
        },
        "bos_token_id": 1,
        "embd_pdrop": 0.0,
        "eos_token_id": 32000,
        "hidden_act": "silu",
        "hidden_size": 3072,
        "initializer_range": 0.02,
        "intermediate_size": 8192,
        "max_position_embeddings": 131072,
        "model_type": "phi3",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "original_max_position_embeddings": 4096,
        "pad_token_id": 32000,
        "resid_pdrop": 0.0,
        "rms_norm_eps": 1e-05,
        "rope_scaling": {
            "long_factor": [
                1.0800000429153442,
                1.1100000143051147,
                1.1399999856948853,
                1.340000033378601,
                1.5899999141693115,
                1.600000023841858,
                1.6200000047683716,
                2.620000123977661,
                3.2300000190734863,
                3.2300000190734863,
                4.789999961853027,
                7.400000095367432,
                7.700000286102295,
                9.09000015258789,
                12.199999809265137,
                17.670000076293945,
                24.46000099182129,
                28.57000160217285,
                30.420001983642578,
                30.840002059936523,
                32.590003967285156,
                32.93000411987305,
                42.320003509521484,
                44.96000289916992,
                50.340003967285156,
                50.45000457763672,
                57.55000305175781,
                57.93000411987305,
                58.21000289916992,
                60.1400032043457,
                62.61000442504883,
                62.62000274658203,
                62.71000289916992,
                63.1400032043457,
                63.1400032043457,
                63.77000427246094,
                63.93000411987305,
                63.96000289916992,
                63.970001220703125,
                64.02999877929688,
                64.06999969482422,
                64.08000183105469,
                64.12000274658203,
                64.41000366210938,
                64.4800033569336,
                64.51000213623047,
                64.52999877929688,
                64.83999633789062,
            ],
            "short_factor": [
                1.0,
                1.0199999809265137,
                1.0299999713897705,
                1.0299999713897705,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0699999332427979,
                1.0999999046325684,
                1.1099998950958252,
                1.1599998474121094,
                1.1599998474121094,
                1.1699998378753662,
                1.2899998426437378,
                1.339999794960022,
                1.679999828338623,
                1.7899998426437378,
                1.8199998140335083,
                1.8499997854232788,
                1.8799997568130493,
                1.9099997282028198,
                1.9399996995925903,
                1.9899996519088745,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0799996852874756,
                2.0899996757507324,
                2.189999580383301,
                2.2199995517730713,
                2.5899994373321533,
                2.729999542236328,
                2.749999523162842,
                2.8399994373321533,
            ],
            "type": "longrope",
        },
        "rope_theta": 10000.0,
        "sliding_window": 262144,
        "tie_word_embeddings": False,
        "torch_dtype": "bfloat16",
        "use_cache": True,
        "attention_bias": False,
        "vocab_size": 32064,
    }
    config.update(**kwargs)
    conf = transformers.Phi3Config(**config)
    model = transformers.Phi3ForCausalLM(conf)
    model.eval()

    cache = make_dynamic_cache(
        [
            (torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96))
            for i in range(config["num_hidden_layers"])
        ]
    )
    cache2 = make_dynamic_cache(
        [
            (torch.randn(batch_size + 1, 32, 31, 96), torch.randn(batch_size + 1, 32, 31, 96))
            for i in range(config["num_hidden_layers"])
        ]
    )

    inputs = dict(
        input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    return dict(inputs=inputs, model=model, inputs2=inputs2)


data = get_phi35_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))

Dynamic Shapes

We want to infer the dynamic shapes from the two sets of inputs we gave. For that, we use a function to trace the execution of the model including its submodules. It is going to execute the model twice with the two sets of inputs and stores every intermediate input and output.

[_trace_forward_execution] -trace-  M:__main__-Phi3ForCausalLM.forward
[_trace_forward_execution] -trace- .. M:model-Phi3Model.forward
[_trace_forward_execution] -trace- .... M:embed_tokens-Embedding.forward
[_trace_forward_execution] -trace- .... M:layers[0]-Phi3DecoderLayer.forward
[_trace_forward_execution] -trace- ...... M:self_attn-Phi3Attention.forward
[_trace_forward_execution] -trace- ........ M:o_proj-Linear.forward
[_trace_forward_execution] -trace- ........ M:qkv_proj-Linear.forward
[_trace_forward_execution] -trace- ...... M:mlp-Phi3MLP.forward
[_trace_forward_execution] -trace- ........ M:gate_up_proj-Linear.forward
[_trace_forward_execution] -trace- ........ M:down_proj-Linear.forward
[_trace_forward_execution] -trace- ........ M:activation_fn-SiLU.forward
[_trace_forward_execution] -trace- ...... M:input_layernorm-Phi3RMSNorm.forward
[_trace_forward_execution] -trace- ...... M:post_attention_layernorm-Phi3RMSNorm.forward
[_trace_forward_execution] -trace- ...... M:resid_attn_dropout-Dropout.forward
[_trace_forward_execution] -trace- ...... M:resid_mlp_dropout-Dropout.forward
[_trace_forward_execution] -trace- .... M:layers[1]-Phi3DecoderLayer.forward
[_trace_forward_execution] -trace- ...... M:self_attn-Phi3Attention.forward
[_trace_forward_execution] -trace- ........ M:o_proj-Linear.forward
[_trace_forward_execution] -trace- ........ M:qkv_proj-Linear.forward
[_trace_forward_execution] -trace- ...... M:mlp-Phi3MLP.forward
[_trace_forward_execution] -trace- ........ M:gate_up_proj-Linear.forward
[_trace_forward_execution] -trace- ........ M:down_proj-Linear.forward
[_trace_forward_execution] -trace- ........ M:activation_fn-SiLU.forward
[_trace_forward_execution] -trace- ...... M:input_layernorm-Phi3RMSNorm.forward
[_trace_forward_execution] -trace- ...... M:post_attention_layernorm-Phi3RMSNorm.forward
[_trace_forward_execution] -trace- ...... M:resid_attn_dropout-Dropout.forward
[_trace_forward_execution] -trace- ...... M:resid_mlp_dropout-Dropout.forward
[_trace_forward_execution] -trace- .... M:norm-Phi3RMSNorm.forward
[_trace_forward_execution] -trace- .... M:rotary_emb-Phi3RotaryEmbedding.forward
[_trace_forward_execution] -trace- .. M:lm_head-Linear.forward
[trace_execution_piece_by_piece] run with dict(args:(),kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model]   > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:bool,output_hidden_states:bool,cache_position:None)
[embed_tokens:Embedding]     > T7r2
[embed_tokens:Embedding]     < T1r3
[rotary_emb:Phi3RotaryEmbedding]     > *(T1r3,T7r2)
[rotary_emb:Phi3RotaryEmbedding]     < *(T1r3,T1r3)
[layers[0]:Phi3DecoderLayer]     > *(T1r3,), **dict(attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[input_layernorm:Phi3RMSNorm]       > T1r3
[input_layernorm:Phi3RMSNorm]       < T1r3
[self_attn:Phi3Attention]       > **dict(hidden_states:T1r3,attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[qkv_proj:Linear]         > T1r3
[qkv_proj:Linear]         < T1r3
[o_proj:Linear]         > T1r3
[o_proj:Linear]         < T1r3
[self_attn:Phi3Attention]       < *(T1r3,None)
[resid_attn_dropout:Dropout]       > T1r3
[resid_attn_dropout:Dropout]       < T1r3
[post_attention_layernorm:Phi3RMSNorm]       > T1r3
[post_attention_layernorm:Phi3RMSNorm]       < T1r3
[mlp:Phi3MLP]       > T1r3
[gate_up_proj:Linear]         > T1r3
[gate_up_proj:Linear]         < T1r3
[activation_fn:SiLU]         > T1r3
[activation_fn:SiLU]         < T1r3
[down_proj:Linear]         > T1r3
[down_proj:Linear]         < T1r3
[mlp:Phi3MLP]       < T1r3
[resid_mlp_dropout:Dropout]       > T1r3
[resid_mlp_dropout:Dropout]       < T1r3
[layers[0]:Phi3DecoderLayer]     < *(T1r3,)
[layers[1]:Phi3DecoderLayer]     > *(T1r3,), **dict(attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[input_layernorm:Phi3RMSNorm]       > T1r3
[input_layernorm:Phi3RMSNorm]       < T1r3
[self_attn:Phi3Attention]       > **dict(hidden_states:T1r3,attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[qkv_proj:Linear]         > T1r3
[qkv_proj:Linear]         < T1r3
[o_proj:Linear]         > T1r3
[o_proj:Linear]         < T1r3
[self_attn:Phi3Attention]       < *(T1r3,None)
[resid_attn_dropout:Dropout]       > T1r3
[resid_attn_dropout:Dropout]       < T1r3
[post_attention_layernorm:Phi3RMSNorm]       > T1r3
[post_attention_layernorm:Phi3RMSNorm]       < T1r3
[mlp:Phi3MLP]       > T1r3
[gate_up_proj:Linear]         > T1r3
[gate_up_proj:Linear]         < T1r3
[activation_fn:SiLU]         > T1r3
[activation_fn:SiLU]         < T1r3
[down_proj:Linear]         > T1r3
[down_proj:Linear]         < T1r3
[mlp:Phi3MLP]       < T1r3
[resid_mlp_dropout:Dropout]       > T1r3
[resid_mlp_dropout:Dropout]       < T1r3
[layers[1]:Phi3DecoderLayer]     < *(T1r3,)
[norm:Phi3RMSNorm]     > T1r3
[norm:Phi3RMSNorm]     < T1r3
[model:Phi3Model]   < *BaseModelOutputWithPast(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear]   > T1r3
[lm_head:Linear]   < T1r3
[__main__:Phi3ForCausalLM] < *CausalLMOutputWithPast(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_execution_piece_by_piece] run with dict(args:(),kwargs:dict(input_ids:T7s3x4,attention_mask:T7s3x35,past_key_values:DynamicCache(key_cache=#2[T1s3x32x31x96,T1s3x32x31x96], value_cache=#2[T1s3x32x31x96,T1s3x32x31x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model]   > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:bool,output_hidden_states:bool,cache_position:None)
[embed_tokens:Embedding]     > T7r2
[embed_tokens:Embedding]     < T1r3
[rotary_emb:Phi3RotaryEmbedding]     > *(T1r3,T7r2)
[rotary_emb:Phi3RotaryEmbedding]     < *(T1r3,T1r3)
[layers[0]:Phi3DecoderLayer]     > *(T1r3,), **dict(attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[input_layernorm:Phi3RMSNorm]       > T1r3
[input_layernorm:Phi3RMSNorm]       < T1r3
[self_attn:Phi3Attention]       > **dict(hidden_states:T1r3,attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[qkv_proj:Linear]         > T1r3
[qkv_proj:Linear]         < T1r3
[o_proj:Linear]         > T1r3
[o_proj:Linear]         < T1r3
[self_attn:Phi3Attention]       < *(T1r3,None)
[resid_attn_dropout:Dropout]       > T1r3
[resid_attn_dropout:Dropout]       < T1r3
[post_attention_layernorm:Phi3RMSNorm]       > T1r3
[post_attention_layernorm:Phi3RMSNorm]       < T1r3
[mlp:Phi3MLP]       > T1r3
[gate_up_proj:Linear]         > T1r3
[gate_up_proj:Linear]         < T1r3
[activation_fn:SiLU]         > T1r3
[activation_fn:SiLU]         < T1r3
[down_proj:Linear]         > T1r3
[down_proj:Linear]         < T1r3
[mlp:Phi3MLP]       < T1r3
[resid_mlp_dropout:Dropout]       > T1r3
[resid_mlp_dropout:Dropout]       < T1r3
[layers[0]:Phi3DecoderLayer]     < *(T1r3,)
[layers[1]:Phi3DecoderLayer]     > *(T1r3,), **dict(attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[input_layernorm:Phi3RMSNorm]       > T1r3
[input_layernorm:Phi3RMSNorm]       < T1r3
[self_attn:Phi3Attention]       > **dict(hidden_states:T1r3,attention_mask:T1r4,position_ids:T7r2,past_key_value:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),output_attentions:bool,use_cache:bool,cache_position:T7r1,position_embeddings:(T1r3,T1r3))
[qkv_proj:Linear]         > T1r3
[qkv_proj:Linear]         < T1r3
[o_proj:Linear]         > T1r3
[o_proj:Linear]         < T1r3
[self_attn:Phi3Attention]       < *(T1r3,None)
[resid_attn_dropout:Dropout]       > T1r3
[resid_attn_dropout:Dropout]       < T1r3
[post_attention_layernorm:Phi3RMSNorm]       > T1r3
[post_attention_layernorm:Phi3RMSNorm]       < T1r3
[mlp:Phi3MLP]       > T1r3
[gate_up_proj:Linear]         > T1r3
[gate_up_proj:Linear]         < T1r3
[activation_fn:SiLU]         > T1r3
[activation_fn:SiLU]         < T1r3
[down_proj:Linear]         > T1r3
[down_proj:Linear]         < T1r3
[mlp:Phi3MLP]       < T1r3
[resid_mlp_dropout:Dropout]       > T1r3
[resid_mlp_dropout:Dropout]       < T1r3
[layers[1]:Phi3DecoderLayer]     < *(T1r3,)
[norm:Phi3RMSNorm]     > T1r3
[norm:Phi3RMSNorm]     < T1r3
[model:Phi3Model]   < *BaseModelOutputWithPast(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear]   > T1r3
[lm_head:Linear]   < T1r3
[__main__:Phi3ForCausalLM] < *CausalLMOutputWithPast(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_forward_execution] traced execution of model Phi3ForCausalLM
>>> __main__: Phi3ForCausalLM
  > ((),dict(input_ids:CT7s2x3[3876,15594:A9096.5],attention_mask:CT7s2x33[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.400763034820557,4.259133338928223:A-0.004826152082167923],CT1s2x32x30x96[-4.0469818115234375,4.471982479095459:A0.002217433422628052]], value_cache=#2[CT1s2x32x30x96[-4.405541896820068,4.668274402618408:A0.002596911304140301],CT1s2x32x30x96[-4.310667037963867,4.2944207191467285:A0.00028137129325799627]])))
  > ((),dict(input_ids:CT7s3x4[2392,29616:A20512.083333333332],attention_mask:CT7s3x35[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.828048229217529,4.243164539337158:A0.00031101864580231346],CT1s3x32x31x96[-4.626999855041504,4.47794246673584:A-0.0025560081888748925]], value_cache=#2[CT1s3x32x31x96[-4.349869728088379,4.253993034362793:A0.0011117361406822917],CT1s3x32x31x96[-4.810003280639648,4.869333267211914:A-0.00027058590800271105]])))
    >>> model: Phi3Model
      > ((),dict(input_ids:CT7s2x3[3876,15594:A9096.5],attention_mask:CT7s2x33[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.400763034820557,4.259133338928223:A-0.004826152082167923],CT1s2x32x30x96[-4.0469818115234375,4.471982479095459:A0.002217433422628052]], value_cache=#2[CT1s2x32x30x96[-4.405541896820068,4.668274402618408:A0.002596911304140301],CT1s2x32x30x96[-4.310667037963867,4.2944207191467285:A0.00028137129325799627]]),inputs_embeds:None,use_cache:None,output_attentions:bool=False,output_hidden_states:bool=False,cache_position:None))
      > ((),dict(input_ids:CT7s3x4[2392,29616:A20512.083333333332],attention_mask:CT7s3x35[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.828048229217529,4.243164539337158:A0.00031101864580231346],CT1s3x32x31x96[-4.626999855041504,4.47794246673584:A-0.0025560081888748925]], value_cache=#2[CT1s3x32x31x96[-4.349869728088379,4.253993034362793:A0.0011117361406822917],CT1s3x32x31x96[-4.810003280639648,4.869333267211914:A-0.00027058590800271105]]),inputs_embeds:None,use_cache:None,output_attentions:bool=False,output_hidden_states:bool=False,cache_position:None))
        >>> embed_tokens: Embedding
          > ((CT7s2x3[3876,15594:A9096.5],),{})
          > ((CT7s3x4[2392,29616:A20512.083333333332],),{})
          < (CT1s2x3x3072[-0.08262448012828827,0.07648692280054092:A0.00013174691782688906],)
          < (CT1s3x4x3072[-0.08454269170761108,0.07825997471809387:A-2.2594149214932857e-05],)
        <<<
        >>> layers[0]: Phi3DecoderLayer
          > ((CT1s2x3x3072[-0.08262448012828827,0.07648692280054092:A0.00013174691782688906],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.400763034820557,4.259133338928223:A-0.004826152082167923],CT1s2x32x30x96[-4.0469818115234375,4.471982479095459:A0.002217433422628052]], value_cache=#2[CT1s2x32x30x96[-4.405541896820068,4.668274402618408:A0.002596911304140301],CT1s2x32x30x96[-4.310667037963867,4.2944207191467285:A0.00028137129325799627]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
          > ((CT1s3x4x3072[-0.08454269170761108,0.07825997471809387:A-2.2594149214932857e-05],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.828048229217529,4.243164539337158:A0.00031101864580231346],CT1s3x32x31x96[-4.626999855041504,4.47794246673584:A-0.0025560081888748925]], value_cache=#2[CT1s3x32x31x96[-4.349869728088379,4.253993034362793:A0.0011117361406822917],CT1s3x32x31x96[-4.810003280639648,4.869333267211914:A-0.00027058590800271105]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
            >>> self_attn: Phi3Attention
              > ((),dict(hidden_states:CT1s2x3x3072[-4.117458343505859,3.675187587738037:A0.006496929148884729],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.400763034820557,4.259133338928223:A-0.004826152082167923],CT1s2x32x30x96[-4.0469818115234375,4.471982479095459:A0.002217433422628052]], value_cache=#2[CT1s2x32x30x96[-4.405541896820068,4.668274402618408:A0.002596911304140301],CT1s2x32x30x96[-4.310667037963867,4.2944207191467285:A0.00028137129325799627]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
              > ((),dict(hidden_states:CT1s3x4x3072[-4.234438896179199,3.8390731811523438:A-0.0011698477338398423],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.828048229217529,4.243164539337158:A0.00031101864580231346],CT1s3x32x31x96[-4.626999855041504,4.47794246673584:A-0.0025560081888748925]], value_cache=#2[CT1s3x32x31x96[-4.349869728088379,4.253993034362793:A0.0011117361406822917],CT1s3x32x31x96[-4.810003280639648,4.869333267211914:A-0.00027058590800271105]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
                >>> o_proj: Linear
                  > ((CT1s2x3x3072[-1.7804882526397705,1.7273778915405273:A0.003596998118933412],),{})
                  > ((CT1s3x4x3072[-3.1012768745422363,2.9759175777435303:A0.0021097844879451383],),{})
                  < (CT1s2x3x3072[-1.6213274002075195,1.6112512350082397:A0.0015135138111317145],)
                  < (CT1s3x4x3072[-1.6543391942977905,1.7866795063018799:A-0.0027640321040425736],)
                <<<
                >>> qkv_proj: Linear
                  > ((CT1s2x3x3072[-4.117458343505859,3.675187587738037:A0.006496929148884729],),{})
                  > ((CT1s3x4x3072[-4.234438896179199,3.8390731811523438:A-0.0011698477338398423],),{})
                  < (CT1s2x3x9216[-4.603874683380127,4.854994297027588:A-0.007235225691783011],)
                  < (CT1s3x4x9216[-5.057411193847656,4.701048374176025:A-0.0005786661612424487],)
                <<<
              < (CT1s2x3x3072[-1.6213274002075195,1.6112512350082397:A0.0015135138111317145],None)
              < (CT1s3x4x3072[-1.6543391942977905,1.7866795063018799:A-0.0027640321040425736],None)
            <<<
            >>> mlp: Phi3MLP
              > ((CT1s2x3x3072[-4.074462413787842,4.080641269683838:A0.004271472019650317],),{})
              > ((CT1s3x4x3072[-4.160494804382324,4.450273036956787:A-0.007759615481418791],),{})
                >>> gate_up_proj: Linear
                  > ((CT1s2x3x3072[-4.074462413787842,4.080641269683838:A0.004271472019650317],),{})
                  > ((CT1s3x4x3072[-4.160494804382324,4.450273036956787:A-0.007759615481418791],),{})
                  < (CT1s2x3x16384[-4.665277481079102,4.764751434326172:A-0.003516147276544738],)
                  < (CT1s3x4x16384[-4.765247821807861,4.7090325355529785:A-0.0005196295543422972],)
                <<<
                >>> down_proj: Linear
                  > ((CT1s2x3x8192[-12.33570671081543,11.671303749084473:A-0.0007140373650187307],),{})
                  > ((CT1s3x4x8192[-9.95230770111084,10.653127670288086:A-0.0001526525493489491],),{})
                  < (CT1s2x3x3072[-5.095609188079834,5.600772857666016:A0.017001853358547377],)
                  < (CT1s3x4x3072[-5.286900043487549,6.241175174713135:A0.00745626809162382],)
                <<<
                >>> activation_fn: SiLU
                  > ((CT1s2x3x8192[-4.665277481079102,4.764751434326172:A-0.0084779705295972],),{})
                  > ((CT1s3x4x8192[-4.765247821807861,4.606565952301025:A0.002571067230685268],),{})
                  < (CT1s2x3x8192[-0.27846455574035645,4.724475383758545:A0.2441930089948916],)
                  < (CT1s3x4x8192[-0.27846455574035645,4.561019420623779:A0.24662724596589647],)
                <<<
              < (CT1s2x3x3072[-5.095609188079834,5.600772857666016:A0.017001853358547377],)
              < (CT1s3x4x3072[-5.286900043487549,6.241175174713135:A0.00745626809162382],)
            <<<
            >>> input_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-0.08262448012828827,0.07648692280054092:A0.00013174691782688906],),{})
              > ((CT1s3x4x3072[-0.08454269170761108,0.07825997471809387:A-2.2594149214932857e-05],),{})
              < (CT1s2x3x3072[-4.117458343505859,3.675187587738037:A0.006496929148884729],)
              < (CT1s3x4x3072[-4.234438896179199,3.8390731811523438:A-0.0011698477338398423],)
            <<<
            >>> post_attention_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-1.59036123752594,1.5879346132278442:A0.0016452607151254345],),{})
              > ((CT1s3x4x3072[-1.6508663892745972,1.791300892829895:A-0.002786626275726197],),{})
              < (CT1s2x3x3072[-4.074462413787842,4.080641269683838:A0.004271472019650317],)
              < (CT1s3x4x3072[-4.160494804382324,4.450273036956787:A-0.007759615481418791],)
            <<<
            >>> resid_attn_dropout: Dropout
              > ((CT1s2x3x3072[-1.6213274002075195,1.6112512350082397:A0.0015135138111317145],),{})
              > ((CT1s3x4x3072[-1.6543391942977905,1.7866795063018799:A-0.0027640321040425736],),{})
              < (CT1s2x3x3072[-1.6213274002075195,1.6112512350082397:A0.0015135138111317145],)
              < (CT1s3x4x3072[-1.6543391942977905,1.7866795063018799:A-0.0027640321040425736],)
            <<<
            >>> resid_mlp_dropout: Dropout
              > ((CT1s2x3x3072[-5.095609188079834,5.600772857666016:A0.017001853358547377],),{})
              > ((CT1s3x4x3072[-5.286900043487549,6.241175174713135:A0.00745626809162382],),{})
              < (CT1s2x3x3072[-5.095609188079834,5.600772857666016:A0.017001853358547377],)
              < (CT1s3x4x3072[-5.286900043487549,6.241175174713135:A0.00745626809162382],)
            <<<
          < (CT1s2x3x3072[-5.606287479400635,5.762608051300049:A0.018647113798023283],)
          < (CT1s3x4x3072[-5.794071197509766,5.967385292053223:A0.004669641863668201],)
        <<<
        >>> layers[1]: Phi3DecoderLayer
          > ((CT1s2x3x3072[-5.606287479400635,5.762608051300049:A0.018647113798023283],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.061519622802734,4.763699054718018:A-0.005803250564235515],CT1s2x32x30x96[-4.0469818115234375,4.471982479095459:A0.002217433422628052]], value_cache=#2[CT1s2x32x33x96[-4.603874683380127,4.668274402618408:A0.0014253464397422568],CT1s2x32x30x96[-4.310667037963867,4.2944207191467285:A0.00028137129325799627]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
          > ((CT1s3x4x3072[-5.794071197509766,5.967385292053223:A0.004669641863668201],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.101267337799072,5.592326641082764:A0.0004619398607301088],CT1s3x32x31x96[-4.626999855041504,4.47794246673584:A-0.0025560081888748925]], value_cache=#2[CT1s3x32x35x96[-5.057411193847656,4.3394317626953125:A0.00042293995440766975],CT1s3x32x31x96[-4.810003280639648,4.869333267211914:A-0.00027058590800271105]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
            >>> self_attn: Phi3Attention
              > ((),dict(hidden_states:CT1s2x3x3072[-4.0454277992248535,4.159030914306641:A0.013164340832402624],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.061519622802734,4.763699054718018:A-0.005803250564235515],CT1s2x32x30x96[-4.0469818115234375,4.471982479095459:A0.002217433422628052]], value_cache=#2[CT1s2x32x33x96[-4.603874683380127,4.668274402618408:A0.0014253464397422568],CT1s2x32x30x96[-4.310667037963867,4.2944207191467285:A0.00028137129325799627]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
              > ((),dict(hidden_states:CT1s3x4x3072[-3.9066245555877686,4.0234808921813965:A0.003327302300785028],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.101267337799072,5.592326641082764:A0.0004619398607301088],CT1s3x32x31x96[-4.626999855041504,4.47794246673584:A-0.0025560081888748925]], value_cache=#2[CT1s3x32x35x96[-5.057411193847656,4.3394317626953125:A0.00042293995440766975],CT1s3x32x31x96[-4.810003280639648,4.869333267211914:A-0.00027058590800271105]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
                >>> o_proj: Linear
                  > ((CT1s2x3x3072[-2.426612615585327,2.5618062019348145:A0.0013105594592492447],),{})
                  > ((CT1s3x4x3072[-2.194685935974121,1.6999558210372925:A0.0006558623772832772],),{})
                  < (CT1s2x3x3072[-1.6817601919174194,1.8244985342025757:A0.007823421057183724],)
                  < (CT1s3x4x3072[-1.51430082321167,1.9903841018676758:A-0.0005759481610236131],)
                <<<
                >>> qkv_proj: Linear
                  > ((CT1s2x3x3072[-4.0454277992248535,4.159030914306641:A0.013164340832402624],),{})
                  > ((CT1s3x4x3072[-3.9066245555877686,4.0234808921813965:A0.003327302300785028],),{})
                  < (CT1s2x3x9216[-4.424978733062744,4.548888206481934:A0.007687200885247539],)
                  < (CT1s3x4x9216[-5.294676780700684,4.829901695251465:A0.0011063482611651427],)
                <<<
              < (CT1s2x3x3072[-1.6817601919174194,1.8244985342025757:A0.007823421057183724],None)
              < (CT1s3x4x3072[-1.51430082321167,1.9903841018676758:A-0.0005759481610236131],None)
            <<<
            >>> mlp: Phi3MLP
              > ((CT1s2x3x3072[-3.9352612495422363,3.729083776473999:A0.018089336692754848],),{})
              > ((CT1s3x4x3072[-3.7487435340881348,4.227113723754883:A0.002809795542497976],),{})
                >>> gate_up_proj: Linear
                  > ((CT1s2x3x3072[-3.9352612495422363,3.729083776473999:A0.018089336692754848],),{})
                  > ((CT1s3x4x3072[-3.7487435340881348,4.227113723754883:A0.002809795542497976],),{})
                  < (CT1s2x3x16384[-4.6556878089904785,4.95295524597168:A0.005004160443962273],)
                  < (CT1s3x4x16384[-4.945455551147461,5.058211326599121:A-0.0019263709165381708],)
                <<<
                >>> down_proj: Linear
                  > ((CT1s2x3x8192[-8.74767017364502,9.395020484924316:A0.0006071647130518922],),{})
                  > ((CT1s3x4x8192[-13.272783279418945,10.04904842376709:A-0.0015269200680454332],),{})
                  < (CT1s2x3x3072[-5.51132869720459,5.5960211753845215:A-0.0008730196218114846],)
                  < (CT1s3x4x3072[-5.300881862640381,5.086551189422607:A-0.002262299616445615],)
                <<<
                >>> activation_fn: SiLU
                  > ((CT1s2x3x8192[-4.588644027709961,4.672793388366699:A0.0015853615327661903],),{})
                  > ((CT1s3x4x8192[-4.798925399780273,5.058211326599121:A-0.0037998580466028407],),{})
                  < (CT1s2x3x8192[-0.27846455574035645,4.629525184631348:A0.24538704169494321],)
                  < (CT1s3x4x8192[-0.27846455574035645,5.026259422302246:A0.24493035986874703],)
                <<<
              < (CT1s2x3x3072[-5.51132869720459,5.5960211753845215:A-0.0008730196218114846],)
              < (CT1s3x4x3072[-5.300881862640381,5.086551189422607:A-0.002262299616445615],)
            <<<
            >>> input_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-5.606287479400635,5.762608051300049:A0.018647113798023283],),{})
              > ((CT1s3x4x3072[-5.794071197509766,5.967385292053223:A0.004669641863668201],),{})
              < (CT1s2x3x3072[-4.0454277992248535,4.159030914306641:A0.013164340832402624],)
              < (CT1s3x4x3072[-3.9066245555877686,4.0234808921813965:A0.003327302300785028],)
            <<<
            >>> post_attention_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-5.665860176086426,5.33676815032959:A0.02647053424324339],),{})
              > ((CT1s3x4x3072[-5.680884838104248,6.474119186401367:A0.004093693750544642],),{})
              < (CT1s2x3x3072[-3.9352612495422363,3.729083776473999:A0.018089336692754848],)
              < (CT1s3x4x3072[-3.7487435340881348,4.227113723754883:A0.002809795542497976],)
            <<<
            >>> resid_attn_dropout: Dropout
              > ((CT1s2x3x3072[-1.6817601919174194,1.8244985342025757:A0.007823421057183724],),{})
              > ((CT1s3x4x3072[-1.51430082321167,1.9903841018676758:A-0.0005759481610236131],),{})
              < (CT1s2x3x3072[-1.6817601919174194,1.8244985342025757:A0.007823421057183724],)
              < (CT1s3x4x3072[-1.51430082321167,1.9903841018676758:A-0.0005759481610236131],)
            <<<
            >>> resid_mlp_dropout: Dropout
              > ((CT1s2x3x3072[-5.51132869720459,5.5960211753845215:A-0.0008730196218114846],),{})
              > ((CT1s3x4x3072[-5.300881862640381,5.086551189422607:A-0.002262299616445615],),{})
              < (CT1s2x3x3072[-5.51132869720459,5.5960211753845215:A-0.0008730196218114846],)
              < (CT1s3x4x3072[-5.300881862640381,5.086551189422607:A-0.002262299616445615],)
            <<<
          < (CT1s2x3x3072[-7.9020490646362305,8.29597282409668:A0.025597514293015895],)
          < (CT1s3x4x3072[-8.080930709838867,8.061180114746094:A0.0018313946399454533],)
        <<<
        >>> norm: Phi3RMSNorm
          > ((CT1s2x3x3072[-7.9020490646362305,8.29597282409668:A0.025597514293015895],),{})
          > ((CT1s3x4x3072[-8.080930709838867,8.061180114746094:A0.0018313946399454533],),{})
          < (CT1s2x3x3072[-4.05719518661499,4.155795574188232:A0.012949258255610453],)
          < (CT1s3x4x3072[-3.9973807334899902,3.9562745094299316:A0.0008511715070787332],)
        <<<
        >>> rotary_emb: Phi3RotaryEmbedding
          > ((CT1s2x3x3072[-0.08262448012828827,0.07648692280054092:A0.00013174691782688906],CT7s1x3[30,32:A31.0]),{})
          > ((CT1s3x4x3072[-0.08454269170761108,0.07825997471809387:A-2.2594149214932857e-05],CT7s1x4[31,34:A32.5]),{})
          < (CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])
          < (CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])
        <<<
      < (dict(last_hidden_state:CT1s2x3x3072[-4.05719518661499,4.155795574188232:A0.012949258255610453],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.061519622802734,4.763699054718018:A-0.005803250564235515],CT1s2x32x33x96[-5.262935161590576,5.203064441680908:A0.001793874462405861]], value_cache=#2[CT1s2x32x33x96[-4.603874683380127,4.668274402618408:A0.0014253464397422568],CT1s2x32x33x96[-4.3247175216674805,4.2944207191467285:A0.0008252018333532062]])),)
      < (dict(last_hidden_state:CT1s3x4x3072[-3.9973807334899902,3.9562745094299316:A0.0008511715070787332],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.101267337799072,5.592326641082764:A0.0004619398607301088],CT1s3x32x35x96[-6.038402080535889,5.146603584289551:A-0.0029978743856185967]], value_cache=#2[CT1s3x32x35x96[-5.057411193847656,4.3394317626953125:A0.00042293995440766975],CT1s3x32x35x96[-5.294676780700684,4.869333267211914:A3.3132155645670344e-05]])),)
    <<<
    >>> lm_head: Linear
      > ((CT1s2x3x3072[-4.05719518661499,4.155795574188232:A0.012949258255610453],),{})
      > ((CT1s3x4x3072[-3.9973807334899902,3.9562745094299316:A0.0008511715070787332],),{})
      < (CT1s2x3x32064[-4.982978820800781,4.97531270980835:A5.55359560734437e-05],)
      < (CT1s3x4x32064[-5.064220428466797,5.067513465881348:A-0.0019615298296192804],)
    <<<
  < (dict(logits:CT1s2x3x32064[-4.982978820800781,4.97531270980835:A5.55359560734437e-05],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.061519622802734,4.763699054718018:A-0.005803250564235515],CT1s2x32x33x96[-5.262935161590576,5.203064441680908:A0.001793874462405861]], value_cache=#2[CT1s2x32x33x96[-4.603874683380127,4.668274402618408:A0.0014253464397422568],CT1s2x32x33x96[-4.3247175216674805,4.2944207191467285:A0.0008252018333532062]])),)
  < (dict(logits:CT1s3x4x32064[-5.064220428466797,5.067513465881348:A-0.0019615298296192804],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.101267337799072,5.592326641082764:A0.0004619398607301088],CT1s3x32x35x96[-6.038402080535889,5.146603584289551:A-0.0029978743856185967]], value_cache=#2[CT1s3x32x35x96[-5.057411193847656,4.3394317626953125:A0.00042293995440766975],CT1s3x32x35x96[-5.294676780700684,4.869333267211914:A3.3132155645670344e-05]])),)
<<<
[_untrace_forward_execution]  M:__main__-Phi3ForCausalLM
[_untrace_forward_execution] .. M:model-Phi3Model
[_untrace_forward_execution] .... M:embed_tokens-Embedding
[_untrace_forward_execution] .... M:layers[0]-Phi3DecoderLayer
[_untrace_forward_execution] ...... M:self_attn-Phi3Attention
[_untrace_forward_execution] ........ M:o_proj-Linear
[_untrace_forward_execution] ........ M:qkv_proj-Linear
[_untrace_forward_execution] ...... M:mlp-Phi3MLP
[_untrace_forward_execution] ........ M:gate_up_proj-Linear
[_untrace_forward_execution] ........ M:down_proj-Linear
[_untrace_forward_execution] ........ M:activation_fn-SiLU
[_untrace_forward_execution] ...... M:input_layernorm-Phi3RMSNorm
[_untrace_forward_execution] ...... M:post_attention_layernorm-Phi3RMSNorm
[_untrace_forward_execution] ...... M:resid_attn_dropout-Dropout
[_untrace_forward_execution] ...... M:resid_mlp_dropout-Dropout
[_untrace_forward_execution] .... M:layers[1]-Phi3DecoderLayer
[_untrace_forward_execution] ...... M:self_attn-Phi3Attention
[_untrace_forward_execution] ........ M:o_proj-Linear
[_untrace_forward_execution] ........ M:qkv_proj-Linear
[_untrace_forward_execution] ...... M:mlp-Phi3MLP
[_untrace_forward_execution] ........ M:gate_up_proj-Linear
[_untrace_forward_execution] ........ M:down_proj-Linear
[_untrace_forward_execution] ........ M:activation_fn-SiLU
[_untrace_forward_execution] ...... M:input_layernorm-Phi3RMSNorm
[_untrace_forward_execution] ...... M:post_attention_layernorm-Phi3RMSNorm
[_untrace_forward_execution] ...... M:resid_attn_dropout-Dropout
[_untrace_forward_execution] ...... M:resid_mlp_dropout-Dropout
[_untrace_forward_execution] .... M:norm-Phi3RMSNorm
[_untrace_forward_execution] .... M:rotary_emb-Phi3RotaryEmbedding
[_untrace_forward_execution] .. M:lm_head-Linear

Now we keep in memory every input/output for the submodules, we can guess the dynamic shapes for every of them. The final ones:

The dynamic shapes are:
((),
 {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                 min=None,
                                 max=None,
                                 _factory=True),
                     1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                 min=None,
                                 max=None,
                                 _factory=True)},
  'input_ids': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                            min=None,
                            max=None,
                            _factory=True),
                1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                            min=None,
                            max=None,
                            _factory=True)},
  'past_key_values': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                        2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True)},
                       {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                        2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True)}],
                      [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                        2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True)},
                       {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True),
                        2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                    min=None,
                                    max=None,
                                    _factory=True)}]]})

And all the dynamic shapes all along the traced submodules.

print(
    diag.pretty_text(
        with_dynamic_shape=True,
        with_shape=False,
        with_min_max=False,
        with_device=False,
        with_inputs=False,
    ).replace("<_DimHint.DYNAMIC: 3>", "DYN")
)
>>> __main__: Phi3ForCausalLM
  DS=((), {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'input_ids': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'past_key_values': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}], [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}]]})
    >>> model: Phi3Model
      DS=((), {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'cache_position': None, 'input_ids': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, 'past_key_values': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}], [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}]], 'position_ids': None, 'use_cache': None})
        >>> embed_tokens: Embedding: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
        >>> layers[0]: Phi3DecoderLayer
          DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 3: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'cache_position': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'output_attentions': None, 'past_key_value': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}], [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}]], 'position_embeddings': ({1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}), 'position_ids': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'use_cache': None})
            >>> self_attn: Phi3Attention
              DS=((), {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 3: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'cache_position': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'hidden_states': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'output_attentions': None, 'past_key_value': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}], [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}]], 'position_embeddings': ({1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}), 'position_ids': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'use_cache': None})
                >>> o_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
                >>> qkv_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            <<<
            >>> mlp: Phi3MLP
              DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {})
                >>> gate_up_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
                >>> down_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
                >>> activation_fn: SiLU: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            >>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            >>> resid_attn_dropout: Dropout: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            >>> resid_mlp_dropout: Dropout: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
        <<<
        >>> layers[1]: Phi3DecoderLayer
          DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 3: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'cache_position': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'output_attentions': None, 'past_key_value': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}], [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}]], 'position_embeddings': ({1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}), 'position_ids': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'use_cache': None})
            >>> self_attn: Phi3Attention
              DS=((), {'attention_mask': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 3: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'cache_position': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'hidden_states': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'output_attentions': None, 'past_key_value': [[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}], [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 2: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}]], 'position_embeddings': ({1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}), 'position_ids': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, 'use_cache': None})
                >>> o_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
                >>> qkv_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            <<<
            >>> mlp: Phi3MLP
              DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {})
                >>> gate_up_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
                >>> down_proj: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
                >>> activation_fn: SiLU: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            >>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            >>> resid_attn_dropout: Dropout: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
            >>> resid_mlp_dropout: Dropout: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
        <<<
        >>> norm: Phi3RMSNorm: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
        >>> rotary_emb: Phi3RotaryEmbedding: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}), {}) <<<
    <<<
    >>> lm_head: Linear: DS=(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)},), {}) <<<
<<<

Evaluate the export

In many cases, the export (to torch.fx.Graph, to ONNX) does not work on the first try. We need a way to understand how much the model can be exported. It can be used to evaluate the how much code needs to be rewritten or patched to be exportable. The verbosity can be increase to show dynamic shapes, results of the discrepancies. Let’s display the module and its submodule first.

print(
    diag.pretty_text(
        with_dynamic_shape=False,
        with_shape=False,
        with_min_max=False,
        with_device=False,
        with_inputs=False,
    )
)
>>> __main__: Phi3ForCausalLM
    >>> model: Phi3Model
        >>> embed_tokens: Embedding <<<
        >>> layers[0]: Phi3DecoderLayer
            >>> self_attn: Phi3Attention
                >>> o_proj: Linear <<<
                >>> qkv_proj: Linear <<<
            <<<
            >>> mlp: Phi3MLP
                >>> gate_up_proj: Linear <<<
                >>> down_proj: Linear <<<
                >>> activation_fn: SiLU <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm <<<
            >>> post_attention_layernorm: Phi3RMSNorm <<<
            >>> resid_attn_dropout: Dropout <<<
            >>> resid_mlp_dropout: Dropout <<<
        <<<
        >>> layers[1]: Phi3DecoderLayer
            >>> self_attn: Phi3Attention
                >>> o_proj: Linear <<<
                >>> qkv_proj: Linear <<<
            <<<
            >>> mlp: Phi3MLP
                >>> gate_up_proj: Linear <<<
                >>> down_proj: Linear <<<
                >>> activation_fn: SiLU <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm <<<
            >>> post_attention_layernorm: Phi3RMSNorm <<<
            >>> resid_attn_dropout: Dropout <<<
            >>> resid_mlp_dropout: Dropout <<<
        <<<
        >>> norm: Phi3RMSNorm <<<
        >>> rotary_emb: Phi3RotaryEmbedding <<<
    <<<
    >>> lm_head: Linear <<<
<<<

The we try to export to see the submodule failing the whole model. We can pickle the failing model and restore it to speedup the refactoring to make it work.

print("----------------------")
ep = diag.try_export(
    exporter="fx",
    use_dynamic_shapes=True,
    exporter_kwargs=dict(strict=False),
    verbose=1,
)
----------------------




def forward(self, arg0_1: "f32[32064, 3072]", arg1_1: "f32[3072, 3072]", arg2_1: "f32[9216, 3072]", arg3_1: "f32[16384, 3072]", arg4_1: "f32[3072, 8192]", arg5_1: "f32[3072]", arg6_1: "f32[3072]", arg7_1: "f32[3072, 3072]", arg8_1: "f32[9216, 3072]", arg9_1: "f32[16384, 3072]", arg10_1: "f32[3072, 8192]", arg11_1: "f32[3072]", arg12_1: "f32[3072]", arg13_1: "f32[3072]", arg14_1: "f32[32064, 3072]", arg15_1: "f32[48]", arg16_1: "i64[s47, s2]", arg17_1: "i64[s47, s10]", arg18_1: "f32[s89, 32, s66, 96]", arg19_1: "f32[s14, 32, s80, 96]", arg20_1: "f32[s62, 32, s96, 96]", arg21_1: "f32[s58, 32, s81, 96]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
    embedding: "f32[s47, s2, 3072]" = torch.ops.aten.embedding.default(arg0_1, arg16_1, 32000);  arg0_1 = embedding = None

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    sym_numel_default: "Sym(3072*s66*s89)" = torch.ops.aten.sym_numel.default(arg18_1)
    ne: "Sym(True)" = sym_numel_default != 0;  ne = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:455 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
    sym_size_int: "Sym(s66)" = torch.ops.aten.sym_size.int(arg18_1, 2);  arg18_1 = None
    sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(arg16_1, 1)
    add: "Sym(s2 + s66)" = sym_size_int + sym_size_int_1

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:454 in forward, code: cache_position = torch.arange(
    arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False);  add = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:459 in forward, code: position_ids = cache_position.unsqueeze(0)
    unsqueeze: "i64[1, s2]" = torch.ops.aten.unsqueeze.default(arange, 0)

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    ne_1: "Sym(True)" = sym_numel_default != 0;  sym_numel_default = ne_1 = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:461 in forward, code: causal_mask = self._update_causal_mask(
    add_1: "Sym(s2 + s66)" = sym_size_int_1 + sym_size_int;  sym_size_int = None
    lt: "Sym(s2 + s66 < 262144)" = add_1 < 262144;  add_1 = lt = None
    sym_size_int_2: "Sym(s10)" = torch.ops.aten.sym_size.int(arg17_1, 1)
    full: "f32[s2, s10]" = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    arange_1: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1])
    gt: "b8[s2, s10]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
    arange_2: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape_1: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
    sub: "i64[s2, 1]" = torch.ops.aten.sub.Tensor(reshape_1, 262144);  reshape_1 = None
    le: "b8[s2, s10]" = torch.ops.aten.le.Tensor(arange_2, sub);  arange_2 = sub = None
    bitwise_or_: "b8[s2, s10]" = torch.ops.aten.bitwise_or_.Tensor(gt, le);  gt = le = None
    mul_: "f32[s2, s10]" = torch.ops.aten.mul_.Tensor(full, bitwise_or_);  full = bitwise_or_ = None
    unsqueeze_1: "f32[1, s2, s10]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
    unsqueeze_2: "f32[1, 1, s2, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
    eq: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_1 == 9223372036854775807;  sym_size_int_1 = eq = None
    slice_1: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
    eq_1: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_1 = None
    slice_2: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807)
    sym_size_int_3: "Sym(s47)" = torch.ops.aten.sym_size.int(arg16_1, 0);  arg16_1 = None
    expand: "f32[s47, 1, s2, s10]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_3, 1, -1, -1])
    clone: "f32[s47, 1, s2, s10]" = torch.ops.aten.clone.default(expand);  expand = None
    gt_1: "Sym(False)" = sym_size_int_2 > sym_size_int_2;  gt_1 = None
    slice_3: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_4: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_3, 1)
    slice_5: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
    slice_6: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, sym_size_int_2)
    sym_size_int_4: "Sym(s47)" = torch.ops.aten.sym_size.int(arg17_1, 0)
    eq_2: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_4 == 9223372036854775807;  sym_size_int_4 = eq_2 = None
    slice_7: "i64[s47, s10]" = torch.ops.aten.slice.Tensor(arg17_1, 0, 0, 9223372036854775807);  arg17_1 = None
    unsqueeze_3: "i64[s47, 1, s10]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
    unsqueeze_4: "i64[s47, 1, 1, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
    eq_3: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_3 = None
    slice_8: "i64[s47, 1, 1, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
    to: "i64[s47, 1, 1, s10]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
    add_2: "f32[s47, 1, s2, s10]" = torch.ops.aten.add.Tensor(slice_6, to);  to = None
    eq_4: "b8[s47, 1, s2, s10]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
    slice_9: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_10: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
    slice_11: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
    slice_12: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, sym_size_int_2);  slice_11 = None
    masked_fill: "f32[s47, 1, s2, s10]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_4, -3.4028234663852886e+38);  slice_12 = eq_4 = None
    eq_5: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_3 == 9223372036854775807;  sym_size_int_3 = eq_5 = None
    slice_13: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
    slice_14: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807)
    sym_size_int_5: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_1, 2);  slice_1 = None
    eq_6: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_5 == 9223372036854775807;  sym_size_int_5 = eq_6 = None
    slice_15: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
    sym_size_int_6: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_2, 3);  slice_2 = None
    eq_7: "Sym(True)" = sym_size_int_6 == sym_size_int_2;  sym_size_int_2 = eq_7 = None
    sym_size_int_7: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_13, 0);  slice_13 = None
    sym_size_int_8: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_3, 0);  slice_3 = None
    eq_8: "Sym(True)" = sym_size_int_7 == sym_size_int_8;  sym_size_int_7 = sym_size_int_8 = eq_8 = None
    sym_size_int_9: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_15, 2)
    sym_size_int_10: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_5, 2);  slice_5 = None
    eq_9: "Sym(True)" = sym_size_int_9 == sym_size_int_10;  sym_size_int_9 = sym_size_int_10 = eq_9 = None
    sym_size_int_11: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_6, 3);  slice_6 = None
    eq_10: "Sym(True)" = sym_size_int_6 == sym_size_int_11;  sym_size_int_6 = sym_size_int_11 = eq_10 = None
    copy_: "f32[s47, 1, s2, s10]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:468 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze);  unsqueeze = None
    add_3: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
    gt_2: "b8[]" = torch.ops.aten.gt.Scalar(add_3, 4096);  add_3 = None
    ne_2: "b8[]" = torch.ops.aten.ne.Scalar(gt_2, 0);  gt_2 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne_2);  ne_2 = item = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None




def forward(self, arg0_1: "f32[32064, 3072]", arg1_1: "f32[3072, 3072]", arg2_1: "f32[9216, 3072]", arg3_1: "f32[16384, 3072]", arg4_1: "f32[3072, 8192]", arg5_1: "f32[3072]", arg6_1: "f32[3072]", arg7_1: "f32[3072, 3072]", arg8_1: "f32[9216, 3072]", arg9_1: "f32[16384, 3072]", arg10_1: "f32[3072, 8192]", arg11_1: "f32[3072]", arg12_1: "f32[3072]", arg13_1: "f32[3072]", arg14_1: "f32[32064, 3072]", arg15_1: "f32[48]", arg16_1: "i64[s47, s2]", arg17_1: "i64[s47, s10]", arg18_1: "f32[s89, 32, s66, 96]", arg19_1: "f32[s14, 32, s80, 96]", arg20_1: "f32[s62, 32, s96, 96]", arg21_1: "f32[s58, 32, s81, 96]"):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
    embedding: "f32[s47, s2, 3072]" = torch.ops.aten.embedding.default(arg0_1, arg16_1, 32000);  arg0_1 = embedding = None

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    sym_numel_default: "Sym(3072*s66*s89)" = torch.ops.aten.sym_numel.default(arg18_1)
    ne: "Sym(True)" = sym_numel_default != 0;  ne = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:455 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
    sym_size_int: "Sym(s66)" = torch.ops.aten.sym_size.int(arg18_1, 2);  arg18_1 = None
    sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(arg16_1, 1)
    add: "Sym(s2 + s66)" = sym_size_int + sym_size_int_1

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:454 in forward, code: cache_position = torch.arange(
    arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False);  add = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:459 in forward, code: position_ids = cache_position.unsqueeze(0)
    unsqueeze: "i64[1, s2]" = torch.ops.aten.unsqueeze.default(arange, 0)

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    ne_1: "Sym(True)" = sym_numel_default != 0;  sym_numel_default = ne_1 = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:461 in forward, code: causal_mask = self._update_causal_mask(
    add_1: "Sym(s2 + s66)" = sym_size_int_1 + sym_size_int;  sym_size_int = None
    lt: "Sym(s2 + s66 < 262144)" = add_1 < 262144;  add_1 = lt = None
    sym_size_int_2: "Sym(s10)" = torch.ops.aten.sym_size.int(arg17_1, 1)
    full: "f32[s2, s10]" = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    arange_1: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1])
    gt: "b8[s2, s10]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
    arange_2: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape_1: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
    sub: "i64[s2, 1]" = torch.ops.aten.sub.Tensor(reshape_1, 262144);  reshape_1 = None
    le: "b8[s2, s10]" = torch.ops.aten.le.Tensor(arange_2, sub);  arange_2 = sub = None
    bitwise_or_: "b8[s2, s10]" = torch.ops.aten.bitwise_or_.Tensor(gt, le);  gt = le = None
    mul_: "f32[s2, s10]" = torch.ops.aten.mul_.Tensor(full, bitwise_or_);  full = bitwise_or_ = None
    unsqueeze_1: "f32[1, s2, s10]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
    unsqueeze_2: "f32[1, 1, s2, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
    eq: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_1 == 9223372036854775807;  sym_size_int_1 = eq = None
    slice_1: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
    eq_1: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_1 = None
    slice_2: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807)
    sym_size_int_3: "Sym(s47)" = torch.ops.aten.sym_size.int(arg16_1, 0);  arg16_1 = None
    expand: "f32[s47, 1, s2, s10]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_3, 1, -1, -1])
    clone: "f32[s47, 1, s2, s10]" = torch.ops.aten.clone.default(expand);  expand = None
    gt_1: "Sym(False)" = sym_size_int_2 > sym_size_int_2;  gt_1 = None
    slice_3: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_4: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_3, 1)
    slice_5: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
    slice_6: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, sym_size_int_2)
    sym_size_int_4: "Sym(s47)" = torch.ops.aten.sym_size.int(arg17_1, 0)
    eq_2: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_4 == 9223372036854775807;  sym_size_int_4 = eq_2 = None
    slice_7: "i64[s47, s10]" = torch.ops.aten.slice.Tensor(arg17_1, 0, 0, 9223372036854775807);  arg17_1 = None
    unsqueeze_3: "i64[s47, 1, s10]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
    unsqueeze_4: "i64[s47, 1, 1, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
    eq_3: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_3 = None
    slice_8: "i64[s47, 1, 1, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
    to: "i64[s47, 1, 1, s10]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
    add_2: "f32[s47, 1, s2, s10]" = torch.ops.aten.add.Tensor(slice_6, to);  to = None
    eq_4: "b8[s47, 1, s2, s10]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
    slice_9: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_10: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
    slice_11: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
    slice_12: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, sym_size_int_2);  slice_11 = None
    masked_fill: "f32[s47, 1, s2, s10]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_4, -3.4028234663852886e+38);  slice_12 = eq_4 = None
    eq_5: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_3 == 9223372036854775807;  sym_size_int_3 = eq_5 = None
    slice_13: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
    slice_14: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807)
    sym_size_int_5: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_1, 2);  slice_1 = None
    eq_6: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_5 == 9223372036854775807;  sym_size_int_5 = eq_6 = None
    slice_15: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
    sym_size_int_6: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_2, 3);  slice_2 = None
    eq_7: "Sym(True)" = sym_size_int_6 == sym_size_int_2;  sym_size_int_2 = eq_7 = None
    sym_size_int_7: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_13, 0);  slice_13 = None
    sym_size_int_8: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_3, 0);  slice_3 = None
    eq_8: "Sym(True)" = sym_size_int_7 == sym_size_int_8;  sym_size_int_7 = sym_size_int_8 = eq_8 = None
    sym_size_int_9: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_15, 2)
    sym_size_int_10: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_5, 2);  slice_5 = None
    eq_9: "Sym(True)" = sym_size_int_9 == sym_size_int_10;  sym_size_int_9 = sym_size_int_10 = eq_9 = None
    sym_size_int_11: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_6, 3);  slice_6 = None
    eq_10: "Sym(True)" = sym_size_int_6 == sym_size_int_11;  sym_size_int_6 = sym_size_int_11 = eq_10 = None
    copy_: "f32[s47, 1, s2, s10]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:468 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze);  unsqueeze = None
    add_3: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
    gt_2: "b8[]" = torch.ops.aten.gt.Scalar(add_3, 4096);  add_3 = None
    ne_2: "b8[]" = torch.ops.aten.ne.Scalar(gt_2, 0);  gt_2 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne_2);  ne_2 = item = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None

[try_export-FX]  M:__main__-Phi3ForCausalLM --- FAIL, step=EXPORT, reason=Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none) ---  --- Caused by: (_export/non_strict_utils.py:973 in __torch_function__) --- For more information, run with TORCH_LOGS="dynamic" --- For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" --- If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --- For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing ---  --- For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 ---  --- The following call raised this error: ---   File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update ---     if seq_len > original_max_position_embeddings: ---  ---  --- The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.['Traceback (most recent call last):\n', '  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/piece_by_piece.py", line 1587, in _try_export_no_bypass_export\n    ep = torch.export.export(\n         ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export\n    return _export(\n           ^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\n    ep = fn(*args, **kwargs)\n         ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export\n    ep = _export_for_training(\n         ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\n    ep = fn(*args, **kwargs)\n         ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training\n    export_artifact = export_func(\n                      ^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in _non_strict_export\n    aten_export_artifact = _to_aten_func(  # type: ignore[operator]\n                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx\n    gm, graph_signature = transform(_make_fx_helper)(\n                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict\n    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1679, in _make_fx_helper\n    gm = make_fx(\n         ^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped\n    return make_fx_tracer.trace(f, *args)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2228, in trace\n    return self._trace_inner(f, *args)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner\n    t = dispatch_trace(\n        ^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner\n    return disable_fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace\n    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]\n            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1787, in trace\n    res = super().trace(root, concrete_args)\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace\n    (self.create_arg(fn(*args)),),\n                     ^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped\n    out = f(*tensors)  # type:ignore[call-arg]\n          ^^^^^^^^^^^\n', '  File "<string>", line 1, in <lambda>\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1583, in wrapped_fn\n    return tuple(flat_fn(*args))\n                 ^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn\n    tree_out = fn(*args, **kwargs)\n               ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call\n    out = mod(*args[params_len:], **kwargs)\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1885, in forward\n    tree_out = mod(*args, **kwargs)\n               ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/utils/generic.py", line 969, in wrapper\n    output = func(self, *args, **kwargs)\n             ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/models/phi3/modeling_phi3.py", line 744, in forward\n    outputs: BaseModelOutputWithPast = self.model(\n                                       ^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/utils/generic.py", line 969, in wrapper\n    output = func(self, *args, **kwargs)\n             ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/models/phi3/modeling_phi3.py", line 468, in forward\n    position_embeddings = self.rotary_emb(hidden_states, position_ids)\n                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 86, in wrapper\n    longrope_frequency_update(self, position_ids, device=x.device)\n', '  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update\n    if seq_len > original_max_position_embeddings:\n       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1373, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 973, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool\n    r = self.evaluate()\n        ^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate\n    return self.shape_env.evaluate_sym_node(self, size_oblivious)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6857, in evaluate_sym_node\n    return self.evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6876, in evaluate_expr\n    return self._inner_evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper\n    return retlog(fn(*args, **kwargs))\n                  ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6892, in _inner_evaluate_expr\n    return self._evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7160, in _evaluate_expr\n    raise self._make_data_dependent_error(\n', 'torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)\n\nCaused by: (_export/non_strict_utils.py:973 in __torch_function__)\nFor more information, run with TORCH_LOGS="dynamic"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update\n    if seq_len > original_max_position_embeddings:\n\n\nThe error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.\n']



def forward(self, arg0_1: "f32[32064, 3072]", arg1_1: "f32[3072, 3072]", arg2_1: "f32[9216, 3072]", arg3_1: "f32[16384, 3072]", arg4_1: "f32[3072, 8192]", arg5_1: "f32[3072]", arg6_1: "f32[3072]", arg7_1: "f32[3072, 3072]", arg8_1: "f32[9216, 3072]", arg9_1: "f32[16384, 3072]", arg10_1: "f32[3072, 8192]", arg11_1: "f32[3072]", arg12_1: "f32[3072]", arg13_1: "f32[3072]", arg14_1: "f32[48]", arg15_1: "i64[s47, s2]", arg16_1: "i64[s47, s10]", arg17_1, arg18_1: "f32[s89, 32, s66, 96]", arg19_1: "f32[s14, 32, s80, 96]", arg20_1: "f32[s62, 32, s96, 96]", arg21_1: "f32[s58, 32, s81, 96]", arg22_1, arg23_1, arg24_1, arg25_1, arg26_1):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
    embedding: "f32[s47, s2, 3072]" = torch.ops.aten.embedding.default(arg0_1, arg15_1, 32000);  arg0_1 = embedding = None

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    sym_numel_default: "Sym(3072*s66*s89)" = torch.ops.aten.sym_numel.default(arg18_1)
    ne: "Sym(True)" = sym_numel_default != 0;  ne = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:455 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
    sym_size_int: "Sym(s66)" = torch.ops.aten.sym_size.int(arg18_1, 2);  arg18_1 = None
    sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(arg15_1, 1)
    add: "Sym(s2 + s66)" = sym_size_int + sym_size_int_1

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:454 in forward, code: cache_position = torch.arange(
    arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False);  add = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:459 in forward, code: position_ids = cache_position.unsqueeze(0)
    unsqueeze: "i64[1, s2]" = torch.ops.aten.unsqueeze.default(arange, 0)

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    ne_1: "Sym(True)" = sym_numel_default != 0;  sym_numel_default = ne_1 = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:461 in forward, code: causal_mask = self._update_causal_mask(
    add_1: "Sym(s2 + s66)" = sym_size_int_1 + sym_size_int;  sym_size_int = None
    lt: "Sym(s2 + s66 < 262144)" = add_1 < 262144;  add_1 = lt = None
    sym_size_int_2: "Sym(s10)" = torch.ops.aten.sym_size.int(arg16_1, 1)
    full: "f32[s2, s10]" = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    arange_1: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1])
    gt: "b8[s2, s10]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
    arange_2: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape_1: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
    sub: "i64[s2, 1]" = torch.ops.aten.sub.Tensor(reshape_1, 262144);  reshape_1 = None
    le: "b8[s2, s10]" = torch.ops.aten.le.Tensor(arange_2, sub);  arange_2 = sub = None
    bitwise_or_: "b8[s2, s10]" = torch.ops.aten.bitwise_or_.Tensor(gt, le);  gt = le = None
    mul_: "f32[s2, s10]" = torch.ops.aten.mul_.Tensor(full, bitwise_or_);  full = bitwise_or_ = None
    unsqueeze_1: "f32[1, s2, s10]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
    unsqueeze_2: "f32[1, 1, s2, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
    eq: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_1 == 9223372036854775807;  sym_size_int_1 = eq = None
    slice_1: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
    eq_1: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_1 = None
    slice_2: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807)
    sym_size_int_3: "Sym(s47)" = torch.ops.aten.sym_size.int(arg15_1, 0);  arg15_1 = None
    expand: "f32[s47, 1, s2, s10]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_3, 1, -1, -1])
    clone: "f32[s47, 1, s2, s10]" = torch.ops.aten.clone.default(expand);  expand = None
    gt_1: "Sym(False)" = sym_size_int_2 > sym_size_int_2;  gt_1 = None
    slice_3: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_4: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_3, 1)
    slice_5: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
    slice_6: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, sym_size_int_2)
    sym_size_int_4: "Sym(s47)" = torch.ops.aten.sym_size.int(arg16_1, 0)
    eq_2: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_4 == 9223372036854775807;  sym_size_int_4 = eq_2 = None
    slice_7: "i64[s47, s10]" = torch.ops.aten.slice.Tensor(arg16_1, 0, 0, 9223372036854775807);  arg16_1 = None
    unsqueeze_3: "i64[s47, 1, s10]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
    unsqueeze_4: "i64[s47, 1, 1, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
    eq_3: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_3 = None
    slice_8: "i64[s47, 1, 1, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
    to: "i64[s47, 1, 1, s10]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
    add_2: "f32[s47, 1, s2, s10]" = torch.ops.aten.add.Tensor(slice_6, to);  to = None
    eq_4: "b8[s47, 1, s2, s10]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
    slice_9: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_10: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
    slice_11: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
    slice_12: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, sym_size_int_2);  slice_11 = None
    masked_fill: "f32[s47, 1, s2, s10]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_4, -3.4028234663852886e+38);  slice_12 = eq_4 = None
    eq_5: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_3 == 9223372036854775807;  sym_size_int_3 = eq_5 = None
    slice_13: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
    slice_14: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807)
    sym_size_int_5: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_1, 2);  slice_1 = None
    eq_6: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_5 == 9223372036854775807;  sym_size_int_5 = eq_6 = None
    slice_15: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
    sym_size_int_6: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_2, 3);  slice_2 = None
    eq_7: "Sym(True)" = sym_size_int_6 == sym_size_int_2;  sym_size_int_2 = eq_7 = None
    sym_size_int_7: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_13, 0);  slice_13 = None
    sym_size_int_8: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_3, 0);  slice_3 = None
    eq_8: "Sym(True)" = sym_size_int_7 == sym_size_int_8;  sym_size_int_7 = sym_size_int_8 = eq_8 = None
    sym_size_int_9: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_15, 2)
    sym_size_int_10: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_5, 2);  slice_5 = None
    eq_9: "Sym(True)" = sym_size_int_9 == sym_size_int_10;  sym_size_int_9 = sym_size_int_10 = eq_9 = None
    sym_size_int_11: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_6, 3);  slice_6 = None
    eq_10: "Sym(True)" = sym_size_int_6 == sym_size_int_11;  sym_size_int_6 = sym_size_int_11 = eq_10 = None
    copy_: "f32[s47, 1, s2, s10]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:468 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze);  unsqueeze = None
    add_3: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
    gt_2: "b8[]" = torch.ops.aten.gt.Scalar(add_3, 4096);  add_3 = None
    ne_2: "b8[]" = torch.ops.aten.ne.Scalar(gt_2, 0);  gt_2 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne_2);  ne_2 = item = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None




def forward(self, arg0_1: "f32[32064, 3072]", arg1_1: "f32[3072, 3072]", arg2_1: "f32[9216, 3072]", arg3_1: "f32[16384, 3072]", arg4_1: "f32[3072, 8192]", arg5_1: "f32[3072]", arg6_1: "f32[3072]", arg7_1: "f32[3072, 3072]", arg8_1: "f32[9216, 3072]", arg9_1: "f32[16384, 3072]", arg10_1: "f32[3072, 8192]", arg11_1: "f32[3072]", arg12_1: "f32[3072]", arg13_1: "f32[3072]", arg14_1: "f32[48]", arg15_1: "i64[s47, s2]", arg16_1: "i64[s47, s10]", arg17_1, arg18_1: "f32[s89, 32, s66, 96]", arg19_1: "f32[s14, 32, s80, 96]", arg20_1: "f32[s62, 32, s96, 96]", arg21_1: "f32[s58, 32, s81, 96]", arg22_1, arg23_1, arg24_1, arg25_1, arg26_1):
     # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
    embedding: "f32[s47, s2, 3072]" = torch.ops.aten.embedding.default(arg0_1, arg15_1, 32000);  arg0_1 = embedding = None

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    sym_numel_default: "Sym(3072*s66*s89)" = torch.ops.aten.sym_numel.default(arg18_1)
    ne: "Sym(True)" = sym_numel_default != 0;  ne = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:455 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
    sym_size_int: "Sym(s66)" = torch.ops.aten.sym_size.int(arg18_1, 2);  arg18_1 = None
    sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(arg15_1, 1)
    add: "Sym(s2 + s66)" = sym_size_int + sym_size_int_1

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:454 in forward, code: cache_position = torch.arange(
    arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False);  add = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:459 in forward, code: position_ids = cache_position.unsqueeze(0)
    unsqueeze: "i64[1, s2]" = torch.ops.aten.unsqueeze.default(arange, 0)

     # File: ~/vv/this312/lib/python3.12/site-packages/torch/__init__.py:435 in __bool__, code: return builtins.bool(self != 0)
    ne_1: "Sym(True)" = sym_numel_default != 0;  sym_numel_default = ne_1 = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:461 in forward, code: causal_mask = self._update_causal_mask(
    add_1: "Sym(s2 + s66)" = sym_size_int_1 + sym_size_int;  sym_size_int = None
    lt: "Sym(s2 + s66 < 262144)" = add_1 < 262144;  add_1 = lt = None
    sym_size_int_2: "Sym(s10)" = torch.ops.aten.sym_size.int(arg16_1, 1)
    full: "f32[s2, s10]" = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    arange_1: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1])
    gt: "b8[s2, s10]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
    arange_2: "i64[s10]" = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False)
    reshape_1: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
    sub: "i64[s2, 1]" = torch.ops.aten.sub.Tensor(reshape_1, 262144);  reshape_1 = None
    le: "b8[s2, s10]" = torch.ops.aten.le.Tensor(arange_2, sub);  arange_2 = sub = None
    bitwise_or_: "b8[s2, s10]" = torch.ops.aten.bitwise_or_.Tensor(gt, le);  gt = le = None
    mul_: "f32[s2, s10]" = torch.ops.aten.mul_.Tensor(full, bitwise_or_);  full = bitwise_or_ = None
    unsqueeze_1: "f32[1, s2, s10]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
    unsqueeze_2: "f32[1, 1, s2, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
    eq: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_1 == 9223372036854775807;  sym_size_int_1 = eq = None
    slice_1: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
    eq_1: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_1 = None
    slice_2: "f32[1, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807)
    sym_size_int_3: "Sym(s47)" = torch.ops.aten.sym_size.int(arg15_1, 0);  arg15_1 = None
    expand: "f32[s47, 1, s2, s10]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_3, 1, -1, -1])
    clone: "f32[s47, 1, s2, s10]" = torch.ops.aten.clone.default(expand);  expand = None
    gt_1: "Sym(False)" = sym_size_int_2 > sym_size_int_2;  gt_1 = None
    slice_3: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_4: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_3, 1)
    slice_5: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
    slice_6: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, sym_size_int_2)
    sym_size_int_4: "Sym(s47)" = torch.ops.aten.sym_size.int(arg16_1, 0)
    eq_2: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_4 == 9223372036854775807;  sym_size_int_4 = eq_2 = None
    slice_7: "i64[s47, s10]" = torch.ops.aten.slice.Tensor(arg16_1, 0, 0, 9223372036854775807);  arg16_1 = None
    unsqueeze_3: "i64[s47, 1, s10]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
    unsqueeze_4: "i64[s47, 1, 1, s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
    eq_3: "Sym(Eq(s10, 9223372036854775807))" = sym_size_int_2 == 9223372036854775807;  eq_3 = None
    slice_8: "i64[s47, 1, 1, s10]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
    to: "i64[s47, 1, 1, s10]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
    add_2: "f32[s47, 1, s2, s10]" = torch.ops.aten.add.Tensor(slice_6, to);  to = None
    eq_4: "b8[s47, 1, s2, s10]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
    slice_9: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone)
    slice_10: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
    slice_11: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
    slice_12: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, sym_size_int_2);  slice_11 = None
    masked_fill: "f32[s47, 1, s2, s10]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_4, -3.4028234663852886e+38);  slice_12 = eq_4 = None
    eq_5: "Sym(Eq(s47, 9223372036854775807))" = sym_size_int_3 == 9223372036854775807;  sym_size_int_3 = eq_5 = None
    slice_13: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
    slice_14: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807)
    sym_size_int_5: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_1, 2);  slice_1 = None
    eq_6: "Sym(Eq(s2, 9223372036854775807))" = sym_size_int_5 == 9223372036854775807;  sym_size_int_5 = eq_6 = None
    slice_15: "f32[s47, 1, s2, s10]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
    sym_size_int_6: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_2, 3);  slice_2 = None
    eq_7: "Sym(True)" = sym_size_int_6 == sym_size_int_2;  sym_size_int_2 = eq_7 = None
    sym_size_int_7: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_13, 0);  slice_13 = None
    sym_size_int_8: "Sym(s47)" = torch.ops.aten.sym_size.int(slice_3, 0);  slice_3 = None
    eq_8: "Sym(True)" = sym_size_int_7 == sym_size_int_8;  sym_size_int_7 = sym_size_int_8 = eq_8 = None
    sym_size_int_9: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_15, 2)
    sym_size_int_10: "Sym(s2)" = torch.ops.aten.sym_size.int(slice_5, 2);  slice_5 = None
    eq_9: "Sym(True)" = sym_size_int_9 == sym_size_int_10;  sym_size_int_9 = sym_size_int_10 = eq_9 = None
    sym_size_int_11: "Sym(s10)" = torch.ops.aten.sym_size.int(slice_6, 3);  slice_6 = None
    eq_10: "Sym(True)" = sym_size_int_6 == sym_size_int_11;  sym_size_int_6 = sym_size_int_11 = eq_10 = None
    copy_: "f32[s47, 1, s2, s10]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

     # File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:468 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze);  unsqueeze = None
    add_3: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
    gt_2: "b8[]" = torch.ops.aten.gt.Scalar(add_3, 4096);  add_3 = None
    ne_2: "b8[]" = torch.ops.aten.ne.Scalar(gt_2, 0);  gt_2 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne_2);  ne_2 = item = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None

[try_export-FX] .. M:model-Phi3Model --- FAIL, step=EXPORT, reason=Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none) ---  --- Caused by: (_export/non_strict_utils.py:973 in __torch_function__) --- For more information, run with TORCH_LOGS="dynamic" --- For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" --- If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --- For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing ---  --- For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 ---  --- The following call raised this error: ---   File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update ---     if seq_len > original_max_position_embeddings: ---  ---  --- The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.['Traceback (most recent call last):\n', '  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/piece_by_piece.py", line 1587, in _try_export_no_bypass_export\n    ep = torch.export.export(\n         ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export\n    return _export(\n           ^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\n    ep = fn(*args, **kwargs)\n         ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export\n    ep = _export_for_training(\n         ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\n    ep = fn(*args, **kwargs)\n         ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training\n    export_artifact = export_func(\n                      ^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in _non_strict_export\n    aten_export_artifact = _to_aten_func(  # type: ignore[operator]\n                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx\n    gm, graph_signature = transform(_make_fx_helper)(\n                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict\n    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1679, in _make_fx_helper\n    gm = make_fx(\n         ^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped\n    return make_fx_tracer.trace(f, *args)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2228, in trace\n    return self._trace_inner(f, *args)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner\n    t = dispatch_trace(\n        ^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner\n    return disable_fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace\n    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]\n            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1787, in trace\n    res = super().trace(root, concrete_args)\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace\n    (self.create_arg(fn(*args)),),\n                     ^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped\n    out = f(*tensors)  # type:ignore[call-arg]\n          ^^^^^^^^^^^\n', '  File "<string>", line 1, in <lambda>\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1583, in wrapped_fn\n    return tuple(flat_fn(*args))\n                 ^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn\n    tree_out = fn(*args, **kwargs)\n               ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call\n    out = mod(*args[params_len:], **kwargs)\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1885, in forward\n    tree_out = mod(*args, **kwargs)\n               ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/utils/generic.py", line 969, in wrapper\n    output = func(self, *args, **kwargs)\n             ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/models/phi3/modeling_phi3.py", line 468, in forward\n    position_embeddings = self.rotary_emb(hidden_states, position_ids)\n                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 86, in wrapper\n    longrope_frequency_update(self, position_ids, device=x.device)\n', '  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update\n    if seq_len > original_max_position_embeddings:\n       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1373, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 973, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool\n    r = self.evaluate()\n        ^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate\n    return self.shape_env.evaluate_sym_node(self, size_oblivious)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6857, in evaluate_sym_node\n    return self.evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6876, in evaluate_expr\n    return self._inner_evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper\n    return retlog(fn(*args, **kwargs))\n                  ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6892, in _inner_evaluate_expr\n    return self._evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7160, in _evaluate_expr\n    raise self._make_data_dependent_error(\n', 'torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)\n\nCaused by: (_export/non_strict_utils.py:973 in __torch_function__)\nFor more information, run with TORCH_LOGS="dynamic"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update\n    if seq_len > original_max_position_embeddings:\n\n\nThe error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.\n']
~/github/onnx-diagnostic/onnx_diagnostic/helpers/helper.py:1285: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  float(diff.max()),
[try_export-FX] .... M:embed_tokens-Embedding --- OK:
[try_export-FX] .... M:layers[0]-Phi3DecoderLayer --- OK:
[try_export-FX] .... M:layers[1]-Phi3DecoderLayer --- OK:
[try_export-FX] .... M:norm-Phi3RMSNorm --- OK:



def forward(self, arg0_1: "f32[48]", arg1_1: "f32[s35, s16, 3072]", arg2_1: "i64[1, s43]"):
    # No stacktrace found for following nodes
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    max_1: "i64[]" = torch.ops.aten.max.default(arg2_1);  arg2_1 = None
    add: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
    gt: "b8[]" = torch.ops.aten.gt.Scalar(add, 4096);  add = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0);  gt = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None




def forward(self, arg0_1: "f32[48]", arg1_1: "f32[s35, s16, 3072]", arg2_1: "i64[1, s43]"):
    # No stacktrace found for following nodes
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    max_1: "i64[]" = torch.ops.aten.max.default(arg2_1);  arg2_1 = None
    add: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1);  max_1 = None
    gt: "b8[]" = torch.ops.aten.gt.Scalar(add, 4096);  add = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0);  gt = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None

[try_export-FX] .... M:rotary_emb-Phi3RotaryEmbedding --- FAIL, step=EXPORT, reason=Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none) ---  --- Caused by: (_export/non_strict_utils.py:973 in __torch_function__) --- For more information, run with TORCH_LOGS="dynamic" --- For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" --- If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --- For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing ---  --- For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 ---  --- The following call raised this error: ---   File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update ---     if seq_len > original_max_position_embeddings: ---  ---  --- The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.['Traceback (most recent call last):\n', '  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/piece_by_piece.py", line 1587, in _try_export_no_bypass_export\n    ep = torch.export.export(\n         ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export\n    return _export(\n           ^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\n    ep = fn(*args, **kwargs)\n         ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export\n    ep = _export_for_training(\n         ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\n    raise e\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\n    ep = fn(*args, **kwargs)\n         ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training\n    export_artifact = export_func(\n                      ^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in _non_strict_export\n    aten_export_artifact = _to_aten_func(  # type: ignore[operator]\n                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx\n    gm, graph_signature = transform(_make_fx_helper)(\n                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict\n    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1679, in _make_fx_helper\n    gm = make_fx(\n         ^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped\n    return make_fx_tracer.trace(f, *args)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2228, in trace\n    return self._trace_inner(f, *args)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner\n    t = dispatch_trace(\n        ^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner\n    return disable_fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace\n    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]\n            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1787, in trace\n    res = super().trace(root, concrete_args)\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\n    return fn(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace\n    (self.create_arg(fn(*args)),),\n                     ^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped\n    out = f(*tensors)  # type:ignore[call-arg]\n          ^^^^^^^^^^^\n', '  File "<string>", line 1, in <lambda>\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1583, in wrapped_fn\n    return tuple(flat_fn(*args))\n                 ^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn\n    tree_out = fn(*args, **kwargs)\n               ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call\n    out = mod(*args[params_len:], **kwargs)\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1885, in forward\n    tree_out = mod(*args, **kwargs)\n               ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\n    return self.call_module(mod, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\n    return Tracer.call_module(self, m, forward, args, kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\n    ret_val = forward(*args, **kwargs)\n              ^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\n    return _orig_module_call(mod, *args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\n    return self._call_impl(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\n    return forward_call(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 86, in wrapper\n    longrope_frequency_update(self, position_ids, device=x.device)\n', '  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update\n    if seq_len > original_max_position_embeddings:\n       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1373, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 973, in __torch_function__\n    return func(*args, **kwargs)\n           ^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 536, in guard_bool\n    r = self.evaluate()\n        ^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate\n    return self.shape_env.evaluate_sym_node(self, size_oblivious)\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6857, in evaluate_sym_node\n    return self.evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6876, in evaluate_expr\n    return self._inner_evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper\n    return retlog(fn(*args, **kwargs))\n                  ^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6892, in _inner_evaluate_expr\n    return self._evaluate_expr(\n           ^^^^^^^^^^^^^^^^^^^^\n', '  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7160, in _evaluate_expr\n    raise self._make_data_dependent_error(\n', 'torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)\n\nCaused by: (_export/non_strict_utils.py:973 in __torch_function__)\nFor more information, run with TORCH_LOGS="dynamic"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n  File "~/github/transformers/src/transformers/modeling_rope_utils.py", line 50, in longrope_frequency_update\n    if seq_len > original_max_position_embeddings:\n\n\nThe error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.\n']
[try_export-FX] .... M:rotary_emb-Phi3RotaryEmbedding --- FAIL: Could not guard on data-depend...
[try_export-FX] .. M:lm_head-Linear --- OK:

Let’s display a report.

print(f"success: {ep.status}")
print(diag.get_export_report())
success: 2
__main__                         Phi3ForCausalLM       FAIL -- step=EXPORT, reason='Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: n...'
..model                          Phi3Model             FAIL -- step=EXPORT, reason='Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: n...'
....embed_tokens                 Embedding             OK -- ExportedProgram
....layers[0]                    Phi3DecoderLayer      OK -- ExportedProgram
......self_attn                  Phi3Attention         <OK-2i-0>
........o_proj                   Linear                <OK-2i-0>
........qkv_proj                 Linear                <OK-2i-0>
......mlp                        Phi3MLP               <OK-2i-0>
........gate_up_proj             Linear                <OK-2i-0>
........down_proj                Linear                <OK-2i-0>
........activation_fn            SiLU                  <OK-2i-0>
......input_layernorm            Phi3RMSNorm           <OK-2i-0>
......post_attention_layernorm   Phi3RMSNorm           <OK-2i-0>
......resid_attn_dropout         Dropout               <OK-2i-0>
......resid_mlp_dropout          Dropout               <OK-2i-0>
....layers[1]                    Phi3DecoderLayer      OK -- ExportedProgram
......self_attn                  Phi3Attention         <OK-2i-0>
........o_proj                   Linear                <OK-2i-0>
........qkv_proj                 Linear                <OK-2i-0>
......mlp                        Phi3MLP               <OK-2i-0>
........gate_up_proj             Linear                <OK-2i-0>
........down_proj                Linear                <OK-2i-0>
........activation_fn            SiLU                  <OK-2i-0>
......input_layernorm            Phi3RMSNorm           <OK-2i-0>
......post_attention_layernorm   Phi3RMSNorm           <OK-2i-0>
......resid_attn_dropout         Dropout               <OK-2i-0>
......resid_mlp_dropout          Dropout               <OK-2i-0>
....norm                         Phi3RMSNorm           OK -- ExportedProgram
....rotary_emb                   Phi3RotaryEmbedding   FAIL -- step=EXPORT, reason='Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: n...'
..lm_head                        Linear                OK -- ExportedProgram

Replace the failing module by a custom op

The main module is not exportable because one piece cannot be exported. But maybe if we assume it works, maybe everything else is working. So let’s try to replace this class by a custom op. This will be something for another example.

Total running time of the script: (0 minutes 7.551 seconds)

Related examples

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with draft_export

Export Phi-3.5-mini-instruct with draft_export

to_onnx and Phi-2

to_onnx and Phi-2

torch.onnx.export and Phi-2

torch.onnx.export and Phi-2

Infer dynamic shapes before exporting

Infer dynamic shapes before exporting

Gallery generated by Sphinx-Gallery