Export Phi-3.5-mini-instruct piece by piece

torch.export.export() often breaks on big models because there are control flows or instructions breaking the propagation of dynamic shapes (see …). The function usually gives an indication where the model implementation can be fixed but in case, that is not possible, we can try to export the model piece by piece: every module is converted separately from its submodule. A model can be exported even if one of its submodules cannot.

Model

import pprint
from typing import Any, Dict
import torch
import torch._export.tools
import transformers
from experimental_experiment.helpers import string_type
from experimental_experiment.torch_interpreter.piece_by_piece import (
    trace_execution_piece_by_piece,
)


def get_phi35_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
    """
    Gets a non initialized model with two sets of inputs and different shapes.

    :param batch_size: batch size
    :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
    :return: dictionary

    See `Phi-3.5-mini-instruct/config.json
    <https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
    """
    config = {
        "_name_or_path": "Phi-3.5-mini-instruct",
        "architectures": ["Phi3ForCausalLM"],
        "attention_dropout": 0.0,
        "auto_map": {
            "AutoConfig": "configuration_phi3.Phi3Config",
            "AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
        },
        "bos_token_id": 1,
        "embd_pdrop": 0.0,
        "eos_token_id": 32000,
        "hidden_act": "silu",
        "hidden_size": 3072,
        "initializer_range": 0.02,
        "intermediate_size": 8192,
        "max_position_embeddings": 131072,
        "model_type": "phi3",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 32,
        "original_max_position_embeddings": 4096,
        "pad_token_id": 32000,
        "resid_pdrop": 0.0,
        "rms_norm_eps": 1e-05,
        "rope_scaling": {
            "long_factor": [
                1.0800000429153442,
                1.1100000143051147,
                1.1399999856948853,
                1.340000033378601,
                1.5899999141693115,
                1.600000023841858,
                1.6200000047683716,
                2.620000123977661,
                3.2300000190734863,
                3.2300000190734863,
                4.789999961853027,
                7.400000095367432,
                7.700000286102295,
                9.09000015258789,
                12.199999809265137,
                17.670000076293945,
                24.46000099182129,
                28.57000160217285,
                30.420001983642578,
                30.840002059936523,
                32.590003967285156,
                32.93000411987305,
                42.320003509521484,
                44.96000289916992,
                50.340003967285156,
                50.45000457763672,
                57.55000305175781,
                57.93000411987305,
                58.21000289916992,
                60.1400032043457,
                62.61000442504883,
                62.62000274658203,
                62.71000289916992,
                63.1400032043457,
                63.1400032043457,
                63.77000427246094,
                63.93000411987305,
                63.96000289916992,
                63.970001220703125,
                64.02999877929688,
                64.06999969482422,
                64.08000183105469,
                64.12000274658203,
                64.41000366210938,
                64.4800033569336,
                64.51000213623047,
                64.52999877929688,
                64.83999633789062,
            ],
            "short_factor": [
                1.0,
                1.0199999809265137,
                1.0299999713897705,
                1.0299999713897705,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0499999523162842,
                1.0699999332427979,
                1.0999999046325684,
                1.1099998950958252,
                1.1599998474121094,
                1.1599998474121094,
                1.1699998378753662,
                1.2899998426437378,
                1.339999794960022,
                1.679999828338623,
                1.7899998426437378,
                1.8199998140335083,
                1.8499997854232788,
                1.8799997568130493,
                1.9099997282028198,
                1.9399996995925903,
                1.9899996519088745,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0199997425079346,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0299997329711914,
                2.0799996852874756,
                2.0899996757507324,
                2.189999580383301,
                2.2199995517730713,
                2.5899994373321533,
                2.729999542236328,
                2.749999523162842,
                2.8399994373321533,
            ],
            "type": "longrope",
        },
        "rope_theta": 10000.0,
        "sliding_window": 262144,
        "tie_word_embeddings": False,
        "torch_dtype": "bfloat16",
        "use_cache": True,
        "attention_bias": False,
        "vocab_size": 32064,
    }
    config.update(**kwargs)
    conf = transformers.Phi3Config(**config)
    model = transformers.Phi3ForCausalLM(conf)
    model.eval()

    cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache.update(
            torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96), i
        )
    cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
    for i in range(config["num_hidden_layers"]):
        cache2.update(
            torch.randn(batch_size + 1, 32, 31, 96),
            torch.randn(batch_size + 1, 32, 31, 96),
            i,
        )

    inputs = dict(
        input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
        attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
        past_key_values=cache,
    )
    inputs2 = dict(
        input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
        attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
        past_key_values=cache2,
    )
    return dict(inputs=inputs, model=model, inputs2=inputs2)


data = get_phi35_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))

Dynamic Shapes

We want to infer the dynamic shapes from the two sets of inputs we gave. For that, we use a function to trace the execution of the model including its submodules. It is going to execute the model twice with the two sets of inputs and stores every intermediate input and output.

diag = trace_execution_piece_by_piece(model, [inputs, inputs2], verbose=2)
[_trace_forward_execution]  __main__ - Phi3ForCausalLM
[_trace_forward_execution] .. model - Phi3Model
[_trace_forward_execution] .. lm_head - Linear
[trace_execution_piece_by_piece] run with dict(args:(),kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model]   > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:bool,output_hidden_states:bool,return_dict:bool,cache_position:None)
[model:Phi3Model]   < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear]   > T1r3
[lm_head:Linear]   < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_execution_piece_by_piece] run with dict(args:(),kwargs:dict(input_ids:T7s3x4,attention_mask:T7s3x35,past_key_values:DynamicCache(key_cache=#2[T1s3x32x31x96,T1s3x32x31x96], value_cache=#2[T1s3x32x31x96,T1s3x32x31x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model]   > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:bool,output_hidden_states:bool,return_dict:bool,cache_position:None)
[model:Phi3Model]   < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear]   > T1r3
[lm_head:Linear]   < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_forward_execution] traced execution of model Phi3ForCausalLM
>>> __main__: Phi3ForCausalLM
  > ((),dict(input_ids:CT7s2x3[4664,31570:A17677.166666666668],attention_mask:CT7s2x33[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]])))
  > ((),dict(input_ids:CT7s3x4[608,30304:A20014.666666666668],attention_mask:CT7s3x35[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]])))
    >>> model: Phi3Model
      > ((),dict(input_ids:CT7s2x3[4664,31570:A17677.166666666668],attention_mask:CT7s2x33[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),inputs_embeds:None,use_cache:None,output_attentions:bool=False,output_hidden_states:bool=False,return_dict:bool=True,cache_position:None))
      > ((),dict(input_ids:CT7s3x4[608,30304:A20014.666666666668],attention_mask:CT7s3x35[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),inputs_embeds:None,use_cache:None,output_attentions:bool=False,output_hidden_states:bool=False,return_dict:bool=True,cache_position:None))
        >>> embed_tokens: Embedding
          > ((CT7s2x3[4664,31570:A17677.166666666668],),{})
          > ((CT7s3x4[608,30304:A20014.666666666668],),{})
          < (CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],)
          < (CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],)
        <<<
        >>> layers[0]: Phi3DecoderLayer
          > ((CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
          > ((CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
            >>> self_attn: Phi3Attention
              > ((),dict(hidden_states:CT1s2x3x3072[-3.8076090812683105,3.682769775390625:A0.00216633024901035],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
              > ((),dict(hidden_states:CT1s3x4x3072[-3.9569787979125977,4.13375186920166:A0.0011474127630010224],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
                >>> o_proj: Linear
                  > ((CT1s2x3x3072[-2.0857884883880615,2.227522850036621:A-0.00166854337468071],),{})
                  > ((CT1s3x4x3072[-1.9925907850265503,2.740478515625:A-4.3969838686671224e-05],),{})
                  < (CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],)
                  < (CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],)
                <<<
                >>> qkv_proj: Linear
                  > ((CT1s2x3x3072[-3.8076090812683105,3.682769775390625:A0.00216633024901035],),{})
                  > ((CT1s3x4x3072[-3.9569787979125977,4.13375186920166:A0.0011474127630010224],),{})
                  < (CT1s2x3x9216[-4.55996561050415,4.785744667053223:A-0.0007797792564158752],)
                  < (CT1s3x4x9216[-5.212952136993408,4.804689407348633:A0.003576188889704634],)
                <<<
              < (CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],None)
              < (CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],None)
            <<<
            >>> mlp: Phi3MLP
              > ((CT1s2x3x3072[-3.766619920730591,3.8079004287719727:A-0.0007900578330600884],),{})
              > ((CT1s3x4x3072[-4.20332670211792,4.075958251953125:A-0.010097015604570991],),{})
                >>> gate_up_proj: Linear
                  > ((CT1s2x3x3072[-3.766619920730591,3.8079004287719727:A-0.0007900578330600884],),{})
                  > ((CT1s3x4x3072[-4.20332670211792,4.075958251953125:A-0.010097015604570991],),{})
                  < (CT1s2x3x16384[-4.626908779144287,4.744266033172607:A0.0004144351753060012],)
                  < (CT1s3x4x16384[-5.071014404296875,5.014136791229248:A0.002499560848283622],)
                <<<
                >>> down_proj: Linear
                  > ((CT1s2x3x8192[-9.626221656799316,10.6359224319458:A-0.003712694137561807],),{})
                  > ((CT1s3x4x8192[-10.604706764221191,13.27551555633545:A-0.0009531583248854041],),{})
                  < (CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],)
                  < (CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],)
                <<<
                >>> activation_fn: SiLU
                  > ((CT1s2x3x8192[-4.296870231628418,4.744266033172607:A-0.00018980784703141276],),{})
                  > ((CT1s3x4x8192[-5.071014404296875,5.014136791229248:A-0.002123918648494557],),{})
                  < (CT1s2x3x8192[-0.27846455574035645,4.70334005355835:A0.2433877349765435],)
                  < (CT1s3x4x8192[-0.27846455574035645,4.981045722961426:A0.24466585338180846],)
                <<<
              < (CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],)
              < (CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],)
            <<<
            >>> input_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],),{})
              > ((CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],),{})
              < (CT1s2x3x3072[-3.8076090812683105,3.682769775390625:A0.00216633024901035],)
              < (CT1s3x4x3072[-3.9569787979125977,4.13375186920166:A0.0011474127630010224],)
            <<<
            >>> post_attention_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-1.716892957687378,1.7488868236541748:A-0.00032424742236179956],),{})
              > ((CT1s3x4x3072[-1.5904587507247925,1.5101318359375:A-0.0038784623282745023],),{})
              < (CT1s2x3x3072[-3.766619920730591,3.8079004287719727:A-0.0007900578330600884],)
              < (CT1s3x4x3072[-4.20332670211792,4.075958251953125:A-0.010097015604570991],)
            <<<
            >>> resid_attn_dropout: Dropout
              > ((CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],),{})
              > ((CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],),{})
              < (CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],)
              < (CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],)
            <<<
            >>> resid_mlp_dropout: Dropout
              > ((CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],),{})
              > ((CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],),{})
              < (CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],)
              < (CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],)
            <<<
          < (CT1s2x3x3072[-5.3348307609558105,6.06136417388916:A0.013187529813725027],)
          < (CT1s3x4x3072[-5.88392448425293,5.575860977172852:A-0.00783786298630831],)
        <<<
        >>> layers[1]: Phi3DecoderLayer
          > ((CT1s2x3x3072[-5.3348307609558105,6.06136417388916:A0.013187529813725027],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
          > ((CT1s3x4x3072[-5.88392448425293,5.575860977172852:A-0.00783786298630831],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
            >>> self_attn: Phi3Attention
              > ((),dict(hidden_states:CT1s2x3x3072[-3.71167254447937,4.351285934448242:A0.009349946014120251],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
              > ((),dict(hidden_states:CT1s3x4x3072[-4.0989861488342285,3.8373234272003174:A-0.005813611555114445],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
                >>> o_proj: Linear
                  > ((CT1s2x3x3072[-2.0965793132781982,2.4613635540008545:A0.001581031349573831],),{})
                  > ((CT1s3x4x3072[-2.2842745780944824,2.885471820831299:A-0.00010379927480992067],),{})
                  < (CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],)
                  < (CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],)
                <<<
                >>> qkv_proj: Linear
                  > ((CT1s2x3x3072[-3.71167254447937,4.351285934448242:A0.009349946014120251],),{})
                  > ((CT1s3x4x3072[-4.0989861488342285,3.8373234272003174:A-0.005813611555114445],),{})
                  < (CT1s2x3x9216[-4.554075717926025,4.405378818511963:A-0.004489510061266292],)
                  < (CT1s3x4x9216[-4.208725929260254,5.463718891143799:A0.004831148010973047],)
                <<<
              < (CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],None)
              < (CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],None)
            <<<
            >>> mlp: Phi3MLP
              > ((CT1s2x3x3072[-3.549666166305542,4.110150337219238:A0.007339953932229193],),{})
              > ((CT1s3x4x3072[-4.205822467803955,4.059736251831055:A-0.004850264546526262],),{})
                >>> gate_up_proj: Linear
                  > ((CT1s2x3x3072[-3.549666166305542,4.110150337219238:A0.007339953932229193],),{})
                  > ((CT1s3x4x3072[-4.205822467803955,4.059736251831055:A-0.004850264546526262],),{})
                  < (CT1s2x3x16384[-4.71592903137207,4.544133186340332:A0.0010123075836598143],)
                  < (CT1s3x4x16384[-5.153237819671631,5.296940803527832:A0.0010891516607924128],)
                <<<
                >>> down_proj: Linear
                  > ((CT1s2x3x8192[-9.733989715576172,9.656818389892578:A-0.00035077506349431273],),{})
                  > ((CT1s3x4x8192[-10.086795806884766,11.529891014099121:A0.004240172249181878],),{})
                  < (CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],)
                  < (CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],)
                <<<
                >>> activation_fn: SiLU
                  > ((CT1s2x3x8192[-4.71592903137207,4.544133186340332:A0.0006271271668936151],),{})
                  > ((CT1s3x4x8192[-4.761664867401123,4.957876682281494:A-0.002479409858256195],),{})
                  < (CT1s2x3x8192[-0.27846455574035645,4.496339797973633:A0.24618416685490122],)
                  < (CT1s3x4x8192[-0.27846455574035645,4.923276424407959:A0.2438934481976652],)
                <<<
              < (CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],)
              < (CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],)
            <<<
            >>> input_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-5.3348307609558105,6.06136417388916:A0.013187529813725027],),{})
              > ((CT1s3x4x3072[-5.88392448425293,5.575860977172852:A-0.00783786298630831],),{})
              < (CT1s2x3x3072[-3.71167254447937,4.351285934448242:A0.009349946014120251],)
              < (CT1s3x4x3072[-4.0989861488342285,3.8373234272003174:A-0.005813611555114445],)
            <<<
            >>> post_attention_layernorm: Phi3RMSNorm
              > ((CT1s2x3x3072[-5.136703968048096,5.932700157165527:A0.010719243169357165],),{})
              > ((CT1s3x4x3072[-5.9217209815979,6.042370319366455:A-0.006805091526985052],),{})
              < (CT1s2x3x3072[-3.549666166305542,4.110150337219238:A0.007339953932229193],)
              < (CT1s3x4x3072[-4.205822467803955,4.059736251831055:A-0.004850264546526262],)
            <<<
            >>> resid_attn_dropout: Dropout
              > ((CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],),{})
              > ((CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],),{})
              < (CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],)
              < (CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],)
            <<<
            >>> resid_mlp_dropout: Dropout
              > ((CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],),{})
              > ((CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],),{})
              < (CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],)
              < (CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],)
            <<<
          < (CT1s2x3x3072[-8.349817276000977,9.60705280303955:A0.014413142231104657],)
          < (CT1s3x4x3072[-7.1743621826171875,7.833809852600098:A-0.004469779496957926],)
        <<<
        >>> norm: Phi3RMSNorm
          > ((CT1s2x3x3072[-8.349817276000977,9.60705280303955:A0.014413142231104657],),{})
          > ((CT1s3x4x3072[-7.1743621826171875,7.833809852600098:A-0.004469779496957926],),{})
          < (CT1s2x3x3072[-4.329084396362305,4.684192657470703:A0.0072901403956184335],)
          < (CT1s3x4x3072[-3.6584055423736572,3.9267525672912598:A-0.002269768673042044],)
        <<<
        >>> rotary_emb: Phi3RotaryEmbedding
          > ((CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],CT7s1x3[30,32:A31.0]),{})
          > ((CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],CT7s1x4[31,34:A32.5]),{})
          < (CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])
          < (CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])
        <<<
      < (dict(last_hidden_state:CT1s2x3x3072[-4.329084396362305,4.684192657470703:A0.0072901403956184335],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x33x96[-5.159628391265869,5.257333278656006:A-0.003930519807112116]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x33x96[-4.677217960357666,5.006890296936035:A0.0021848251059774755]])),)
      < (dict(last_hidden_state:CT1s3x4x3072[-3.6584055423736572,3.9267525672912598:A-0.002269768673042044],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x35x96[-5.084636211395264,4.978662967681885:A-0.002626009440932474]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x35x96[-5.0299248695373535,5.066045761108398:A0.001982277264531371]])),)
    <<<
    >>> lm_head: Linear
      > ((CT1s2x3x3072[-4.329084396362305,4.684192657470703:A0.0072901403956184335],),{})
      > ((CT1s3x4x3072[-3.6584055423736572,3.9267525672912598:A-0.002269768673042044],),{})
      < (CT1s2x3x32064[-4.806385040283203,4.561767101287842:A-7.084682205524456e-05],)
      < (CT1s3x4x32064[-4.770487308502197,5.134904384613037:A0.0011637009553284994],)
    <<<
  < (dict(logits:CT1s2x3x32064[-4.806385040283203,4.561767101287842:A-7.084682205524456e-05],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x33x96[-5.159628391265869,5.257333278656006:A-0.003930519807112116]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x33x96[-4.677217960357666,5.006890296936035:A0.0021848251059774755]])),)
  < (dict(logits:CT1s3x4x32064[-4.770487308502197,5.134904384613037:A0.0011637009553284994],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x35x96[-5.084636211395264,4.978662967681885:A-0.002626009440932474]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x35x96[-5.0299248695373535,5.066045761108398:A0.001982277264531371]])),)
<<<
[_untrace_forward_execution]  __main__ - Phi3ForCausalLM
[_untrace_forward_execution] .. model - Phi3Model
[_untrace_forward_execution] .... embed_tokens - Embedding
[_untrace_forward_execution] .... layers[0] - Phi3DecoderLayer
[_untrace_forward_execution] ...... self_attn - Phi3Attention
[_untrace_forward_execution] ........ o_proj - Linear
[_untrace_forward_execution] ........ qkv_proj - Linear
[_untrace_forward_execution] ...... mlp - Phi3MLP
[_untrace_forward_execution] ........ gate_up_proj - Linear
[_untrace_forward_execution] ........ down_proj - Linear
[_untrace_forward_execution] ........ activation_fn - SiLU
[_untrace_forward_execution] ...... input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... resid_attn_dropout - Dropout
[_untrace_forward_execution] ...... resid_mlp_dropout - Dropout
[_untrace_forward_execution] .... layers[1] - Phi3DecoderLayer
[_untrace_forward_execution] ...... self_attn - Phi3Attention
[_untrace_forward_execution] ........ o_proj - Linear
[_untrace_forward_execution] ........ qkv_proj - Linear
[_untrace_forward_execution] ...... mlp - Phi3MLP
[_untrace_forward_execution] ........ gate_up_proj - Linear
[_untrace_forward_execution] ........ down_proj - Linear
[_untrace_forward_execution] ........ activation_fn - SiLU
[_untrace_forward_execution] ...... input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... resid_attn_dropout - Dropout
[_untrace_forward_execution] ...... resid_mlp_dropout - Dropout
[_untrace_forward_execution] .... norm - Phi3RMSNorm
[_untrace_forward_execution] .... rotary_emb - Phi3RotaryEmbedding
[_untrace_forward_execution] .. lm_head - Linear

Now we keep in memory every input/output for the submodules, we can guess the dynamic shapes for every of them. The final ones:

dynamic_shapes = diag.guess_dynamic_shapes()
print("The dynamic shapes are:")
pprint.pprint(dynamic_shapes)
The dynamic shapes are:
((),
 {'attention_mask': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  'input_ids': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  'past_key_values': [[{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
                       {0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}],
                      [{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
                       {0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}]]})

And all the dynamic shapes all along the traced submodules.

print(
    diag.pretty_text(
        with_dynamic_shape=True,
        with_shape=False,
        with_min_max=False,
        with_device=False,
        with_inputs=False,
    ).replace("<_DimHint.DYNAMIC: 3>", "DYN")
)
>>> __main__: Phi3ForCausalLM
  DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'input_ids': {0: DYN, 1: DYN}, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]]})
    >>> model: Phi3Model
      DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'cache_position': None, 'input_ids': {0: DYN, 1: DYN}, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_ids': None, 'return_dict': None, 'use_cache': None})
        >>> embed_tokens: Embedding: DS=(({0: DYN, 1: DYN},), {}) <<<
        >>> layers[0]: Phi3DecoderLayer
          DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
            >>> self_attn: Phi3Attention
              DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
                >>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> mlp: Phi3MLP
              DS=(({0: DYN, 1: DYN},), {})
                >>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
        <<<
        >>> layers[1]: Phi3DecoderLayer
          DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
            >>> self_attn: Phi3Attention
              DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
                >>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> mlp: Phi3MLP
              DS=(({0: DYN, 1: DYN},), {})
                >>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
                >>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
            >>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
        <<<
        >>> norm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
        >>> rotary_emb: Phi3RotaryEmbedding: DS=(({0: DYN, 1: DYN}, {1: DYN}), {}) <<<
    <<<
    >>> lm_head: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<

Evaluate the export

In many cases, the export (to torch.fx.Graph, to ONNX) does not work on the first try. We need a way to understand how much the model can be exported. It can be used to evaluate the how much code needs to be rewritten or patched to be exportable. The verbosity can be increase to show dynamic shapes, results of the discrepancies. Let’s display the module and its submodule first.

print(
    diag.pretty_text(
        with_dynamic_shape=False,
        with_shape=False,
        with_min_max=False,
        with_device=False,
        with_inputs=False,
    )
)
>>> __main__: Phi3ForCausalLM
    >>> model: Phi3Model
        >>> embed_tokens: Embedding <<<
        >>> layers[0]: Phi3DecoderLayer
            >>> self_attn: Phi3Attention
                >>> o_proj: Linear <<<
                >>> qkv_proj: Linear <<<
            <<<
            >>> mlp: Phi3MLP
                >>> gate_up_proj: Linear <<<
                >>> down_proj: Linear <<<
                >>> activation_fn: SiLU <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm <<<
            >>> post_attention_layernorm: Phi3RMSNorm <<<
            >>> resid_attn_dropout: Dropout <<<
            >>> resid_mlp_dropout: Dropout <<<
        <<<
        >>> layers[1]: Phi3DecoderLayer
            >>> self_attn: Phi3Attention
                >>> o_proj: Linear <<<
                >>> qkv_proj: Linear <<<
            <<<
            >>> mlp: Phi3MLP
                >>> gate_up_proj: Linear <<<
                >>> down_proj: Linear <<<
                >>> activation_fn: SiLU <<<
            <<<
            >>> input_layernorm: Phi3RMSNorm <<<
            >>> post_attention_layernorm: Phi3RMSNorm <<<
            >>> resid_attn_dropout: Dropout <<<
            >>> resid_mlp_dropout: Dropout <<<
        <<<
        >>> norm: Phi3RMSNorm <<<
        >>> rotary_emb: Phi3RotaryEmbedding <<<
    <<<
    >>> lm_head: Linear <<<
<<<

The we try to export to see the submodule failing the whole model. We can pickle the failing model and restore it to speedup the refactoring to make it work.

print("----------------------")
ep = diag.try_export(
    exporter="fx",
    use_dynamic_shapes=True,
    exporter_kwargs=dict(strict=False),
    verbose=1,
)
print(f"success: {ep.status}")
print(diag.get_export_report())
----------------------

[try_export-FX]  __main__ - Phi3ForCausalLM --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
[try_export-FX] .. model - Phi3Model --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
[try_export-FX] .... embed_tokens - Embedding --- OK
[try_export-FX] .... layers[0] - Phi3DecoderLayer --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ...... self_attn - Phi3Attention --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ........ o_proj - Linear --- OK
[try_export-FX] ........ qkv_proj - Linear --- OK
[try_export-FX] ...... mlp - Phi3MLP --- OK
[try_export-FX] ...... input_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... post_attention_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... resid_attn_dropout - Dropout --- OK
[try_export-FX] ...... resid_mlp_dropout - Dropout --- OK
[try_export-FX] .... layers[1] - Phi3DecoderLayer --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ...... self_attn - Phi3Attention --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ........ o_proj - Linear --- OK
[try_export-FX] ........ qkv_proj - Linear --- OK
[try_export-FX] ...... mlp - Phi3MLP --- OK
[try_export-FX] ...... input_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... post_attention_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... resid_attn_dropout - Dropout --- OK
[try_export-FX] ...... resid_mlp_dropout - Dropout --- OK
[try_export-FX] .... norm - Phi3RMSNorm --- OK
[try_export-FX] .... rotary_emb - Phi3RotaryEmbedding --- FAIL, step=EXPORT, reason=Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)
[try_export-FX] .... rotary_emb - Phi3RotaryEmbedding --- FAIL
[try_export-FX] .. lm_head - Linear --- OK
success: 2
__main__                         Phi3ForCausalLM       FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
..model                          Phi3Model             FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
....embed_tokens                 Embedding             OK -- ExportedProgram
....layers[0]                    Phi3DecoderLayer      FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
......self_attn                  Phi3Attention         FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
........o_proj                   Linear                OK -- ExportedProgram
........qkv_proj                 Linear                OK -- ExportedProgram
......mlp                        Phi3MLP               OK -- ExportedProgram
........gate_up_proj             Linear                OK as part of its owner
........down_proj                Linear                OK as part of its owner
........activation_fn            SiLU                  OK as part of its owner
......input_layernorm            Phi3RMSNorm           OK -- ExportedProgram
......post_attention_layernorm   Phi3RMSNorm           OK -- ExportedProgram
......resid_attn_dropout         Dropout               OK -- ExportedProgram
......resid_mlp_dropout          Dropout               OK -- ExportedProgram
....layers[1]                    Phi3DecoderLayer      FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
......self_attn                  Phi3Attention         FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
........o_proj                   Linear                OK -- ExportedProgram
........qkv_proj                 Linear                OK -- ExportedProgram
......mlp                        Phi3MLP               OK -- ExportedProgram
........gate_up_proj             Linear                OK as part of its owner
........down_proj                Linear                OK as part of its owner
........activation_fn            SiLU                  OK as part of its owner
......input_layernorm            Phi3RMSNorm           OK -- ExportedProgram
......post_attention_layernorm   Phi3RMSNorm           OK -- ExportedProgram
......resid_attn_dropout         Dropout               OK -- ExportedProgram
......resid_mlp_dropout          Dropout               OK -- ExportedProgram
....norm                         Phi3RMSNorm           OK -- ExportedProgram
....rotary_emb                   Phi3RotaryEmbedding   FAIL -- step=EXPORT, reason='Could not guard on data-depend...'
..lm_head                        Linear                OK -- ExportedProgram

Replace the failing module by a custom op

The main module is not exportable because one piece cannot be exported. But maybe if we assume it works, maybe everything else is working. So let’s try to replace this class by a custom op. This will be something for another example.

Total running time of the script: (0 minutes 9.513 seconds)

Related examples

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with report_exportability

Export Phi-3.5-mini-instruct with draft_export

Export Phi-3.5-mini-instruct with draft_export

torch.onnx.export and Phi-2

torch.onnx.export and Phi-2

Gallery generated by Sphinx-Gallery