Note
Go to the end to download the full example code.
to_onnx, failures Phi-3.5-mini-instruct¶
Example to_onnx and Phi-2 shows how to export a simple LLM model with dynamic shapes. What if it does not work?
Model¶
import pprint
from typing import Any, Dict
import torch
import transformers
from experimental_experiment.helpers import string_type
from experimental_experiment.torch_interpreter.diagnose import infer_shape_type_from_execution
def get_phi2_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
"""
Gets a non initialized model with two sets of inputs and different shapes.
:param batch_size: batch size
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:return: dictionary
See `Phi-3.5-mini-instruct/config.json
<https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
"""
config = {
"_name_or_path": "Phi-3.5-mini-instruct",
"architectures": ["Phi3ForCausalLM"],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
1.0800000429153442,
1.1100000143051147,
1.1399999856948853,
1.340000033378601,
1.5899999141693115,
1.600000023841858,
1.6200000047683716,
2.620000123977661,
3.2300000190734863,
3.2300000190734863,
4.789999961853027,
7.400000095367432,
7.700000286102295,
9.09000015258789,
12.199999809265137,
17.670000076293945,
24.46000099182129,
28.57000160217285,
30.420001983642578,
30.840002059936523,
32.590003967285156,
32.93000411987305,
42.320003509521484,
44.96000289916992,
50.340003967285156,
50.45000457763672,
57.55000305175781,
57.93000411987305,
58.21000289916992,
60.1400032043457,
62.61000442504883,
62.62000274658203,
62.71000289916992,
63.1400032043457,
63.1400032043457,
63.77000427246094,
63.93000411987305,
63.96000289916992,
63.970001220703125,
64.02999877929688,
64.06999969482422,
64.08000183105469,
64.12000274658203,
64.41000366210938,
64.4800033569336,
64.51000213623047,
64.52999877929688,
64.83999633789062,
],
"short_factor": [
1.0,
1.0199999809265137,
1.0299999713897705,
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0699999332427979,
1.0999999046325684,
1.1099998950958252,
1.1599998474121094,
1.1599998474121094,
1.1699998378753662,
1.2899998426437378,
1.339999794960022,
1.679999828338623,
1.7899998426437378,
1.8199998140335083,
1.8499997854232788,
1.8799997568130493,
1.9099997282028198,
1.9399996995925903,
1.9899996519088745,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0799996852874756,
2.0899996757507324,
2.189999580383301,
2.2199995517730713,
2.5899994373321533,
2.729999542236328,
2.749999523162842,
2.8399994373321533,
],
"type": "longrope",
},
"rope_theta": 10000.0,
"sliding_window": 262144,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"use_cache": True,
"attention_bias": False,
"vocab_size": 32064,
}
config.update(**kwargs)
conf = transformers.Phi3Config(**config)
model = transformers.Phi3ForCausalLM(conf)
model.eval()
cache = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache.update(
torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96), i
)
cache2 = transformers.cache_utils.DynamicCache(config["num_hidden_layers"])
for i in range(config["num_hidden_layers"]):
cache2.update(
torch.randn(batch_size + 1, 32, 31, 96),
torch.randn(batch_size + 1, 32, 31, 96),
i,
)
inputs = dict(
input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
past_key_values=cache,
)
inputs2 = dict(
input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
past_key_values=cache2,
)
return dict(inputs=inputs, model=model, inputs2=inputs2)
data = get_phi2_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))
Dynamic Shapes¶
We want to infer the dynamic shapes from the two sets of inputs we gave. For that, we use a function to trace the execution of the model including its submodules. It is going to execute the model twice with the two sets of inputs and stores every intermediate input and output.
[_trace_forward_execution] __main__ - Phi3ForCausalLM
[_trace_forward_execution] ..model - Phi3Model
[_trace_forward_execution] ..lm_head - Linear
[infer_shape_type_from_execution] run with dict(args:(),kwargs:dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model] > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:int,output_hidden_states:int,return_dict:int,cache_position:None)
[model:Phi3Model] < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear] > T1r3
[lm_head:Linear] < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[infer_shape_type_from_execution] run with dict(args:(),kwargs:dict(input_ids:T7s3x4,attention_mask:T7s3x35,past_key_values:DynamicCache(key_cache=#2[T1s3x32x31x96,T1s3x32x31x96], value_cache=#2[T1s3x32x31x96,T1s3x32x31x96])))
[__main__:Phi3ForCausalLM] > **dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[model:Phi3Model] > **dict(input_ids:T7r2,attention_mask:T7r2,position_ids:None,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]),inputs_embeds:None,use_cache:None,output_attentions:int,output_hidden_states:int,return_dict:int,cache_position:None)
[model:Phi3Model] < *dict(last_hidden_state:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[lm_head:Linear] > T1r3
[lm_head:Linear] < T1r3
[__main__:Phi3ForCausalLM] < *dict(logits:T1r3,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
[trace_forward_execution] traced execution of model Phi3ForCausalLM
>>> __main__: Phi3ForCausalLM
> ((),dict(input_ids:CT7s2x3[4480,30643:A16788.666666666668],attention_mask:CT7s2x33[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]])))
> ((),dict(input_ids:CT7s3x4[3842,31602:A20204.833333333332],attention_mask:CT7s3x35[1,1:A1.0],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]])))
>>> model: Phi3Model
> ((),dict(input_ids:CT7s2x3[4480,30643:A16788.666666666668],attention_mask:CT7s2x33[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),inputs_embeds:None,use_cache:None,output_attentions:int[False],output_hidden_states:int[False],return_dict:int[True],cache_position:None))
> ((),dict(input_ids:CT7s3x4[3842,31602:A20204.833333333332],attention_mask:CT7s3x35[1,1:A1.0],position_ids:None,past_key_values:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),inputs_embeds:None,use_cache:None,output_attentions:int[False],output_hidden_states:int[False],return_dict:int[True],cache_position:None))
>>> embed_tokens: Embedding
> ((CT7s2x3[4480,30643:A16788.666666666668],),dict())
> ((CT7s3x4[3842,31602:A20204.833333333332],),dict())
< (CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],)
< (CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],)
<<<
>>> layers[0]: Phi3DecoderLayer
> ((CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> self_attn: Phi3Attention
> ((),dict(hidden_states:CT1s2x3x3072[-3.6068482398986816,4.1260199546813965:A0.0041334081171583315],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x30x96[-4.999748229980469,4.596332550048828:A-0.002525397799765299],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x30x96[-4.537422180175781,4.253913402557373:A0.001188100440696327],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((),dict(hidden_states:CT1s3x4x3072[-3.7820181846618652,4.1085357666015625:A0.0015415984938473353],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x31x96[-4.707752704620361,4.471818447113037:A0.0017012020740154023],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x31x96[-4.775568962097168,4.293818473815918:A0.00015770170416937328],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> o_proj: Linear
> ((CT1s2x3x3072[-2.397754669189453,1.8615864515304565:A-0.00028552267273683637],),dict())
> ((CT1s3x4x3072[-2.073399305343628,2.3145229816436768:A0.0001120645541599568],),dict())
< (CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],)
< (CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],)
<<<
>>> qkv_proj: Linear
> ((CT1s2x3x3072[-3.6068482398986816,4.1260199546813965:A0.0041334081171583315],),dict())
> ((CT1s3x4x3072[-3.7820181846618652,4.1085357666015625:A0.0015415984938473353],),dict())
< (CT1s2x3x9216[-5.016335487365723,4.414941787719727:A-0.002557053032576193],)
< (CT1s3x4x9216[-5.139618396759033,4.558834552764893:A-0.001959620447271012],)
<<<
< (CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],None)
< (CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],None)
<<<
>>> mlp: Phi3MLP
> ((CT1s2x3x3072[-4.1014404296875,3.9666833877563477:A-0.012948900215022855],),dict())
> ((CT1s3x4x3072[-4.106596946716309,4.0618391036987305:A0.004948446015071421],),dict())
>>> gate_up_proj: Linear
> ((CT1s2x3x3072[-4.1014404296875,3.9666833877563477:A-0.012948900215022855],),dict())
> ((CT1s3x4x3072[-4.106596946716309,4.0618391036987305:A0.004948446015071421],),dict())
< (CT1s2x3x16384[-4.869340419769287,5.286735534667969:A-0.00018809297783898651],)
< (CT1s3x4x16384[-4.958648204803467,4.989035606384277:A-0.0033986618244264597],)
<<<
>>> down_proj: Linear
> ((CT1s2x3x8192[-8.716111183166504,8.193013191223145:A-0.0011114813027920483],),dict())
> ((CT1s3x4x8192[-9.956159591674805,11.3091459274292:A0.00025204877698488506],),dict())
< (CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],)
< (CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],)
<<<
>>> activation_fn: SiLU
> ((CT1s2x3x8192[-4.831746578216553,4.276060581207275:A0.004578614053405279],),dict())
> ((CT1s3x4x8192[-4.439509391784668,4.847005367279053:A-0.004492188054667186],),dict())
< (CT1s2x3x8192[-0.27846455574035645,4.217449188232422:A0.24632204136495695],)
< (CT1s3x4x8192[-0.27846455574035645,4.809244155883789:A0.24241726201152414],)
<<<
< (CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],)
< (CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],)
<<<
>>> input_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],),dict())
> ((CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],),dict())
< (CT1s2x3x3072[-3.6068482398986816,4.1260199546813965:A0.0041334081171583315],)
< (CT1s3x4x3072[-3.7820181846618652,4.1085357666015625:A0.0015415984938473353],)
<<<
>>> post_attention_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-1.7519030570983887,1.521032452583313:A-0.004998811839893986],),dict())
> ((CT1s3x4x3072[-1.5569356679916382,1.542656421661377:A0.0019694551863583204],),dict())
< (CT1s2x3x3072[-4.1014404296875,3.9666833877563477:A-0.012948900215022855],)
< (CT1s3x4x3072[-4.106596946716309,4.0618391036987305:A0.004948446015071421],)
<<<
>>> resid_attn_dropout: Dropout
> ((CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],),dict())
> ((CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],),dict())
< (CT1s2x3x3072[-1.7544070482254028,1.5175296068191528:A-0.0050823212642399085],)
< (CT1s3x4x3072[-1.564698338508606,1.5298670530319214:A0.001937195043135868],)
<<<
>>> resid_mlp_dropout: Dropout
> ((CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],),dict())
> ((CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],),dict())
< (CT1s2x3x3072[-5.0135064125061035,5.036469459533691:A0.010313347476135782],)
< (CT1s3x4x3072[-5.809861183166504,5.423994064331055:A-0.008804775849689072],)
<<<
< (CT1s2x3x3072[-5.366197109222412,5.397191524505615:A0.0053145358553794925],)
< (CT1s3x4x3072[-6.367547035217285,5.992779731750488:A-0.006835320722656333],)
<<<
>>> layers[1]: Phi3DecoderLayer
> ((CT1s2x3x3072[-5.366197109222412,5.397191524505615:A0.0053145358553794925],),dict(attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((CT1s3x4x3072[-6.367547035217285,5.992779731750488:A-0.006835320722656333],),dict(attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> self_attn: Phi3Attention
> ((),dict(hidden_states:CT1s2x3x3072[-3.7633559703826904,3.8344295024871826:A0.003907856796488652],attention_mask:CT1s2x1x3x33[-3.4028234663852886e+38,-0.0:A-1.0311586261773601e+37],position_ids:CT7s1x3[30,32:A31.0],past_key_value:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x30x96[-4.244330883026123,4.201772212982178:A-0.0015612097777758596]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x30x96[-4.228041648864746,4.426188945770264:A0.001298994530771347]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s3[30,32:A31.0],position_embeddings:(CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])))
> ((),dict(hidden_states:CT1s3x4x3072[-4.528147220611572,4.306519031524658:A-0.0049134756377199945],attention_mask:CT1s3x1x4x35[-3.4028234663852886e+38,-0.0:A-1.4583529141651237e+37],position_ids:CT7s1x4[31,34:A32.5],past_key_value:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x31x96[-5.308402061462402,4.379249572753906:A-0.00011979189954858169]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x31x96[-4.303168773651123,4.225716590881348:A0.0009151201830652209]]),output_attentions:int[False],use_cache:int[True],cache_position:CT7s4[31,34:A32.5],position_embeddings:(CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])))
>>> o_proj: Linear
> ((CT1s2x3x3072[-2.522385597229004,2.634071111679077:A0.005504650073209956],),dict())
> ((CT1s3x4x3072[-2.916518211364746,2.487565040588379:A0.001905495911235554],),dict())
< (CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],)
< (CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],)
<<<
>>> qkv_proj: Linear
> ((CT1s2x3x3072[-3.7633559703826904,3.8344295024871826:A0.003907856796488652],),dict())
> ((CT1s3x4x3072[-4.528147220611572,4.306519031524658:A-0.0049134756377199945],),dict())
< (CT1s2x3x9216[-5.106101989746094,4.490701675415039:A-0.002418912901332617],)
< (CT1s3x4x9216[-4.6034040451049805,4.716505527496338:A-0.0005003103740501012],)
<<<
< (CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],None)
< (CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],None)
<<<
>>> mlp: Phi3MLP
> ((CT1s2x3x3072[-3.7263343334198,3.9119250774383545:A0.0003562870679852084],),dict())
> ((CT1s3x4x3072[-4.118927955627441,4.210121154785156:A-0.005571656055501102],),dict())
>>> gate_up_proj: Linear
> ((CT1s2x3x3072[-3.7263343334198,3.9119250774383545:A0.0003562870679852084],),dict())
> ((CT1s3x4x3072[-4.118927955627441,4.210121154785156:A-0.005571656055501102],),dict())
< (CT1s2x3x16384[-4.767377853393555,4.763617992401123:A-0.0011706767508409637],)
< (CT1s3x4x16384[-4.816030025482178,5.136754989624023:A0.00020893273096452467],)
<<<
>>> down_proj: Linear
> ((CT1s2x3x8192[-10.276310920715332,8.63280963897705:A0.007607944792103967],),dict())
> ((CT1s3x4x8192[-9.172404289245605,10.50564956665039:A-0.0007796187043171565],),dict())
< (CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],)
< (CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],)
<<<
>>> activation_fn: SiLU
> ((CT1s2x3x8192[-4.767377853393555,4.6629109382629395:A-0.0019218210885677915],),dict())
> ((CT1s3x4x8192[-4.816030025482178,4.801560878753662:A0.0034934405888857136],),dict())
< (CT1s2x3x8192[-0.27846455574035645,4.619309425354004:A0.2432932592988779],)
< (CT1s3x4x8192[-0.27846455574035645,4.7624287605285645:A0.24843786659936437],)
<<<
< (CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],)
< (CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],)
<<<
>>> input_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-5.366197109222412,5.397191524505615:A0.0053145358553794925],),dict())
> ((CT1s3x4x3072[-6.367547035217285,5.992779731750488:A-0.006835320722656333],),dict())
< (CT1s2x3x3072[-3.7633559703826904,3.8344295024871826:A0.003907856796488652],)
< (CT1s3x4x3072[-4.528147220611572,4.306519031524658:A-0.0049134756377199945],)
<<<
>>> post_attention_layernorm: Phi3RMSNorm
> ((CT1s2x3x3072[-5.602095603942871,5.5913519859313965:A0.00028517031084144645],),dict())
> ((CT1s3x4x3072[-5.972293376922607,6.098865509033203:A-0.008048094156871835],),dict())
< (CT1s2x3x3072[-3.7263343334198,3.9119250774383545:A0.0003562870679852084],)
< (CT1s3x4x3072[-4.118927955627441,4.210121154785156:A-0.005571656055501102],)
<<<
>>> resid_attn_dropout: Dropout
> ((CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],),dict())
> ((CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],),dict())
< (CT1s2x3x3072[-1.6478346586227417,1.5206420421600342:A-0.005029365316684359],)
< (CT1s3x4x3072[-1.4533982276916504,1.5127313137054443:A-0.0012127736726231181],)
<<<
>>> resid_mlp_dropout: Dropout
> ((CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],),dict())
> ((CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],),dict())
< (CT1s2x3x3072[-5.244140148162842,5.897292137145996:A-0.0020538516822499434],)
< (CT1s3x4x3072[-5.963540077209473,5.767940998077393:A0.003957682628348873],)
<<<
< (CT1s2x3x3072[-8.54723834991455,7.845819473266602:A-0.0017686810511400432],)
< (CT1s3x4x3072[-9.078064918518066,8.243967056274414:A-0.004090411359306422],)
<<<
>>> norm: Phi3RMSNorm
> ((CT1s2x3x3072[-8.54723834991455,7.845819473266602:A-0.0017686810511400432],),dict())
> ((CT1s3x4x3072[-9.078064918518066,8.243967056274414:A-0.004090411359306422],),dict())
< (CT1s2x3x3072[-4.311580657958984,3.90199875831604:A-0.0008492474549339275],)
< (CT1s3x4x3072[-4.496620178222656,4.30377721786499:A-0.0018928649814797928],)
<<<
>>> rotary_emb: Phi3RotaryEmbedding
> ((CT1s2x3x3072[-0.07310392707586288,0.0832141563296318:A8.350938669193169e-05],CT7s1x3[30,32:A31.0]),dict())
> ((CT1s3x4x3072[-0.07719504088163376,0.08286085724830627:A3.226016894325445e-05],CT7s1x4[31,34:A32.5]),dict())
< (CT1s1x3x96[-1.1855769157409668,1.1902371644973755:A0.746652018013669],CT1s1x3x96[-1.1887905597686768,1.190193772315979:A0.1589894221542636])
< (CT1s1x4x96[-1.1855769157409668,1.190237045288086:A0.7129333875218435],CT1s1x4x96[-1.1719439029693604,1.1902378797531128:A0.18296290554159592])
<<<
< (dict(last_hidden_state:CT1s2x3x3072[-4.311580657958984,3.90199875831604:A-0.0008492474549339275],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x33x96[-6.076218128204346,4.901590347290039:A-0.000825817244353194]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x33x96[-4.364691257476807,4.490701675415039:A0.0002260386739006028]])),)
< (dict(last_hidden_state:CT1s3x4x3072[-4.496620178222656,4.30377721786499:A-0.0018928649814797928],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x35x96[-5.430703639984131,5.21905517578125:A-0.0001790060691358257]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x35x96[-4.4542317390441895,4.643773555755615:A0.0006424695788216192]])),)
<<<
>>> lm_head: Linear
> ((CT1s2x3x3072[-4.311580657958984,3.90199875831604:A-0.0008492474549339275],),dict())
> ((CT1s3x4x3072[-4.496620178222656,4.30377721786499:A-0.0018928649814797928],),dict())
< (CT1s2x3x32064[-4.8934478759765625,5.155729293823242:A-0.004496973525844078],)
< (CT1s3x4x32064[-4.830400466918945,5.16563606262207:A0.0007996815108034524],)
<<<
< (dict(logits:CT1s2x3x32064[-4.8934478759765625,5.155729293823242:A-0.004496973525844078],past_key_values:DynamicCache(key_cache=#2[CT1s2x32x33x96[-4.999748229980469,5.135319232940674:A-0.0014785649989419458],CT1s2x32x33x96[-6.076218128204346,4.901590347290039:A-0.000825817244353194]], value_cache=#2[CT1s2x32x33x96[-4.537422180175781,4.31105375289917:A0.000781268099316412],CT1s2x32x33x96[-4.364691257476807,4.490701675415039:A0.0002260386739006028]])),)
< (dict(logits:CT1s3x4x32064[-4.830400466918945,5.16563606262207:A0.0007996815108034524],past_key_values:DynamicCache(key_cache=#2[CT1s3x32x35x96[-6.123331069946289,5.262447834014893:A0.0019904564969595038],CT1s3x32x35x96[-5.430703639984131,5.21905517578125:A-0.0001790060691358257]], value_cache=#2[CT1s3x32x35x96[-4.818883895874023,4.558834552764893:A-2.3198660056494583e-05],CT1s3x32x35x96[-4.4542317390441895,4.643773555755615:A0.0006424695788216192]])),)
<<<
[_untrace_forward_execution] __main__ - Phi3ForCausalLM
[_untrace_forward_execution] ..model - Phi3Model
[_untrace_forward_execution] ....embed_tokens - Embedding
[_untrace_forward_execution] ....layers[0] - Phi3DecoderLayer
[_untrace_forward_execution] ......self_attn - Phi3Attention
[_untrace_forward_execution] ........o_proj - Linear
[_untrace_forward_execution] ........qkv_proj - Linear
[_untrace_forward_execution] ......mlp - Phi3MLP
[_untrace_forward_execution] ........gate_up_proj - Linear
[_untrace_forward_execution] ........down_proj - Linear
[_untrace_forward_execution] ........activation_fn - SiLU
[_untrace_forward_execution] ......input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......resid_attn_dropout - Dropout
[_untrace_forward_execution] ......resid_mlp_dropout - Dropout
[_untrace_forward_execution] ....layers[1] - Phi3DecoderLayer
[_untrace_forward_execution] ......self_attn - Phi3Attention
[_untrace_forward_execution] ........o_proj - Linear
[_untrace_forward_execution] ........qkv_proj - Linear
[_untrace_forward_execution] ......mlp - Phi3MLP
[_untrace_forward_execution] ........gate_up_proj - Linear
[_untrace_forward_execution] ........down_proj - Linear
[_untrace_forward_execution] ........activation_fn - SiLU
[_untrace_forward_execution] ......input_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......post_attention_layernorm - Phi3RMSNorm
[_untrace_forward_execution] ......resid_attn_dropout - Dropout
[_untrace_forward_execution] ......resid_mlp_dropout - Dropout
[_untrace_forward_execution] ....norm - Phi3RMSNorm
[_untrace_forward_execution] ....rotary_emb - Phi3RotaryEmbedding
[_untrace_forward_execution] ..lm_head - Linear
Now we keep in memory every input/output for the submodules, we can guess the dynamic shapes for every of them. The final ones:
dynamic_shapes = diag.guess_dynamic_shapes()
print("The dynamic shapes are:")
pprint.pprint(dynamic_shapes)
The dynamic shapes are:
((),
{'attention_mask': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
'input_ids': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
'past_key_values': [[{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}],
[{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>},
{0: <_DimHint.DYNAMIC: 3>, 2: <_DimHint.DYNAMIC: 3>}]]})
And all the dynamic shapes all along the traced submodules.
print(
diag.pretty_text(
with_dynamic_shape=True,
with_shape=False,
with_min_max=False,
with_device=False,
with_inputs=False,
).replace("<_DimHint.DYNAMIC: 3>", "DYN")
)
>>> __main__: Phi3ForCausalLM
DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'input_ids': {0: DYN, 1: DYN}, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]]})
>>> model: Phi3Model
DS=((), {'attention_mask': {0: DYN, 1: DYN}, 'cache_position': None, 'input_ids': {0: DYN, 1: DYN}, 'inputs_embeds': None, 'output_attentions': None, 'output_hidden_states': None, 'past_key_values': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_ids': None, 'return_dict': None, 'use_cache': None})
>>> embed_tokens: Embedding: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> layers[0]: Phi3DecoderLayer
DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> self_attn: Phi3Attention
DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> mlp: Phi3MLP
DS=(({0: DYN, 1: DYN},), {})
>>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> layers[1]: Phi3DecoderLayer
DS=(({0: DYN, 1: DYN},), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> self_attn: Phi3Attention
DS=((), {'attention_mask': {0: DYN, 2: DYN, 3: DYN}, 'cache_position': {0: DYN}, 'hidden_states': {0: DYN, 1: DYN}, 'output_attentions': None, 'past_key_value': [[{0: DYN, 2: DYN}, {0: DYN, 2: DYN}], [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}]], 'position_embeddings': ({1: DYN}, {1: DYN}), 'position_ids': {1: DYN}, 'use_cache': None})
>>> o_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> qkv_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> mlp: Phi3MLP
DS=(({0: DYN, 1: DYN},), {})
>>> gate_up_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> down_proj: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> activation_fn: SiLU: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> input_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> post_attention_layernorm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_attn_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> resid_mlp_dropout: Dropout: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
>>> norm: Phi3RMSNorm: DS=(({0: DYN, 1: DYN},), {}) <<<
>>> rotary_emb: Phi3RotaryEmbedding: DS=(({0: DYN, 1: DYN}, {1: DYN}), {}) <<<
<<<
>>> lm_head: Linear: DS=(({0: DYN, 1: DYN},), {}) <<<
<<<
Evaluate the export¶
In many cases, the export (to torch.fx.Graph
, to ONNX)
does not work on the first try. We need a way to understand
how much the model can be exported. It can be used to evaluate
the how much code needs to be rewritten or patched to be exportable.
The verbosity can be increase to show dynamic shapes, results
of the discrepancies.
Let’s display the module and its submodule first.
print(
diag.pretty_text(
with_dynamic_shape=False,
with_shape=False,
with_min_max=False,
with_device=False,
with_inputs=False,
)
)
>>> __main__: Phi3ForCausalLM
>>> model: Phi3Model
>>> embed_tokens: Embedding <<<
>>> layers[0]: Phi3DecoderLayer
>>> self_attn: Phi3Attention
>>> o_proj: Linear <<<
>>> qkv_proj: Linear <<<
<<<
>>> mlp: Phi3MLP
>>> gate_up_proj: Linear <<<
>>> down_proj: Linear <<<
>>> activation_fn: SiLU <<<
<<<
>>> input_layernorm: Phi3RMSNorm <<<
>>> post_attention_layernorm: Phi3RMSNorm <<<
>>> resid_attn_dropout: Dropout <<<
>>> resid_mlp_dropout: Dropout <<<
<<<
>>> layers[1]: Phi3DecoderLayer
>>> self_attn: Phi3Attention
>>> o_proj: Linear <<<
>>> qkv_proj: Linear <<<
<<<
>>> mlp: Phi3MLP
>>> gate_up_proj: Linear <<<
>>> down_proj: Linear <<<
>>> activation_fn: SiLU <<<
<<<
>>> input_layernorm: Phi3RMSNorm <<<
>>> post_attention_layernorm: Phi3RMSNorm <<<
>>> resid_attn_dropout: Dropout <<<
>>> resid_mlp_dropout: Dropout <<<
<<<
>>> norm: Phi3RMSNorm <<<
>>> rotary_emb: Phi3RotaryEmbedding <<<
<<<
>>> lm_head: Linear <<<
<<<
The we try to export.
ep = diag.try_export(
exporter="fx",
use_dynamic_shapes=True,
exporter_kwargs=dict(strict=False),
bypass_kwargs=dict(patch_transformers=True, replace_dynamic_cache=True),
verbose=1,
)
[try_export] __main__ - Phi3ForCausalLM --- FAIL-EXPORT: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
[try_export] ..model - Phi3Model --- FAIL-EXPORT: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
[try_export] ....embed_tokens - Embedding --- OK
[try_export] ....layers[0] - Phi3DecoderLayer --- FAIL-EXPORT: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['hidden_states']` is a <class 'dict'>, but `dynamic_shapes['hidden_states']` is not
[try_export] ......self_attn - Phi3Attention --- FAIL-EXPORT: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['hidden_states', 'position_embeddings', 'attention_mask', 'past_key_value', 'cache_position', 'kwargs'] of `inputs`, but here they are ['attention_mask', 'cache_position', 'hidden_states', 'output_attentions', 'past_key_value', 'position_embeddings', 'position_ids', 'use_cache']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
[try_export] ........o_proj - Linear --- OK
[try_export] ........qkv_proj - Linear --- OK
[try_export] ......mlp - Phi3MLP --- OK
[try_export] ......input_layernorm - Phi3RMSNorm --- OK
[try_export] ......post_attention_layernorm - Phi3RMSNorm --- OK
[try_export] ......resid_attn_dropout - Dropout --- OK
[try_export] ......resid_mlp_dropout - Dropout --- OK
[try_export] ....layers[1] - Phi3DecoderLayer --- FAIL-EXPORT: Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs['hidden_states']` is a <class 'dict'>, but `dynamic_shapes['hidden_states']` is not
[try_export] ......self_attn - Phi3Attention --- FAIL-EXPORT: When `dynamic_shapes` is specified as a dict, its top-level keys must be the arg names ['hidden_states', 'position_embeddings', 'attention_mask', 'past_key_value', 'cache_position', 'kwargs'] of `inputs`, but here they are ['attention_mask', 'cache_position', 'hidden_states', 'output_attentions', 'past_key_value', 'position_embeddings', 'position_ids', 'use_cache']. Alternatively, you could also ignore arg names entirely and specify `dynamic_shapes` as a list/tuple matching `inputs`. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
[try_export] ........o_proj - Linear --- OK
[try_export] ........qkv_proj - Linear --- OK
[try_export] ......mlp - Phi3MLP --- OK
[try_export] ......input_layernorm - Phi3RMSNorm --- OK
[try_export] ......post_attention_layernorm - Phi3RMSNorm --- OK
[try_export] ......resid_attn_dropout - Dropout --- OK
[try_export] ......resid_mlp_dropout - Dropout --- OK
[try_export] ....norm - Phi3RMSNorm --- OK
[try_export] ....rotary_emb - Phi3RotaryEmbedding --- FAIL-EXPORT: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
[try_export] ..lm_head - Linear --- OK
Total running time of the script: (0 minutes 5.563 seconds)
Related examples