Note
Go to the end to download the full example code.
Export Phi-3.5-mini-instruct piece by piece¶
torch.export.export()
often breaks on big models because there
are control flows or instructions breaking the propagation of
dynamic shapes (see …). The function usually gives an indication where
the model implementation can be fixed but in case, that is not possible,
we can try to export the model piece by piece: every module
is converted separately from its submodule. A model can be exported even
if one of its submodules cannot.
Model¶
import pprint
from typing import Any, Dict
import torch
import torch._export.tools
import transformers
from experimental_experiment.helpers import string_type
from experimental_experiment.torch_interpreter.piece_by_piece import (
trace_execution_piece_by_piece,
)
def get_phi35_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
"""
Gets a non initialized model with two sets of inputs and different shapes.
:param batch_size: batch size
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:return: dictionary
See `Phi-3.5-mini-instruct/config.json
<https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
"""
config = {
"_name_or_path": "Phi-3.5-mini-instruct",
"architectures": ["Phi3ForCausalLM"],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
1.0800000429153442,
1.1100000143051147,
1.1399999856948853,
1.340000033378601,
1.5899999141693115,
1.600000023841858,
1.6200000047683716,
2.620000123977661,
3.2300000190734863,
3.2300000190734863,
4.789999961853027,
7.400000095367432,
7.700000286102295,
9.09000015258789,
12.199999809265137,
17.670000076293945,
24.46000099182129,
28.57000160217285,
30.420001983642578,
30.840002059936523,
32.590003967285156,
32.93000411987305,
42.320003509521484,
44.96000289916992,
50.340003967285156,
50.45000457763672,
57.55000305175781,
57.93000411987305,
58.21000289916992,
60.1400032043457,
62.61000442504883,
62.62000274658203,
62.71000289916992,
63.1400032043457,
63.1400032043457,
63.77000427246094,
63.93000411987305,
63.96000289916992,
63.970001220703125,
64.02999877929688,
64.06999969482422,
64.08000183105469,
64.12000274658203,
64.41000366210938,
64.4800033569336,
64.51000213623047,
64.52999877929688,
64.83999633789062,
],
"short_factor": [
1.0,
1.0199999809265137,
1.0299999713897705,
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0699999332427979,
1.0999999046325684,
1.1099998950958252,
1.1599998474121094,
1.1599998474121094,
1.1699998378753662,
1.2899998426437378,
1.339999794960022,
1.679999828338623,
1.7899998426437378,
1.8199998140335083,
1.8499997854232788,
1.8799997568130493,
1.9099997282028198,
1.9399996995925903,
1.9899996519088745,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0799996852874756,
2.0899996757507324,
2.189999580383301,
2.2199995517730713,
2.5899994373321533,
2.729999542236328,
2.749999523162842,
2.8399994373321533,
],
"type": "longrope",
},
"rope_theta": 10000.0,
"sliding_window": 262144,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"use_cache": True,
"attention_bias": False,
"vocab_size": 32064,
}
config.update(**kwargs)
conf = transformers.Phi3Config(**config)
model = transformers.Phi3ForCausalLM(conf)
model.eval()
cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache.update(
torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96), i
)
cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache2.update(
torch.randn(batch_size + 1, 32, 31, 96),
torch.randn(batch_size + 1, 32, 31, 96),
i,
)
inputs = dict(
input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
past_key_values=cache,
)
inputs2 = dict(
input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
past_key_values=cache2,
)
return dict(inputs=inputs, model=model, inputs2=inputs2)
data = get_phi35_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))
Dynamic Shapes¶
We want to infer the dynamic shapes from the two sets of inputs we gave. For that, we use a function to trace the execution of the model including its submodules. It is going to execute the model twice with the two sets of inputs and stores every intermediate input and output.
[_trace_forward_execution] __main__ - Phi3ForCausalLM
[_trace_forward_execution] .. model - Phi3Model
[_trace_forward_execution] .. lm_head - Linear
[trace_execution_piece_by_piece] run with dict(args:(),kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model] > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:bool,output_hidden_states:bool,return_dict:bool,cache_position:None)
[model:Phi3Model] < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear] > T1r3
[lm_head:Linear] < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_execution_piece_by_piece] run with dict(args:(),kwargs:dict(input_ids:T7s3x4,attention_mask:T7s3x35,past_key_values:DynamicCache(key_cache=#2[T1s3x32x31x96,T1s3x32x31x96], value_cache=#2[T1s3x32x31x96,T1s3x32x31x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model] > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:bool,output_hidden_states:bool,return_dict:bool,cache_position:None)
[model:Phi3Model] < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear] > T1r3
[lm_head:Linear] < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_forward_execution] traced execution of model Phi3ForCausalLM
>>> __main__: Phi3ForCausalLM
> ((),dict(input_ids:CT7s2x3[4664,31570:A17677.166666666668],attention_mask:CT7s2x33[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]])))
> ((),dict(input_ids:CT7s3x4[608,30304:A20014.666666666668],attention_mask:CT7s3x35[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]])))
>>> model: Phi3Model
> ((),dict(input_ids:CT7s2x3[4664,31570:A17677.166666666668],attention_mask:CT7s2x33[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),inputs_embeds:None,use_cache:None,output_attentions:bool=False,output_hidden_states:bool=False,return_dict:bool=True,cache_position:None))
> ((),dict(input_ids:CT7s3x4[608,30304:A20014.666666666668],attention_mask:CT7s3x35[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),inputs_embeds:None,use_cache:None,output_attentions:bool=False,output_hidden_states:bool=False,return_dict:bool=True,cache_position:None))
>>> embed_tokens: Embedding
> ((CT7s2x3[4664,31570:A17677.166666666668],),{})
> ((CT7s3x4[608,30304:A20014.666666666668],),{})
< (CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],)
< (CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],)
<<<
>>> layers[0]: Phi3DecoderLayer
> ((CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> self_attn: Phi3Attention
> ((),dict(hidden_states:CT1s2x3x3072[-3.8076090812683105,3.682769775390625:A0.00216633024901035],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.524319648742676,4.446798324584961:A0.00045197112875037246],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x30x96[-4.324870586395264,4.5727949142456055:A-0.0008027056451969537],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((),dict(hidden_states:CT1s3x4x3072[-3.9569787979125977,4.13375186920166:A0.0011474127630010224],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.5164971351623535,4.401163578033447:A-0.0006134297668116055],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x31x96[-4.715044021606445,4.725687026977539:A0.0017272961305985642],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> o_proj: Linear
> ((CT1s2x3x3072[-2.0857884883880615,2.227522850036621:A-0.00166854337468071],),{})
> ((CT1s3x4x3072[-1.9925907850265503,2.740478515625:A-4.3969838686671224e-05],),{})
< (CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],)
< (CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],)
<<<
>>> qkv_proj: Linear
> ((CT1s2x3x3072[-3.8076090812683105,3.682769775390625:A0.00216633024901035],),{})
> ((CT1s3x4x3072[-3.9569787979125977,4.13375186920166:A0.0011474127630010224],),{})
< (CT1s2x3x9216[-4.55996561050415,4.785744667053223:A-0.0007797792564158752],)
< (CT1s3x4x9216[-5.212952136993408,4.804689407348633:A0.003576188889704634],)
<<<
< (CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],None)
< (CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],None)
<<<
>>> mlp: Phi3MLP
> ((CT1s2x3x3072[-3.766619920730591,3.8079004287719727:A-0.0007900578330600884],),{})
> ((CT1s3x4x3072[-4.20332670211792,4.075958251953125:A-0.010097015604570991],),{})
>>> gate_up_proj: Linear
> ((CT1s2x3x3072[-3.766619920730591,3.8079004287719727:A-0.0007900578330600884],),{})
> ((CT1s3x4x3072[-4.20332670211792,4.075958251953125:A-0.010097015604570991],),{})
< (CT1s2x3x16384[-4.626908779144287,4.744266033172607:A0.0004144351753060012],)
< (CT1s3x4x16384[-5.071014404296875,5.014136791229248:A0.002499560848283622],)
<<<
>>> down_proj: Linear
> ((CT1s2x3x8192[-9.626221656799316,10.6359224319458:A-0.003712694137561807],),{})
> ((CT1s3x4x8192[-10.604706764221191,13.27551555633545:A-0.0009531583248854041],),{})
< (CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],)
< (CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],)
<<<
>>> activation_fn: SiLU
> ((CT1s2x3x8192[-4.296870231628418,4.744266033172607:A-0.00018980784703141276],),{})
> ((CT1s3x4x8192[-5.071014404296875,5.014136791229248:A-0.002123918648494557],),{})
< (CT1s2x3x8192[-0.27846455574035645,4.70334005355835:A0.2433877349765435],)
< (CT1s3x4x8192[-0.27846455574035645,4.981045722961426:A0.24466585338180846],)
<<<
< (CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],)
< (CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],)
<<<
>>> input_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],),{})
> ((CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],),{})
< (CT1s2x3x3072[-3.8076090812683105,3.682769775390625:A0.00216633024901035],)
< (CT1s3x4x3072[-3.9569787979125977,4.13375186920166:A0.0011474127630010224],)
<<<
>>> post_attention_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-1.716892957687378,1.7488868236541748:A-0.00032424742236179956],),{})
> ((CT1s3x4x3072[-1.5904587507247925,1.5101318359375:A-0.0038784623282745023],),{})
< (CT1s2x3x3072[-3.766619920730591,3.8079004287719727:A-0.0007900578330600884],)
< (CT1s3x4x3072[-4.20332670211792,4.075958251953125:A-0.010097015604570991],)
<<<
>>> resid_attn_dropout: Dropout
> ((CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],),{})
> ((CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],),{})
< (CT1s2x3x3072[-1.7465921640396118,1.7467150688171387:A-0.00036754129455908295],)
< (CT1s3x4x3072[-1.576282024383545,1.5121502876281738:A-0.0039008252983229064],)
<<<
>>> resid_mlp_dropout: Dropout
> ((CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],),{})
> ((CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],),{})
< (CT1s2x3x3072[-5.091585636138916,5.638833045959473:A0.013511776874464785],)
< (CT1s3x4x3072[-5.19716215133667,5.227738857269287:A-0.003959400700043463],)
<<<
< (CT1s2x3x3072[-5.3348307609558105,6.06136417388916:A0.013187529813725027],)
< (CT1s3x4x3072[-5.88392448425293,5.575860977172852:A-0.00783786298630831],)
<<<
>>> layers[1]: Phi3DecoderLayer
> ((CT1s2x3x3072[-5.3348307609558105,6.06136417388916:A0.013187529813725027],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((CT1s3x4x3072[-5.88392448425293,5.575860977172852:A-0.00783786298630831],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> self_attn: Phi3Attention
> ((),dict(hidden_states:CT1s2x3x3072[-3.71167254447937,4.351285934448242:A0.009349946014120251],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x30x96[-4.23585319519043,4.715499401092529:A-0.005312503596610355]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x30x96[-4.677217960357666,5.006890296936035:A0.0019749061087091166]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((),dict(hidden_states:CT1s3x4x3072[-4.0989861488342285,3.8373234272003174:A-0.005813611555114445],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x31x96[-4.406899452209473,4.658656597137451:A-0.0018093928141991926]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x31x96[-5.0299248695373535,5.066045761108398:A0.0021334321277751466]]),output_attentions:bool=False,use_cache:bool=True,cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> o_proj: Linear
> ((CT1s2x3x3072[-2.0965793132781982,2.4613635540008545:A0.001581031349573831],),{})
> ((CT1s3x4x3072[-2.2842745780944824,2.885471820831299:A-0.00010379927480992067],),{})
< (CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],)
< (CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],)
<<<
>>> qkv_proj: Linear
> ((CT1s2x3x3072[-3.71167254447937,4.351285934448242:A0.009349946014120251],),{})
> ((CT1s3x4x3072[-4.0989861488342285,3.8373234272003174:A-0.005813611555114445],),{})
< (CT1s2x3x9216[-4.554075717926025,4.405378818511963:A-0.004489510061266292],)
< (CT1s3x4x9216[-4.208725929260254,5.463718891143799:A0.004831148010973047],)
<<<
< (CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],None)
< (CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],None)
<<<
>>> mlp: Phi3MLP
> ((CT1s2x3x3072[-3.549666166305542,4.110150337219238:A0.007339953932229193],),{})
> ((CT1s3x4x3072[-4.205822467803955,4.059736251831055:A-0.004850264546526262],),{})
>>> gate_up_proj: Linear
> ((CT1s2x3x3072[-3.549666166305542,4.110150337219238:A0.007339953932229193],),{})
> ((CT1s3x4x3072[-4.205822467803955,4.059736251831055:A-0.004850264546526262],),{})
< (CT1s2x3x16384[-4.71592903137207,4.544133186340332:A0.0010123075836598143],)
< (CT1s3x4x16384[-5.153237819671631,5.296940803527832:A0.0010891516607924128],)
<<<
>>> down_proj: Linear
> ((CT1s2x3x8192[-9.733989715576172,9.656818389892578:A-0.00035077506349431273],),{})
> ((CT1s3x4x8192[-10.086795806884766,11.529891014099121:A0.004240172249181878],),{})
< (CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],)
< (CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],)
<<<
>>> activation_fn: SiLU
> ((CT1s2x3x8192[-4.71592903137207,4.544133186340332:A0.0006271271668936151],),{})
> ((CT1s3x4x8192[-4.761664867401123,4.957876682281494:A-0.002479409858256195],),{})
< (CT1s2x3x8192[-0.27846455574035645,4.496339797973633:A0.24618416685490122],)
< (CT1s3x4x8192[-0.27846455574035645,4.923276424407959:A0.2438934481976652],)
<<<
< (CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],)
< (CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],)
<<<
>>> input_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-5.3348307609558105,6.06136417388916:A0.013187529813725027],),{})
> ((CT1s3x4x3072[-5.88392448425293,5.575860977172852:A-0.00783786298630831],),{})
< (CT1s2x3x3072[-3.71167254447937,4.351285934448242:A0.009349946014120251],)
< (CT1s3x4x3072[-4.0989861488342285,3.8373234272003174:A-0.005813611555114445],)
<<<
>>> post_attention_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-5.136703968048096,5.932700157165527:A0.010719243169357165],),{})
> ((CT1s3x4x3072[-5.9217209815979,6.042370319366455:A-0.006805091526985052],),{})
< (CT1s2x3x3072[-3.549666166305542,4.110150337219238:A0.007339953932229193],)
< (CT1s3x4x3072[-4.205822467803955,4.059736251831055:A-0.004850264546526262],)
<<<
>>> resid_attn_dropout: Dropout
> ((CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],),{})
> ((CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],),{})
< (CT1s2x3x3072[-1.727698802947998,1.9539368152618408:A-0.0024682869261406872],)
< (CT1s3x4x3072[-1.7709965705871582,1.708268642425537:A0.0010327714551263195],)
<<<
>>> resid_mlp_dropout: Dropout
> ((CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],),{})
> ((CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],),{})
< (CT1s2x3x3072[-5.305871963500977,5.178980827331543:A0.0036938989775308073],)
< (CT1s3x4x3072[-6.050964832305908,5.4576826095581055:A0.0023353122698810897],)
<<<
< (CT1s2x3x3072[-8.349817276000977,9.60705280303955:A0.014413142231104657],)
< (CT1s3x4x3072[-7.1743621826171875,7.833809852600098:A-0.004469779496957926],)
<<<
>>> norm: Phi3RMSNorm
> ((CT1s2x3x3072[-8.349817276000977,9.60705280303955:A0.014413142231104657],),{})
> ((CT1s3x4x3072[-7.1743621826171875,7.833809852600098:A-0.004469779496957926],),{})
< (CT1s2x3x3072[-4.329084396362305,4.684192657470703:A0.0072901403956184335],)
< (CT1s3x4x3072[-3.6584055423736572,3.9267525672912598:A-0.002269768673042044],)
<<<
>>> rotary_emb: Phi3RotaryEmbedding
> ((CT1s2x3x3072[-0.0765041932463646,0.07427593320608139:A4.3293902790490314e-05],CT7s1x3[30,32:A31.0]),{})
> ((CT1s3x4x3072[-0.07915466278791428,0.08391024172306061:A2.2362938498739677e-05],CT7s1x4[31,34:A32.5]),{})
< (CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])
< (CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])
<<<
< (dict(last_hidden_state:CT1s2x3x3072[-4.329084396362305,4.684192657470703:A0.0072901403956184335],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x33x96[-5.159628391265869,5.257333278656006:A-0.003930519807112116]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x33x96[-4.677217960357666,5.006890296936035:A0.0021848251059774755]])),)
< (dict(last_hidden_state:CT1s3x4x3072[-3.6584055423736572,3.9267525672912598:A-0.002269768673042044],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x35x96[-5.084636211395264,4.978662967681885:A-0.002626009440932474]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x35x96[-5.0299248695373535,5.066045761108398:A0.001982277264531371]])),)
<<<
>>> lm_head: Linear
> ((CT1s2x3x3072[-4.329084396362305,4.684192657470703:A0.0072901403956184335],),{})
> ((CT1s3x4x3072[-3.6584055423736572,3.9267525672912598:A-0.002269768673042044],),{})
< (CT1s2x3x32064[-4.806385040283203,4.561767101287842:A-7.084682205524456e-05],)
< (CT1s3x4x32064[-4.770487308502197,5.134904384613037:A0.0011637009553284994],)
<<<
< (dict(logits:CT1s2x3x32064[-4.806385040283203,4.561767101287842:A-7.084682205524456e-05],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-5.269703388214111,5.904300689697266:A-0.0008714742347480582],CT1s2x32x33x96[-5.159628391265869,5.257333278656006:A-0.003930519807112116]], value_cache=#2[CT1s2x32x33x96[-4.365392684936523,4.5727949142456055:A-0.0001598371759914454],CT1s2x32x33x96[-4.677217960357666,5.006890296936035:A0.0021848251059774755]])),)
< (dict(logits:CT1s3x4x32064[-4.770487308502197,5.134904384613037:A0.0011637009553284994],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-5.257526397705078,6.271761417388916:A5.3596344517194334e-05],CT1s3x32x35x96[-5.084636211395264,4.978662967681885:A-0.002626009440932474]], value_cache=#2[CT1s3x32x35x96[-4.715044021606445,4.725687026977539:A0.0016024406523245105],CT1s3x32x35x96[-5.0299248695373535,5.066045761108398:A0.001982277264531371]])),)
<<<
[_untrace_forward_execution] __main__ - Phi3ForCausalLM
[_untrace_forward_execution] .. model - Phi3Model
[_untrace_forward_execution] .... embed_tokens - Embedding
[_untrace_forward_execution] .... layers[0] - Phi3DecoderLayer
[_untrace_forward_execution] ...... self_attn - Phi3Attention
[_untrace_forward_execution] ........ o_proj - Linear
[_untrace_forward_execution] ........ qkv_proj - Linear
[_untrace_forward_execution] ...... mlp - Phi3MLP
[_untrace_forward_execution] ........ gate_up_proj - Linear
[_untrace_forward_execution] ........ down_proj - Linear
[_untrace_forward_execution] ........ activation_fn - SiLU
[_untrace_forward_execution] ...... input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... resid_attn_dropout - Dropout
[_untrace_forward_execution] ...... resid_mlp_dropout - Dropout
[_untrace_forward_execution] .... layers[1] - Phi3DecoderLayer
[_untrace_forward_execution] ...... self_attn - Phi3Attention
[_untrace_forward_execution] ........ o_proj - Linear
[_untrace_forward_execution] ........ qkv_proj - Linear
[_untrace_forward_execution] ...... mlp - Phi3MLP
[_untrace_forward_execution] ........ gate_up_proj - Linear
[_untrace_forward_execution] ........ down_proj - Linear
[_untrace_forward_execution] ........ activation_fn - SiLU
[_untrace_forward_execution] ...... input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ...... resid_attn_dropout - Dropout
[_untrace_forward_execution] ...... resid_mlp_dropout - Dropout
[_untrace_forward_execution] .... norm - Phi3RMSNorm
[_untrace_forward_execution] .... rotary_emb - Phi3RotaryEmbedding
[_untrace_forward_execution] .. lm_head - Linear
Now we keep in memory every input/output for the submodules, we can guess the dynamic shapes for every of them. The final ones:
dynamic_shapes = diag.guess_dynamic_shapes()
print("The dynamic shapes are:")
pprint.pprint(dynamic_shapes)
The dynamic shapes are:
((),
{'attention_mask': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
'input_ids': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
'past_key_values': [[{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}],
[{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}]]})
And all the dynamic shapes all along the traced submodules.
print(
diag.pretty_text(
with_dynamic_shape=True,
with_shape=False,
with_min_max=False,
with_device=False,
with_inputs=False,
).replace("<_DimHint.DYNAMIC: 3>", "DYN")
)
>>> __main__: Phi3ForCausalLM
DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'input_ids': {0: DYN, 1: DYN}, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]]})
>>> model: Phi3Model
DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'cache_position': None, 'input_ids': {0: DYN, 1: DYN}, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_ids': None, 'return_dict': None, 'use_cache': None})
>>> embed_tokens: Embedding: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> layers[0]: Phi3DecoderLayer
DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> self_attn: Phi3Attention
DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> mlp: Phi3MLP
DS=(({0: DYN, 1: DYN},), {})
>>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> layers[1]: Phi3DecoderLayer
DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> self_attn: Phi3Attention
DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> mlp: Phi3MLP
DS=(({0: DYN, 1: DYN},), {})
>>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> norm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> rotary_emb: Phi3RotaryEmbedding: DS=(({0: DYN, 1: DYN}, {1: DYN}), {}) <<<
<<<
>>> lm_head: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
Evaluate the export¶
In many cases, the export (to torch.fx.Graph
, to ONNX)
does not work on the first try. We need a way to understand
how much the model can be exported. It can be used to evaluate
the how much code needs to be rewritten or patched to be exportable.
The verbosity can be increase to show dynamic shapes, results
of the discrepancies.
Let’s display the module and its submodule first.
print(
diag.pretty_text(
with_dynamic_shape=False,
with_shape=False,
with_min_max=False,
with_device=False,
with_inputs=False,
)
)
>>> __main__: Phi3ForCausalLM
>>> model: Phi3Model
>>> embed_tokens: Embedding <<<
>>> layers[0]: Phi3DecoderLayer
>>> self_attn: Phi3Attention
>>> o_proj: Linear <<<
>>> qkv_proj: Linear <<<
<<<
>>> mlp: Phi3MLP
>>> gate_up_proj: Linear <<<
>>> down_proj: Linear <<<
>>> activation_fn: SiLU <<<
<<<
>>> input_layernorm: Phi3RMSNorm <<<
>>> post_attention_layernorm: Phi3RMSNorm <<<
>>> resid_attn_dropout: Dropout <<<
>>> resid_mlp_dropout: Dropout <<<
<<<
>>> layers[1]: Phi3DecoderLayer
>>> self_attn: Phi3Attention
>>> o_proj: Linear <<<
>>> qkv_proj: Linear <<<
<<<
>>> mlp: Phi3MLP
>>> gate_up_proj: Linear <<<
>>> down_proj: Linear <<<
>>> activation_fn: SiLU <<<
<<<
>>> input_layernorm: Phi3RMSNorm <<<
>>> post_attention_layernorm: Phi3RMSNorm <<<
>>> resid_attn_dropout: Dropout <<<
>>> resid_mlp_dropout: Dropout <<<
<<<
>>> norm: Phi3RMSNorm <<<
>>> rotary_emb: Phi3RotaryEmbedding <<<
<<<
>>> lm_head: Linear <<<
<<<
The we try to export to see the submodule failing the whole model. We can pickle the failing model and restore it to speedup the refactoring to make it work.
print("----------------------")
ep = diag.try_export(
exporter="fx",
use_dynamic_shapes=True,
exporter_kwargs=dict(strict=False),
verbose=1,
)
print(f"success: {ep.status}")
print(diag.get_export_report())
----------------------
[try_export-FX] __main__ - Phi3ForCausalLM --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
[try_export-FX] .. model - Phi3Model --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
[try_export-FX] .... embed_tokens - Embedding --- OK
[try_export-FX] .... layers[0] - Phi3DecoderLayer --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ...... self_attn - Phi3Attention --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ........ o_proj - Linear --- OK
[try_export-FX] ........ qkv_proj - Linear --- OK
[try_export-FX] ...... mlp - Phi3MLP --- OK
[try_export-FX] ...... input_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... post_attention_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... resid_attn_dropout - Dropout --- OK
[try_export-FX] ...... resid_mlp_dropout - Dropout --- OK
[try_export-FX] .... layers[1] - Phi3DecoderLayer --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ...... self_attn - Phi3Attention --- FAIL, step=EXPORT, reason=Cannot associate shape [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]] specified at `dynamic_shapes['past_key_value']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_value']` (expected None)
[try_export-FX] ........ o_proj - Linear --- OK
[try_export-FX] ........ qkv_proj - Linear --- OK
[try_export-FX] ...... mlp - Phi3MLP --- OK
[try_export-FX] ...... input_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... post_attention_layernorm - Phi3RMSNorm --- OK
[try_export-FX] ...... resid_attn_dropout - Dropout --- OK
[try_export-FX] ...... resid_mlp_dropout - Dropout --- OK
[try_export-FX] .... norm - Phi3RMSNorm --- OK
[try_export-FX] .... rotary_emb - Phi3RotaryEmbedding --- FAIL, step=EXPORT, reason=Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
[try_export-FX] .... rotary_emb - Phi3RotaryEmbedding --- FAIL
[try_export-FX] .. lm_head - Linear --- OK
success: 2
__main__ Phi3ForCausalLM FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
..model Phi3Model FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
....embed_tokens Embedding OK -- ExportedProgram
....layers[0] Phi3DecoderLayer FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
......self_attn Phi3Attention FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
........o_proj Linear OK -- ExportedProgram
........qkv_proj Linear OK -- ExportedProgram
......mlp Phi3MLP OK -- ExportedProgram
........gate_up_proj Linear OK as part of its owner
........down_proj Linear OK as part of its owner
........activation_fn SiLU OK as part of its owner
......input_layernorm Phi3RMSNorm OK -- ExportedProgram
......post_attention_layernorm Phi3RMSNorm OK -- ExportedProgram
......resid_attn_dropout Dropout OK -- ExportedProgram
......resid_mlp_dropout Dropout OK -- ExportedProgram
....layers[1] Phi3DecoderLayer FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
......self_attn Phi3Attention FAIL -- step=EXPORT, reason='Cannot associate shape [[{0: D...'
........o_proj Linear OK -- ExportedProgram
........qkv_proj Linear OK -- ExportedProgram
......mlp Phi3MLP OK -- ExportedProgram
........gate_up_proj Linear OK as part of its owner
........down_proj Linear OK as part of its owner
........activation_fn SiLU OK as part of its owner
......input_layernorm Phi3RMSNorm OK -- ExportedProgram
......post_attention_layernorm Phi3RMSNorm OK -- ExportedProgram
......resid_attn_dropout Dropout OK -- ExportedProgram
......resid_mlp_dropout Dropout OK -- ExportedProgram
....norm Phi3RMSNorm OK -- ExportedProgram
....rotary_emb Phi3RotaryEmbedding FAIL -- step=EXPORT, reason='Could not guard on data-depend...'
..lm_head Linear OK -- ExportedProgram
Replace the failing module by a custom op¶
The main module is not exportable because one piece cannot be exported. But maybe if we assume it works, maybe everything else is working. So let’s try to replace this class by a custom op. This will be something for another example.
Total running time of the script: (0 minutes 9.513 seconds)
Related examples
Export Phi-3.5-mini-instruct with report_exportability
Export Phi-3.5-mini-instruct with draft_export