to_onnx, failures Phi-3.5-mini-instruct

Example to_onnx and Phi-2 shows how to export a simple LLM model with dynamic shapes. What if it does not work?

Model

import pprint
from typing import Any, Dict
import torch
import transformers
from experimental_experiment.helpers import string_type
from experimental_experiment.torch_interpreter.diagnose import infer_shape_type_from_execution


def get_phi2_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with two sets of inputs and different shapes.

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-3.5-mini-instruct/config.json
    <https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "Phi-3.5-mini-instruct",
        "architectures": ["Phi3ForCausalLM"],
        "attention_dropout": 0.0,
        "auto_map": {
            "AutoConfig": "configuration_phi3.Phi3Config",
            "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
        },
        "bos_token_id": 1,
        "embd_pdrop": 0.0,
        "eos_token_id": 32000,
        "hidden_act": "silu",
        "hidden_size": 3072,
        "initializer_range": 0.02,
        "intermediate_size": 8192,
        "max_position_embeddings": 131072,
        "model_type": "phi3",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "original_max_position_embeddings": 4096,
        "pad_token_id": 32000,
        "resid_pdrop": 0.0,
        "rms_norm_eps": 1e-05,
        "rope_scaling": {
            "long_factor": [
                1.0800000429153442,
                1.1100000143051147,
                1.1399999856948853,
                1.340000033378601,
                1.5899999141693115,
                1.600000023841858,
                1.6200000047683716,
                2.620000123977661,
                3.2300000190734863,
                3.2300000190734863,
                4.789999961853027,
                7.400000095367432,
                7.700000286102295,
                9.09000015258789,
                12.199999809265137,
                17.670000076293945,
                24.46000099182129,
                28.57000160217285,
                30.420001983642578,
                30.840002059936523,
                32.590003967285156,
                32.93000411987305,
                42.320003509521484,
                44.96000289916992,
                50.340003967285156,
                50.45000457763672,
                57.55000305175781,
                57.93000411987305,
                58.21000289916992,
                60.1400032043457,
                62.61000442504883,
                62.62000274658203,
                62.71000289916992,
                63.1400032043457,
                63.1400032043457,
                63.77000427246094,
                63.93000411987305,
                63.96000289916992,
                63.970001220703125,
                64.02999877929688,
                64.06999969482422,
                64.08000183105469,
                64.12000274658203,
                64.41000366210938,
                64.4800033569336,
                64.51000213623047,
                64.52999877929688,
                64.83999633789062,
            ],
            "short_factor": [
                1.0,
                1.0199999809265137,
                1.0299999713897705,
                1.0299999713897705,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0699999332427979,
                1.0999999046325684,
                1.1099998950958252,
                1.1599998474121094,
                1.1599998474121094,
                1.1699998378753662,
                1.2899998426437378,
                1.339999794960022,
                1.679999828338623,
                1.7899998426437378,
                1.8199998140335083,
                1.8499997854232788,
                1.8799997568130493,
                1.9099997282028198,
                1.9399996995925903,
                1.9899996519088745,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0799996852874756,
                2.0899996757507324,
                2.189999580383301,
                2.2199995517730713,
                2.5899994373321533,
                2.729999542236328,
                2.749999523162842,
                2.8399994373321533,
            ],
            "type": "longrope",
        },
        "rope_theta": 10000.0,
        "sliding_window": 262144,
        "tie_word_embeddings": False,
        "torch_dtype": "bfloat16",
        "use_cache": True,
        "attention_bias": False,
        "vocab_size": 32064,
    }
    config.update(**kwargs)
    conf = transformers.Phi3Config(**config)
    model = transformers.Phi3ForCausalLM(conf)
    model.eval()

    cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache.update(
            torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96), i
        )
    cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache2.update(
            torch.randn(batch_size + 1, 32, 31, 96),
            torch.randn(batch_size + 1, 32, 31, 96),
            i,
        )

    inputs = dict(
        input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    return dict(inputs=inputs, model=model, inputs2=inputs2)


data = get_phi2_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))

Dynamic Shapes

We want to infer the dynamic shapes from the two sets of inputs we gave. For that, we use a function to trace the execution of the model including its submodules. It is going to execute the model twice with the two sets of inputs and stores every intermediate input and output.

diag = infer_shape_type_from_execution(model, [inputs, inputs2], verbose=2)
[_trace_forward_execution] __main__ - Phi3ForCausalLM
[_trace_forward_execution] ..model - Phi3Model
[_trace_forward_execution] ..lm_head - Linear
[infer_shape_type_from_execution] run with dict(args:(),kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model]   > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:int,output_hidden_states:int,return_dict:int,cache_position:None)
[model:Phi3Model]   < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear]   > T1r3
[lm_head:Linear]   < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[infer_shape_type_from_execution] run with dict(args:(),kwargs:dict(input_ids:T7s3x4,attention_mask:T7s3x35,past_key_values:DynamicCache(key_cache=#2[T1s3x32x31x96,T1s3x32x31x96], value_cache=#2[T1s3x32x31x96,T1s3x32x31x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model]   > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:int,output_hidden_states:int,return_dict:int,cache_position:None)
[model:Phi3Model]   < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear]   > T1r3
[lm_head:Linear]   < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_forward_execution] traced execution of model Phi3ForCausalLM
>>> __main__: Phi3ForCausalLM
  > ((),dict(input_ids:CT7s2x3[4480,30643:A16788.666666666668],attention_mask:CT7s2x33[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]])))
  > ((),dict(input_ids:CT7s3x4[3842,31602:A20204.833333333332],attention_mask:CT7s3x35[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]])))
    >>> model: Phi3Model
      > ((),dict(input_ids:CT7s2x3[4480,30643:A16788.666666666668],attention_mask:CT7s2x33[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),inputs_embeds:None,use_cache:None,output_attentions:int[False],output_hidden_states:int[False],return_dict:int[True],cache_position:None))
      > ((),dict(input_ids:CT7s3x4[3842,31602:A20204.833333333332],attention_mask:CT7s3x35[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),inputs_embeds:None,use_cache:None,output_attentions:int[False],output_hidden_states:int[False],return_dict:int[True],cache_position:None))
        >>> embed_tokens: Embedding
          > ((CT7s2x3[4480,30643:A16788.666666666668],),dict())
          > ((CT7s3x4[3842,31602:A20204.833333333332],),dict())
          < (CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],)
          < (CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],)
        <<<
        >>> layers[0]: Phi3DecoderLayer
          > ((CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
          > ((CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
            >>> self_attn: Phi3Attention
              > ((),dict(hidden_states:CT1s2x3x3072[-3.6068482398986816,4.1260199546813965:A0.0041334081171583315],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
              > ((),dict(hidden_states:CT1s3x4x3072[-3.7820181846618652,4.1085357666015625:A0.0015415984938473353],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
                >>> o_proj: Linear
                  > ((CT1s2x3x3072[-2.397754669189453,1.8615864515304565:A-0.00028552267273683637],),dict())
                  > ((CT1s3x4x3072[-2.073399305343628,2.3145229816436768:A0.0001120645541599568],),dict())
                  < (CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],)
                  < (CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],)
                <<<
                >>> qkv_proj: Linear
                  > ((CT1s2x3x3072[-3.6068482398986816,4.1260199546813965:A0.0041334081171583315],),dict())
                  > ((CT1s3x4x3072[-3.7820181846618652,4.1085357666015625:A0.0015415984938473353],),dict())
                  < (CT1s2x3x9216[-5.016335487365723,4.414941787719727:A-0.002557053032576193],)
                  < (CT1s3x4x9216[-5.139618396759033,4.558834552764893:A-0.001959620447271012],)
                <<<
              < (CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],None)
              < (CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],None)
            <<<
            >>> mlp: Phi3MLP
              > ((CT1s2x3x3072[-4.1014404296875,3.9666833877563477:A-0.012948900215022855],),dict())
              > ((CT1s3x4x3072[-4.106596946716309,4.0618391036987305:A0.004948446015071421],),dict())
                >>> gate_up_proj: Linear
                  > ((CT1s2x3x3072[-4.1014404296875,3.9666833877563477:A-0.012948900215022855],),dict())
                  > ((CT1s3x4x3072[-4.106596946716309,4.0618391036987305:A0.004948446015071421],),dict())
                  < (CT1s2x3x16384[-4.869340419769287,5.286735534667969:A-0.00018809297783898651],)
                  < (CT1s3x4x16384[-4.958648204803467,4.989035606384277:A-0.0033986618244264597],)
                <<<
                >>> down_proj: Linear
                  > ((CT1s2x3x8192[-8.716111183166504,8.193013191223145:A-0.0011114813027920483],),dict())
                  > ((CT1s3x4x8192[-9.956159591674805,11.3091459274292:A0.00025204877698488506],),dict())
                  < (CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],)
                  < (CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],)
                <<<
                >>> activation_fn: SiLU
                  > ((CT1s2x3x8192[-4.831746578216553,4.276060581207275:A0.004578614053405279],),dict())
                  > ((CT1s3x4x8192[-4.439509391784668,4.847005367279053:A-0.004492188054667186],),dict())
                  < (CT1s2x3x8192[-0.27846455574035645,4.217449188232422:A0.24632204136495695],)
                  < (CT1s3x4x8192[-0.27846455574035645,4.809244155883789:A0.24241726201152414],)
                <<<
              < (CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],)
              < (CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],)
            <<<
            >>> input_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],),dict())
              > ((CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],),dict())
              < (CT1s2x3x3072[-3.6068482398986816,4.1260199546813965:A0.0041334081171583315],)
              < (CT1s3x4x3072[-3.7820181846618652,4.1085357666015625:A0.0015415984938473353],)
            <<<
            >>> post_attention_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-1.7519030570983887,1.521032452583313:A-0.004998811839893986],),dict())
              > ((CT1s3x4x3072[-1.5569356679916382,1.542656421661377:A0.0019694551863583204],),dict())
              < (CT1s2x3x3072[-4.1014404296875,3.9666833877563477:A-0.012948900215022855],)
              < (CT1s3x4x3072[-4.106596946716309,4.0618391036987305:A0.004948446015071421],)
            <<<
            >>> resid_attn_dropout: Dropout
              > ((CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],),dict())
              > ((CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],),dict())
              < (CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],)
              < (CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],)
            <<<
            >>> resid_mlp_dropout: Dropout
              > ((CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],),dict())
              > ((CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],),dict())
              < (CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],)
              < (CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],)
            <<<
          < (CT1s2x3x3072[-5.366197109222412,5.397191524505615:A0.0053145358553794925],)
          < (CT1s3x4x3072[-6.367547035217285,5.992779731750488:A-0.006835320722656333],)
        <<<
        >>> layers[1]: Phi3DecoderLayer
          > ((CT1s2x3x3072[-5.366197109222412,5.397191524505615:A0.0053145358553794925],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
          > ((CT1s3x4x3072[-6.367547035217285,5.992779731750488:A-0.006835320722656333],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
            >>> self_attn: Phi3Attention
              > ((),dict(hidden_states:CT1s2x3x3072[-3.7633559703826904,3.8344295024871826:A0.003907856796488652],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
              > ((),dict(hidden_states:CT1s3x4x3072[-4.528147220611572,4.306519031524658:A-0.0049134756377199945],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
                >>> o_proj: Linear
                  > ((CT1s2x3x3072[-2.522385597229004,2.634071111679077:A0.005504650073209956],),dict())
                  > ((CT1s3x4x3072[-2.916518211364746,2.487565040588379:A0.001905495911235554],),dict())
                  < (CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],)
                  < (CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],)
                <<<
                >>> qkv_proj: Linear
                  > ((CT1s2x3x3072[-3.7633559703826904,3.8344295024871826:A0.003907856796488652],),dict())
                  > ((CT1s3x4x3072[-4.528147220611572,4.306519031524658:A-0.0049134756377199945],),dict())
                  < (CT1s2x3x9216[-5.106101989746094,4.490701675415039:A-0.002418912901332617],)
                  < (CT1s3x4x9216[-4.6034040451049805,4.716505527496338:A-0.0005003103740501012],)
                <<<
              < (CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],None)
              < (CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],None)
            <<<
            >>> mlp: Phi3MLP
              > ((CT1s2x3x3072[-3.7263343334198,3.9119250774383545:A0.0003562870679852084],),dict())
              > ((CT1s3x4x3072[-4.118927955627441,4.210121154785156:A-0.005571656055501102],),dict())
                >>> gate_up_proj: Linear
                  > ((CT1s2x3x3072[-3.7263343334198,3.9119250774383545:A0.0003562870679852084],),dict())
                  > ((CT1s3x4x3072[-4.118927955627441,4.210121154785156:A-0.005571656055501102],),dict())
                  < (CT1s2x3x16384[-4.767377853393555,4.763617992401123:A-0.0011706767508409637],)
                  < (CT1s3x4x16384[-4.816030025482178,5.136754989624023:A0.00020893273096452467],)
                <<<
                >>> down_proj: Linear
                  > ((CT1s2x3x8192[-10.276310920715332,8.63280963897705:A0.007607944792103967],),dict())
                  > ((CT1s3x4x8192[-9.172404289245605,10.50564956665039:A-0.0007796187043171565],),dict())
                  < (CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],)
                  < (CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],)
                <<<
                >>> activation_fn: SiLU
                  > ((CT1s2x3x8192[-4.767377853393555,4.6629109382629395:A-0.0019218210885677915],),dict())
                  > ((CT1s3x4x8192[-4.816030025482178,4.801560878753662:A0.0034934405888857136],),dict())
                  < (CT1s2x3x8192[-0.27846455574035645,4.619309425354004:A0.2432932592988779],)
                  < (CT1s3x4x8192[-0.27846455574035645,4.7624287605285645:A0.24843786659936437],)
                <<<
              < (CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],)
              < (CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],)
            <<<
            >>> input_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-5.366197109222412,5.397191524505615:A0.0053145358553794925],),dict())
              > ((CT1s3x4x3072[-6.367547035217285,5.992779731750488:A-0.006835320722656333],),dict())
              < (CT1s2x3x3072[-3.7633559703826904,3.8344295024871826:A0.003907856796488652],)
              < (CT1s3x4x3072[-4.528147220611572,4.306519031524658:A-0.0049134756377199945],)
            <<<
            >>> post_attention_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-5.602095603942871,5.5913519859313965:A0.00028517031084144645],),dict())
              > ((CT1s3x4x3072[-5.972293376922607,6.098865509033203:A-0.008048094156871835],),dict())
              < (CT1s2x3x3072[-3.7263343334198,3.9119250774383545:A0.0003562870679852084],)
              < (CT1s3x4x3072[-4.118927955627441,4.210121154785156:A-0.005571656055501102],)
            <<<
            >>> resid_attn_dropout: Dropout
              > ((CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],),dict())
              > ((CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],),dict())
              < (CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],)
              < (CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],)
            <<<
            >>> resid_mlp_dropout: Dropout
              > ((CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],),dict())
              > ((CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],),dict())
              < (CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],)
              < (CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],)
            <<<
          < (CT1s2x3x3072[-8.54723834991455,7.845819473266602:A-0.0017686810511400432],)
          < (CT1s3x4x3072[-9.078064918518066,8.243967056274414:A-0.004090411359306422],)
        <<<
        >>> norm: Phi3RMSNorm
          > ((CT1s2x3x3072[-8.54723834991455,7.845819473266602:A-0.0017686810511400432],),dict())
          > ((CT1s3x4x3072[-9.078064918518066,8.243967056274414:A-0.004090411359306422],),dict())
          < (CT1s2x3x3072[-4.311580657958984,3.90199875831604:A-0.0008492474549339275],)
          < (CT1s3x4x3072[-4.496620178222656,4.30377721786499:A-0.0018928649814797928],)
        <<<
        >>> rotary_emb: Phi3RotaryEmbedding
          > ((CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],CT7s1x3[30,32:A31.0]),dict())
          > ((CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],CT7s1x4[31,34:A32.5]),dict())
          < (CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])
          < (CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])
        <<<
      < (dict(last_hidden_state:CT1s2x3x3072[-4.311580657958984,3.90199875831604:A-0.0008492474549339275],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x33x96[-6.076218128204346,4.901590347290039:A-0.000825817244353194]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x33x96[-4.364691257476807,4.490701675415039:A0.0002260386739006028]])),)
      < (dict(last_hidden_state:CT1s3x4x3072[-4.496620178222656,4.30377721786499:A-0.0018928649814797928],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x35x96[-5.430703639984131,5.21905517578125:A-0.0001790060691358257]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x35x96[-4.4542317390441895,4.643773555755615:A0.0006424695788216192]])),)
    <<<
    >>> lm_head: Linear
      > ((CT1s2x3x3072[-4.311580657958984,3.90199875831604:A-0.0008492474549339275],),dict())
      > ((CT1s3x4x3072[-4.496620178222656,4.30377721786499:A-0.0018928649814797928],),dict())
      < (CT1s2x3x32064[-4.8934478759765625,5.155729293823242:A-0.004496973525844078],)
      < (CT1s3x4x32064[-4.830400466918945,5.16563606262207:A0.0007996815108034524],)
    <<<
  < (dict(logits:CT1s2x3x32064[-4.8934478759765625,5.155729293823242:A-0.004496973525844078],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x33x96[-6.076218128204346,4.901590347290039:A-0.000825817244353194]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x33x96[-4.364691257476807,4.490701675415039:A0.0002260386739006028]])),)
  < (dict(logits:CT1s3x4x32064[-4.830400466918945,5.16563606262207:A0.0007996815108034524],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x35x96[-5.430703639984131,5.21905517578125:A-0.0001790060691358257]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x35x96[-4.4542317390441895,4.643773555755615:A0.0006424695788216192]])),)
<<<
[_untrace_forward_execution] __main__ - Phi3ForCausalLM
[_untrace_forward_execution] ..model - Phi3Model
[_untrace_forward_execution] ....embed_tokens - Embedding
[_untrace_forward_execution] ....layers[0] - Phi3DecoderLayer
[_untrace_forward_execution] ......self_attn - Phi3Attention
[_untrace_forward_execution] ........o_proj - Linear
[_untrace_forward_execution] ........qkv_proj - Linear
[_untrace_forward_execution] ......mlp - Phi3MLP
[_untrace_forward_execution] ........gate_up_proj - Linear
[_untrace_forward_execution] ........down_proj - Linear
[_untrace_forward_execution] ........activation_fn - SiLU
[_untrace_forward_execution] ......input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......resid_attn_dropout - Dropout
[_untrace_forward_execution] ......resid_mlp_dropout - Dropout
[_untrace_forward_execution] ....layers[1] - Phi3DecoderLayer
[_untrace_forward_execution] ......self_attn - Phi3Attention
[_untrace_forward_execution] ........o_proj - Linear
[_untrace_forward_execution] ........qkv_proj - Linear
[_untrace_forward_execution] ......mlp - Phi3MLP
[_untrace_forward_execution] ........gate_up_proj - Linear
[_untrace_forward_execution] ........down_proj - Linear
[_untrace_forward_execution] ........activation_fn - SiLU
[_untrace_forward_execution] ......input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......resid_attn_dropout - Dropout
[_untrace_forward_execution] ......resid_mlp_dropout - Dropout
[_untrace_forward_execution] ....norm - Phi3RMSNorm
[_untrace_forward_execution] ....rotary_emb - Phi3RotaryEmbedding
[_untrace_forward_execution] ..lm_head - Linear

Now we keep in memory every input/output for the submodules, we can guess the dynamic shapes for every of them. The final ones:

dynamic_shapes = diag.guess_dynamic_shapes()
print("The dynamic shapes are:")
pprint.pprint(dynamic_shapes)
The dynamic shapes are:
((),
 {'attention_mask': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  'input_ids': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  'past_key_values': [[{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
                       {0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}],
                      [{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
                       {0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}]]})

And all the dynamic shapes all along the traced submodules.

print(
    diag.pretty_text(
        with_dynamic_shape=True,
        with_shape=False,
        with_min_max=False,
        with_device=False,
        with_inputs=False,
    ).replace("<_DimHint.DYNAMIC: 3>", "DYN")
)
>>> __main__: Phi3ForCausalLM
  DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'input_ids': {0: DYN, 1: DYN}, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]]})
    >>> model: Phi3Model
      DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'cache_position': None, 'input_ids': {0: DYN, 1: DYN}, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_ids': None, 'return_dict': None, 'use_cache': None})
        >>> embed_tokens: Embedding: DS=(({0: DYN, 1: DYN},), {}) <<<
        >>> layers[0]: Phi3DecoderLayer
          DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
            >>> self_attn: Phi3Attention
              DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
                >>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> mlp: Phi3MLP
              DS=(({0: DYN, 1: DYN},), {})
                >>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
        <<<
        >>> layers[1]: Phi3DecoderLayer
          DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
            >>> self_attn: Phi3Attention
              DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
                >>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> mlp: Phi3MLP
              DS=(({0: DYN, 1: DYN},), {})
                >>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
        <<<
        >>> norm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
        >>> rotary_emb: Phi3RotaryEmbedding: DS=(({0: DYN, 1: DYN}, {1: DYN}), {}) <<<
    <<<
    >>> lm_head: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<

Evaluate the export

In many cases, the export (to torch.fx.Graph, to ONNX) does not work on the first try. We need a way to understand how much the model can be exported. It can be used to evaluate the how much code needs to be rewritten or patched to be exportable. The verbosity can be increase to show dynamic shapes, results of the discrepancies. Let’s display the module and its submodule first.

print(
    diag.pretty_text(
        with_dynamic_shape=False,
        with_shape=False,
        with_min_max=False,
        with_device=False,
        with_inputs=False,
    )
)
>>> __main__: Phi3ForCausalLM
    >>> model: Phi3Model
        >>> embed_tokens: Embedding <<<
        >>> layers[0]: Phi3DecoderLayer
            >>> self_attn: Phi3Attention
                >>> o_proj: Linear <<<
                >>> qkv_proj: Linear <<<
            <<<
            >>> mlp: Phi3MLP
                >>> gate_up_proj: Linear <<<
                >>> down_proj: Linear <<<
                >>> activation_fn: SiLU <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm <<<
            >>> post_attention_layernorm: Phi3RMSNorm <<<
            >>> resid_attn_dropout: Dropout <<<
            >>> resid_mlp_dropout: Dropout <<<
        <<<
        >>> layers[1]: Phi3DecoderLayer
            >>> self_attn: Phi3Attention
                >>> o_proj: Linear <<<
                >>> qkv_proj: Linear <<<
            <<<
            >>> mlp: Phi3MLP
                >>> gate_up_proj: Linear <<<
                >>> down_proj: Linear <<<
                >>> activation_fn: SiLU <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm <<<
            >>> post_attention_layernorm: Phi3RMSNorm <<<
            >>> resid_attn_dropout: Dropout <<<
            >>> resid_mlp_dropout: Dropout <<<
        <<<
        >>> norm: Phi3RMSNorm <<<
        >>> rotary_emb: Phi3RotaryEmbedding <<<
    <<<
    >>> lm_head: Linear <<<
<<<

The we try to export.

ep = diag.try_export(
    exporter="fx",
    use_dynamic_shapes=True,
    exporter_kwargs=dict(strict=False),
    bypass_kwargs=dict(patch_transformers=True, replace_dynamic_cache=True),
    verbose=1,
)
[try_export] __main__ - Phi3ForCausalLM --- FAIL-EXPORT: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)
[try_export] ..model - Phi3Model --- FAIL-EXPORT: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)
[try_export] ....embed_tokens - Embedding --- OK
[try_export] ....layers[0] - Phi3DecoderLayer --- FAIL-EXPORT: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['hidden_states']` is a <class 'dict'>, but `dynamic_shapes['hidden_states']` is not
[try_export] ......self_attn - Phi3Attention --- FAIL-EXPORT: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['hidden_states', 'position_embeddings', 'attention_mask', 'past_key_value', 'cache_position', 'kwargs'] of `inputs`, but here they are ['attention_mask', 'cache_position', 'hidden_states', 'output_attentions', 'past_key_value', 'position_embeddings', 'position_ids', 'use_cache']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
[try_export] ........o_proj - Linear --- OK
[try_export] ........qkv_proj - Linear --- OK
[try_export] ......mlp - Phi3MLP --- OK
[try_export] ......input_layernorm - Phi3RMSNorm --- OK
[try_export] ......post_attention_layernorm - Phi3RMSNorm --- OK
[try_export] ......resid_attn_dropout - Dropout --- OK
[try_export] ......resid_mlp_dropout - Dropout --- OK
[try_export] ....layers[1] - Phi3DecoderLayer --- FAIL-EXPORT: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['hidden_states']` is a <class 'dict'>, but `dynamic_shapes['hidden_states']` is not
[try_export] ......self_attn - Phi3Attention --- FAIL-EXPORT: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['hidden_states', 'position_embeddings', 'attention_mask', 'past_key_value', 'cache_position', 'kwargs'] of `inputs`, but here they are ['attention_mask', 'cache_position', 'hidden_states', 'output_attentions', 'past_key_value', 'position_embeddings', 'position_ids', 'use_cache']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
[try_export] ........o_proj - Linear --- OK
[try_export] ........qkv_proj - Linear --- OK
[try_export] ......mlp - Phi3MLP --- OK
[try_export] ......input_layernorm - Phi3RMSNorm --- OK
[try_export] ......post_attention_layernorm - Phi3RMSNorm --- OK
[try_export] ......resid_attn_dropout - Dropout --- OK
[try_export] ......resid_mlp_dropout - Dropout --- OK
[try_export] ....norm - Phi3RMSNorm --- OK
[try_export] ....rotary_emb - Phi3RotaryEmbedding --- FAIL-EXPORT: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)
[try_export] ..lm_head - Linear --- OK

Total running time of the script: (0 minutes 5.563 seconds)

Related examples

torch.onnx.export and Phi-2

torch.onnx.export and Phi-2

to_onnx and Phi-2

to_onnx and Phi-2

to_onnx and infer dynamic shapes

to_onnx and infer dynamic shapes

Gallery generated by Sphinx-Gallery