Note
Go to the end to download the full example code.
Export Phi-3.5-mini-instruct with draft_export¶
Tries torch.export._draft_export.draft_export()
.
Model¶
from contextlib import redirect_stderr
from io import StringIO
from typing import Any, Dict
import torch
import torch.export._draft_export
import transformers
from experimental_experiment.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.torch_export_patches import register_additional_serialization_functions
def get_phi35_untrained(batch_size: int = 2, **kwargs) -> Dict[str, Any]:
"""
Gets a non initialized model with two sets of inputs and different shapes.
:param batch_size: batch size
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
:return: dictionary
See `Phi-3.5-mini-instruct/config.json
<https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json>`_.
"""
config = {
"_name_or_path": "Phi-3.5-mini-instruct",
"architectures": ["Phi3ForCausalLM"],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"long_factor": [
1.0800000429153442,
1.1100000143051147,
1.1399999856948853,
1.340000033378601,
1.5899999141693115,
1.600000023841858,
1.6200000047683716,
2.620000123977661,
3.2300000190734863,
3.2300000190734863,
4.789999961853027,
7.400000095367432,
7.700000286102295,
9.09000015258789,
12.199999809265137,
17.670000076293945,
24.46000099182129,
28.57000160217285,
30.420001983642578,
30.840002059936523,
32.590003967285156,
32.93000411987305,
42.320003509521484,
44.96000289916992,
50.340003967285156,
50.45000457763672,
57.55000305175781,
57.93000411987305,
58.21000289916992,
60.1400032043457,
62.61000442504883,
62.62000274658203,
62.71000289916992,
63.1400032043457,
63.1400032043457,
63.77000427246094,
63.93000411987305,
63.96000289916992,
63.970001220703125,
64.02999877929688,
64.06999969482422,
64.08000183105469,
64.12000274658203,
64.41000366210938,
64.4800033569336,
64.51000213623047,
64.52999877929688,
64.83999633789062,
],
"short_factor": [
1.0,
1.0199999809265137,
1.0299999713897705,
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0699999332427979,
1.0999999046325684,
1.1099998950958252,
1.1599998474121094,
1.1599998474121094,
1.1699998378753662,
1.2899998426437378,
1.339999794960022,
1.679999828338623,
1.7899998426437378,
1.8199998140335083,
1.8499997854232788,
1.8799997568130493,
1.9099997282028198,
1.9399996995925903,
1.9899996519088745,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0799996852874756,
2.0899996757507324,
2.189999580383301,
2.2199995517730713,
2.5899994373321533,
2.729999542236328,
2.749999523162842,
2.8399994373321533,
],
"type": "longrope",
},
"rope_theta": 10000.0,
"sliding_window": 262144,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"use_cache": True,
"attention_bias": False,
"vocab_size": 32064,
}
config.update(**kwargs)
conf = transformers.Phi3Config(**config)
model = transformers.Phi3ForCausalLM(conf)
model.eval()
cache = make_dynamic_cache(
[
(torch.randn(batch_size, 32, 30, 96), torch.randn(batch_size, 32, 30, 96))
for i in range(config["num_hidden_layers"])
]
)
cache2 = make_dynamic_cache(
[
(torch.randn(batch_size + 1, 32, 31, 96), torch.randn(batch_size + 1, 32, 31, 96))
for i in range(config["num_hidden_layers"])
]
)
inputs = dict(
input_ids=torch.randint(0, 32064, (batch_size, 3)).to(torch.int64),
attention_mask=torch.ones((batch_size, 33)).to(torch.int64),
past_key_values=cache,
)
inputs2 = dict(
input_ids=torch.randint(0, 32064, (batch_size + 1, 4)).to(torch.int64),
attention_mask=torch.ones((batch_size + 1, 35)).to(torch.int64),
past_key_values=cache2,
)
return dict(inputs=inputs, model=model, inputs2=inputs2)
data = get_phi35_untrained(num_hidden_layers=2)
model, inputs, inputs2 = data["model"], data["inputs"], data["inputs2"]
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x96,T1s2x32x30x96], value_cache=#2[T1s2x32x30x96,T1s2x32x30x96]))
Draft Export¶
The function we want to try.
err = StringIO()
with redirect_stderr(err), register_additional_serialization_functions():
ep = torch.export._draft_export.draft_export(model, tuple(), kwargs=inputs, strict=False)
Errors if any.
print(err.getvalue())
Let’s print the report.
print(ep._report)
###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################
1. Data dependent error.
When exporting, we were unable to evaluate the value of `Eq(u0, 1)`.
This was encountered 1 times.
This occurred at the following user stacktrace:
File ~/vv/this312/lib/python3.12/site-packages/torch/utils/_contextlib.py, lineno 116, in decorate_context
File ~/github/transformers/src/transformers/modeling_rope_utils.py, lineno 86, in wrapper
File ~/github/transformers/src/transformers/modeling_rope_utils.py, lineno 50, in longrope_frequency_update
if seq_len > original_max_position_embeddings:
Locals:
seq_len: ['Tensor(shape: torch.Size([]), stride: (), storage_offset: 0)']
original_max_position_embeddings: [None]
And the following framework stacktrace:
File ~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py, lineno 1326, in __torch_function__
File ~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py, lineno 1373, in __torch_function__
File ~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py, lineno 973, in __torch_function__
return func(*args, **kwargs)
As a result, it was specialized to a constant (e.g. `0` in the 1st occurrence), and asserts were inserted into the graph.
Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
And the exported program.
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32064, 3072]", p_model_layers_0_self_attn_o_proj_weight: "f32[3072, 3072]", p_model_layers_0_self_attn_qkv_proj_weight: "f32[9216, 3072]", p_model_layers_0_mlp_gate_up_proj_weight: "f32[16384, 3072]", p_model_layers_0_mlp_down_proj_weight: "f32[3072, 8192]", p_model_layers_0_input_layernorm_weight: "f32[3072]", p_model_layers_0_post_attention_layernorm_weight: "f32[3072]", p_model_layers_1_self_attn_o_proj_weight: "f32[3072, 3072]", p_model_layers_1_self_attn_qkv_proj_weight: "f32[9216, 3072]", p_model_layers_1_mlp_gate_up_proj_weight: "f32[16384, 3072]", p_model_layers_1_mlp_down_proj_weight: "f32[3072, 8192]", p_model_layers_1_input_layernorm_weight: "f32[3072]", p_model_layers_1_post_attention_layernorm_weight: "f32[3072]", p_model_norm_weight: "f32[3072]", p_lm_head_weight: "f32[32064, 3072]", b_model_rotary_emb_inv_freq: "f32[48]", c_model_rotary_emb_lifted_tensor_0: "f32[48]", input_ids: "i64[2, 3]", attention_mask: "i64[2, 33]", past_key_values_key_cache_0: "f32[2, 32, 30, 96]", past_key_values_key_cache_1: "f32[2, 32, 30, 96]", past_key_values_value_cache_0: "f32[2, 32, 30, 96]", past_key_values_value_cache_1: "f32[2, 32, 30, 96]"):
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[2, 3, 3072]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids, 32000); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:454 in forward, code: cache_position = torch.arange(
arange: "i64[3]" = torch.ops.aten.arange.start(30, 33, device = device(type='cpu'), pin_memory = False)
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:459 in forward, code: position_ids = cache_position.unsqueeze(0)
unsqueeze: "i64[1, 3]" = torch.ops.aten.unsqueeze.default(arange, 0)
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:461 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[3, 33]" = torch.ops.aten.full.default([3, 33], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
arange_1: "i64[33]" = torch.ops.aten.arange.default(33, device = device(type='cpu'), pin_memory = False)
reshape: "i64[3, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1])
gt: "b8[3, 33]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
arange_2: "i64[33]" = torch.ops.aten.arange.default(33, device = device(type='cpu'), pin_memory = False)
reshape_1: "i64[3, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
sub: "i64[3, 1]" = torch.ops.aten.sub.Tensor(reshape_1, 262144); reshape_1 = None
le: "b8[3, 33]" = torch.ops.aten.le.Tensor(arange_2, sub); arange_2 = sub = None
bitwise_or_: "b8[3, 33]" = torch.ops.aten.bitwise_or_.Tensor(gt, le); gt = le = None
mul_: "f32[3, 33]" = torch.ops.aten.mul_.Tensor(full, bitwise_or_); full = bitwise_or_ = None
unsqueeze_1: "f32[1, 3, 33]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_2: "f32[1, 1, 3, 33]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1); unsqueeze_1 = None
slice_1: "f32[1, 1, 3, 33]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
slice_2: "f32[1, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[2, 1, 3, 33]" = torch.ops.aten.expand.default(slice_2, [2, 1, -1, -1]); slice_2 = None
clone: "f32[2, 1, 3, 33]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone)
slice_4: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_3, 1); slice_3 = None
slice_5: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_4, 2); slice_4 = None
slice_6: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, 33); slice_5 = None
slice_7: "i64[2, 33]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_3: "i64[2, 1, 33]" = torch.ops.aten.unsqueeze.default(slice_7, 1); slice_7 = None
unsqueeze_4: "i64[2, 1, 1, 33]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
slice_8: "i64[2, 1, 1, 33]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807); unsqueeze_4 = None
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(slice_8, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "i64[2, 1, 1, 33]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_8 = None
add: "f32[2, 1, 3, 33]" = torch.ops.aten.add.Tensor(slice_6, to); slice_6 = to = None
eq: "b8[2, 1, 3, 33]" = torch.ops.aten.eq.Scalar(add, 0); add = None
slice_9: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone)
slice_10: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_9, 1); slice_9 = None
slice_11: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_10, 2); slice_10 = None
slice_12: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, 33); slice_11 = None
masked_fill: "f32[2, 1, 3, 33]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq, -3.4028234663852886e+38); slice_12 = eq = None
slice_13: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_14: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807); slice_13 = None
slice_15: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807); slice_14 = None
copy_: "f32[2, 1, 3, 33]" = torch.ops.aten.copy_.default(slice_15, masked_fill); slice_15 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, unsqueeze, c_model_rotary_emb_lifted_tensor_0); submod_3 = unsqueeze = c_model_rotary_emb_lifted_tensor_0 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:385 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_7: "f32[1, 3, 96]" = wrap_with_set_grad_enabled[0]
to_8: "f32[1, 3, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_9: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_9, 2)
mean: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_2: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_2); add_2 = None
mul_2: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_9, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_10: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_10); p_model_layers_0_input_layernorm_weight = to_10 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2, 3, 9216]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_qkv_proj_weight); mul_3 = p_model_layers_0_self_attn_qkv_proj_weight = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:184 in forward, code: query_states = qkv[..., :query_pos]
slice_19: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 0, 3072)
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:185 in forward, code: key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
slice_20: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 3072, 6144)
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:186 in forward, code: value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
slice_21: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear, 2, 6144, 9223372036854775807); linear = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:188 in forward, code: query_states = query_states.view(hidden_shape).transpose(1, 2)
view: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_19, [2, 3, -1, 96]); slice_19 = None
transpose_1: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:189 in forward, code: key_states = key_states.view(hidden_shape).transpose(1, 2)
view_1: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_20, [2, 3, -1, 96]); slice_20 = None
transpose_2: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:190 in forward, code: value_states = value_states.view(hidden_shape).transpose(1, 2)
view_2: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_21, [2, 3, -1, 96]); slice_21 = None
transpose_3: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:193 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_8: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1)
unsqueeze_9: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1)
alias: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_1)
slice_22: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 96, 9223372036854775807); transpose_1 = None
alias_1: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_2)
slice_23: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 96, 9223372036854775807); transpose_2 = None
mul_4: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias, unsqueeze_8)
slice_24: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias, 3, 0, 48)
slice_25: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias, 3, 48, 9223372036854775807); alias = None
neg: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_25); slice_25 = None
cat_1: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg, slice_24], -1); neg = slice_24 = None
mul_5: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9); cat_1 = None
add_3: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
cat_2: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_3, slice_22], -1); add_3 = slice_22 = None
mul_6: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_1, unsqueeze_8); unsqueeze_8 = None
slice_26: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_1, 3, 0, 48)
slice_27: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_1, 3, 48, 9223372036854775807); alias_1 = None
neg_1: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_27); slice_27 = None
cat_3: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_1, slice_26], -1); neg_1 = slice_26 = None
mul_7: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_3, unsqueeze_9); cat_3 = unsqueeze_9 = None
add_4: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
cat_4: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_4, slice_23], -1); add_4 = slice_23 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:198 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_5: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, cat_4], -2); past_key_values_key_cache_0 = cat_4 = None
cat_6: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:210 in forward, code: attn_output, attn_weights = attention_interface(
slice_28: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone)
slice_29: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_28, 1); slice_28 = None
slice_30: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_29, 2); slice_29 = None
slice_31: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_30, 3, None, 33); slice_30 = None
scaled_dot_product_attention: "f32[2, 32, 3, 96]" = torch.ops.aten.scaled_dot_product_attention.default(cat_2, cat_5, cat_6, slice_31, scale = 0.10206207261596575); cat_2 = slice_31 = None
transpose_4: "f32[2, 3, 32, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous: "f32[2, 3, 32, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:222 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_2: "f32[2, 3, 3072]" = torch.ops.aten.reshape.default(contiguous, [2, 3, -1]); contiguous = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_0_self_attn_o_proj_weight); reshape_2 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
dropout: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_1, 0.0, False); linear_1 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:310 in forward, code: hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
add_5: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_9, dropout); to_9 = dropout = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_11: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_5, torch.float32); add_5 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_11, 2)
mean_1: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_6: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_6); add_6 = None
mul_8: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_11, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_12: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_12); p_model_layers_0_post_attention_layernorm_weight = to_12 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[2, 3, 16384]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_up_proj_weight); mul_9 = p_model_layers_0_mlp_gate_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:69 in forward, code: gate, up_states = up_states.chunk(2, dim=-1)
chunk = torch.ops.aten.chunk.default(linear_2, 2, -1); linear_2 = None
getitem_10: "f32[2, 3, 8192]" = chunk[0]
getitem_11: "f32[2, 3, 8192]" = chunk[1]; chunk = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[2, 3, 8192]" = torch.ops.aten.silu.default(getitem_10); getitem_10 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:70 in forward, code: up_states = up_states * self.activation_fn(gate)
mul_10: "f32[2, 3, 8192]" = torch.ops.aten.mul.Tensor(getitem_11, silu); getitem_11 = silu = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
dropout_1: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_3, 0.0, False); linear_3 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:315 in forward, code: hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
add_7: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_11, dropout_1); to_11 = dropout_1 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_13: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_13, 2)
mean_2: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_11: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_13, rsqrt_2); rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_14: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_14); p_model_layers_1_input_layernorm_weight = to_14 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[2, 3, 9216]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_qkv_proj_weight); mul_12 = p_model_layers_1_self_attn_qkv_proj_weight = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:184 in forward, code: query_states = qkv[..., :query_pos]
slice_32: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 0, 3072)
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:185 in forward, code: key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
slice_33: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 3072, 6144)
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:186 in forward, code: value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
slice_34: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(linear_4, 2, 6144, 9223372036854775807); linear_4 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:188 in forward, code: query_states = query_states.view(hidden_shape).transpose(1, 2)
view_3: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_32, [2, 3, -1, 96]); slice_32 = None
transpose_5: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_3, 1, 2); view_3 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:189 in forward, code: key_states = key_states.view(hidden_shape).transpose(1, 2)
view_4: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_33, [2, 3, -1, 96]); slice_33 = None
transpose_6: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_4, 1, 2); view_4 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:190 in forward, code: value_states = value_states.view(hidden_shape).transpose(1, 2)
view_5: "f32[2, 3, 32, 96]" = torch.ops.aten.view.default(slice_34, [2, 3, -1, 96]); slice_34 = None
transpose_7: "f32[2, 32, 3, 96]" = torch.ops.aten.transpose.int(view_5, 1, 2); view_5 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:193 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_10: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
unsqueeze_11: "f32[1, 1, 3, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None
alias_2: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_5)
slice_35: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 96, 9223372036854775807); transpose_5 = None
alias_3: "f32[2, 32, 3, 96]" = torch.ops.aten.alias.default(transpose_6)
slice_36: "f32[2, 32, 3, 0]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 96, 9223372036854775807); transpose_6 = None
mul_13: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_2, unsqueeze_10)
slice_37: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_2, 3, 0, 48)
slice_38: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_2, 3, 48, 9223372036854775807); alias_2 = None
neg_2: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_38); slice_38 = None
cat_7: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_2, slice_37], -1); neg_2 = slice_37 = None
mul_14: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_7, unsqueeze_11); cat_7 = None
add_9: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_13, mul_14); mul_13 = mul_14 = None
cat_8: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_9, slice_35], -1); add_9 = slice_35 = None
mul_15: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(alias_3, unsqueeze_10); unsqueeze_10 = None
slice_39: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_3, 3, 0, 48)
slice_40: "f32[2, 32, 3, 48]" = torch.ops.aten.slice.Tensor(alias_3, 3, 48, 9223372036854775807); alias_3 = None
neg_3: "f32[2, 32, 3, 48]" = torch.ops.aten.neg.default(slice_40); slice_40 = None
cat_9: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([neg_3, slice_39], -1); neg_3 = slice_39 = None
mul_16: "f32[2, 32, 3, 96]" = torch.ops.aten.mul.Tensor(cat_9, unsqueeze_11); cat_9 = unsqueeze_11 = None
add_10: "f32[2, 32, 3, 96]" = torch.ops.aten.add.Tensor(mul_15, mul_16); mul_15 = mul_16 = None
cat_10: "f32[2, 32, 3, 96]" = torch.ops.aten.cat.default([add_10, slice_36], -1); add_10 = slice_36 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:198 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_11: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_1, cat_10], -2); past_key_values_key_cache_1 = cat_10 = None
cat_12: "f32[2, 32, 33, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_1, transpose_7], -2); past_key_values_value_cache_1 = transpose_7 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:210 in forward, code: attn_output, attn_weights = attention_interface(
slice_41: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(clone); clone = None
slice_42: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_41, 1); slice_41 = None
slice_43: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_42, 2); slice_42 = None
slice_44: "f32[2, 1, 3, 33]" = torch.ops.aten.slice.Tensor(slice_43, 3, None, 33); slice_43 = None
scaled_dot_product_attention_1: "f32[2, 32, 3, 96]" = torch.ops.aten.scaled_dot_product_attention.default(cat_8, cat_11, cat_12, slice_44, scale = 0.10206207261596575); cat_8 = slice_44 = None
transpose_8: "f32[2, 3, 32, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2); scaled_dot_product_attention_1 = None
contiguous_1: "f32[2, 3, 32, 96]" = torch.ops.aten.contiguous.default(transpose_8); transpose_8 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:222 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[2, 3, 3072]" = torch.ops.aten.reshape.default(contiguous_1, [2, 3, -1]); contiguous_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_1_self_attn_o_proj_weight); reshape_3 = p_model_layers_1_self_attn_o_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
dropout_2: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_5, 0.0, False); linear_5 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:310 in forward, code: hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
add_11: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_13, dropout_2); to_13 = dropout_2 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_15: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_11, torch.float32); add_11 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_4: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_15, 2)
mean_3: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True); pow_4 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_12: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05); mean_3 = None
rsqrt_3: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_12); add_12 = None
mul_17: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_15, rsqrt_3); rsqrt_3 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_17, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_16: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_17, torch.float32); mul_17 = None
mul_18: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_16); p_model_layers_1_post_attention_layernorm_weight = to_16 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[2, 3, 16384]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_gate_up_proj_weight); mul_18 = p_model_layers_1_mlp_gate_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:69 in forward, code: gate, up_states = up_states.chunk(2, dim=-1)
chunk_1 = torch.ops.aten.chunk.default(linear_6, 2, -1); linear_6 = None
getitem_12: "f32[2, 3, 8192]" = chunk_1[0]
getitem_13: "f32[2, 3, 8192]" = chunk_1[1]; chunk_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu_1: "f32[2, 3, 8192]" = torch.ops.aten.silu.default(getitem_12); getitem_12 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:70 in forward, code: up_states = up_states * self.activation_fn(gate)
mul_19: "f32[2, 3, 8192]" = torch.ops.aten.mul.Tensor(getitem_13, silu_1); getitem_13 = silu_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[2, 3, 3072]" = torch.ops.aten.linear.default(mul_19, p_model_layers_1_mlp_down_proj_weight); mul_19 = p_model_layers_1_mlp_down_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/dropout.py:70 in forward, code: return F.dropout(input, self.p, self.training, self.inplace)
dropout_3: "f32[2, 3, 3072]" = torch.ops.aten.dropout.default(linear_7, 0.0, False); linear_7 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:315 in forward, code: hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
add_13: "f32[2, 3, 3072]" = torch.ops.aten.add.Tensor(to_15, dropout_3); to_15 = dropout_3 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:239 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(add_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_16 = None
to_17: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(add_13, torch.float32); add_13 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:240 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_5: "f32[2, 3, 3072]" = torch.ops.aten.pow.Tensor_Scalar(to_17, 2)
mean_4: "f32[2, 3, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True); pow_5 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:241 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_14: "f32[2, 3, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05); mean_4 = None
rsqrt_4: "f32[2, 3, 1]" = torch.ops.aten.rsqrt.default(add_14); add_14 = None
mul_20: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(to_17, rsqrt_4); to_17 = rsqrt_4 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:242 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(mul_20, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_17 = None
to_18: "f32[2, 3, 3072]" = torch.ops.aten.to.dtype(mul_20, torch.float32); mul_20 = None
mul_21: "f32[2, 3, 3072]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_18); p_model_norm_weight = to_18 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:760 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_45: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(mul_21); mul_21 = None
slice_46: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(slice_45, 1, 0); slice_45 = None
slice_47: "f32[2, 3, 3072]" = torch.ops.aten.slice.Tensor(slice_46, 2); slice_46 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_8: "f32[2, 3, 32064]" = torch.ops.aten.linear.default(slice_47, p_lm_head_weight); slice_47 = p_lm_head_weight = None
return (linear_8, cat_5, cat_11, cat_6, cat_12)
class submod_1(torch.nn.Module):
def forward(self, unsqueeze: "i64[1, 3]", c_model_rotary_emb_lifted_tensor_0: "f32[48]"):
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:468 in forward, code: position_embeddings = self.rotary_emb(hidden_states, position_ids)
max_1: "i64[]" = torch.ops.aten.max.default(unsqueeze)
add_1: "i64[]" = torch.ops.aten.add.Tensor(max_1, 1); max_1 = None
gt_1: "b8[]" = torch.ops.aten.gt.Scalar(add_1, 4096); add_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(gt_1, 0); gt_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
to_1: "f32[48]" = torch.ops.aten.to.dtype_layout(c_model_rotary_emb_lifted_tensor_0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); c_model_rotary_emb_lifted_tensor_0 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:375 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(to_1, 0); to_1 = None
slice_16: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807); unsqueeze_5 = None
unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_16, 2); slice_16 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_2: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32); unsqueeze_6 = None
expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_2, [1, -1, 1]); to_2 = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_1 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:376 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_17: "i64[1, 3]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807); unsqueeze = None
unsqueeze_7: "i64[1, 1, 3]" = torch.ops.aten.unsqueeze.default(slice_17, 1); slice_17 = None
slice_18: "i64[1, 1, 3]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_18, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_4: "f32[1, 1, 3]" = torch.ops.aten.to.dtype(slice_18, torch.float32); slice_18 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_3, to_4); submod_3 = to_3 = to_4 = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:382 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[1, 3, 96]" = wrap_with_autocast[0]
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:383 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[1, 3, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:385 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_7: "f32[1, 3, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_8: "f32[1, 3, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_7, to_8)
class submod_1(torch.nn.Module):
def forward(self, to_3: "f32[1, 48, 1]", to_4: "f32[1, 1, 3]"):
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:380 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_5: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(to_3, torch.float32); to_3 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_6: "f32[1, 1, 3]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None
matmul: "f32[1, 48, 3]" = torch.ops.aten.matmul.default(to_5, to_6); to_5 = to_6 = None
transpose: "f32[1, 3, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:381 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[1, 3, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:382 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[1, 3, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[1, 3, 96]" = torch.ops.aten.mul.Tensor(cos, 1.1902380714238083); cos = None
# File: ~/github/transformers/src/transformers/models/phi3/modeling_phi3.py:383 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[1, 3, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[1, 3, 96]" = torch.ops.aten.mul.Tensor(sin, 1.1902380714238083); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_self_attn_qkv_proj_weight: PARAMETER target='model.layers.0.self_attn.qkv_proj.weight'
p_model_layers_0_mlp_gate_up_proj_weight: PARAMETER target='model.layers.0.mlp.gate_up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
p_model_layers_1_self_attn_qkv_proj_weight: PARAMETER target='model.layers.1.self_attn.qkv_proj.weight'
p_model_layers_1_mlp_gate_up_proj_weight: PARAMETER target='model.layers.1.mlp.gate_up_proj.weight'
p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
c_model_rotary_emb_lifted_tensor_0: CONSTANT_TENSOR target='model.rotary_emb.lifted_tensor_0'
input_ids: USER_INPUT
attention_mask: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_key_cache_1: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
past_key_values_value_cache_1: USER_INPUT
# outputs
linear_8: USER_OUTPUT
cat_5: USER_OUTPUT
cat_11: USER_OUTPUT
cat_6: USER_OUTPUT
cat_12: USER_OUTPUT
Range constraints: {u0: VR[0, 1]}
Total running time of the script: (0 minutes 15.820 seconds)
Related examples

Export Phi-3.5-mini-instruct with report_exportability
Export Phi-3.5-mini-instruct with report_exportability