Note
Go to the end to download the full example code.
Test the export on untrained models¶
Checking the exporter on a whole model takes time as it is usually big but we can create a smaller version with the same architecture. Then fix export issues on such a small model is faster.
codellama/CodeLlama-7b-Python-hf¶
Let’s grab some information about this model. This reuses huggingface_hub API.
import copy
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
)
from onnx_diagnostic.torch_models.hghub.hub_api import (
get_model_info,
get_pretrained_config,
task_from_id,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
model_id = "codellama/CodeLlama-7b-Python-hf"
print("info", get_model_info(model_id))
info ModelInfo(id='codellama/CodeLlama-7b-Python-hf', author='codellama', sha='d4178f5d2eead875e627ec487b23679266319b7f', created_at=datetime.datetime(2023, 8, 24, 16, 31, 28, tzinfo=datetime.timezone.utc), last_modified=datetime.datetime(2024, 4, 12, 14, 16, 26, tzinfo=datetime.timezone.utc), private=False, disabled=False, downloads=19482, downloads_all_time=None, gated=False, gguf=None, inference=None, inference_provider_mapping=None, likes=141, library_name='transformers', tags=['transformers', 'pytorch', 'safetensors', 'llama', 'text-generation', 'llama-2', 'code', 'arxiv:2308.12950', 'license:llama2', 'autotrain_compatible', 'text-generation-inference', 'endpoints_compatible', 'region:us'], pipeline_tag='text-generation', mask_token=None, card_data={'base_model': None, 'datasets': None, 'eval_results': None, 'language': ['code'], 'library_name': None, 'license': 'llama2', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': 'text-generation', 'tags': ['llama-2']}, widget_data=None, model_index=None, config={'architectures': ['LlamaForCausalLM'], 'model_type': 'llama', 'tokenizer_config': {'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'pad_token': None, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}}}, transformers_info=TransformersInfo(auto_model='AutoModelForCausalLM', custom_class=None, pipeline_tag='text-generation', processor='AutoTokenizer'), trending_score=None, siblings=[RepoSibling(rfilename='.gitattributes', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='LICENSE', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='README.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='USE_POLICY.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='generation_config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00001-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00002-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model.safetensors.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00001-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00002-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00003-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model.bin.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='special_tokens_map.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.model', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer_config.json', size=None, blob_id=None, lfs=None)], spaces=['bigcode/bigcode-models-leaderboard', 'Intel/low_bit_open_llm_leaderboard', 'BAAI/open_cn_llm_leaderboard', 'qiantong-xu/toolbench-leaderboard', 'gsaivinay/open_llm_leaderboard', 'EvanTHU/MotionLLM', 'GTBench/GTBench', 'Vikhrmodels/small-shlepa-lb', 'Vikhrmodels/DOoM-lb', 'kz-transformers/kaz-llm-lb', 'felixz/open_llm_leaderboard', 'HemaAM/GPT_train_on_LLaMa', '21world/bigcode-models-leaderboard', 'OPTML-Group/UnlearnCanvas-Benchmark', 'whackthejacker/CodeTuneStudio', 'anantgupta129/LitGPT-Pythia-160M', 'BAAI/open_flageval_vlm_leaderboard', 'neubla/neubla-llm-evaluation-board', 'PrarthanaTS/tsai-gpt-from-scratch', 'MadhurGarg/TSAIGPTRedPajama', 'RaviNaik/ERA-SESSION22', 'theangkko/codellama-CodeLlama-7b-Python-hf', 'rodrigomasini/data_only_open_llm_leaderboard', 'Docfile/open_llm_leaderboard', 'Sijuade/GPTNEXTWORD', 'VDebugger/VDebugger-generalist-for-VQA', 'Temuzin64/code_helper', 'piyushgrover/MiniGPT_S22', 'supra-e-acc/Pythia-160M-text-generate', 'venkyyuvy/GPT_redpajama', 'VarunSivamani/GPT-From-Scratch', 'mkthoma/GPT_From_Scratch', 'sanjanatule/GPTNext', 'RashiAgarwal/TSAIGPTRedPajama', 'neuralorbs/DialogGen', 'Navyabhat/ERAV1-Session-22', 'GunaKoppula/ERA-Session-22', 'Vaish2705/ERA_S22', 'smothiki/open_llm_leaderboard', 'aemonge/codellama-CodeLlama-7b-Python-hf', 'sid92/codellama-CodeLlama-7b-Python-hf', 'poubellearman/codellama-CodeLlama-7b-Python-hf', 'IPF/codellama-CodeLlama-7b-Python-hf', 'Chris4K/codellama-CodeLlama-7b-Python-hf', 'CuriosityPdf/codellama-CodeLlama-7b-Python-hf', 'shreefhamed/codellama-CodeLlama-7b-Python-hf', 'markl11/codellama-CodeLlama-7b-Python-hf', '0x1668/open_llm_leaderboard', 'rpratl/codellama-CodeLlama-7b-Python-hf', 'pngwn/open_llm_leaderboard-check', 'asir0z/open_llm_leaderboard', 'LovelySweet/codellama-CodeLlama-7b-Python-hf', 'kbmlcoding/open_llm_leaderboard_free', 'ToletiSri/TSAI_S22', 'aichampions/open_llm_leaderboard', 'Adeco/open_llm_leaderboard', 'anirudh937/open_llm_leaderboard', 'smothiki/open_llm_leaderboard2', 'mjalg/IFEvalTR', 'lastsamuraii/LitGPT-Pythia-160M', 'atlasas/bigcode-models-leaderboard', 'Wazahat/Pyco', 'feryelb/python-coder'], safetensors=SafeTensorsInfo(parameters={'BF16': 6738415616}, total=6738415616), security_repo_status=None, xet_enabled=None)
The configuration.
print("config", get_pretrained_config(model_id))
config LlamaConfig {
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 16384,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 1000000,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.53.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
The task determines the set of inputs which needs to be created for this input.
print("task", task_from_id(model_id))
task text-generation
Untrained model¶
The function get_untrained_model_with_inputs
.
It loads the pretrained configuration, extracts the task associated
to the model and them creates random inputs and dynamic shapes
for torch.export.export()
.
[get_untrained_model_with_inputs] model_id='codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] use preinstalled 'codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] architectures=['LlamaForCausalLM']
[get_untrained_model_with_inputs] cls='LlamaConfig'
[get_untrained_model_with_inputs] task='text-generation'
[get_untrained_model_with_inputs] use fct=<function get_inputs at 0x7a4f195cda80>
model size: 547377152
number of weights: 136844288
fields: {'size', 'task', 'model_kwargs', 'n_weights', 'configuration', 'input_kwargs', 'dynamic_shapes', 'inputs', 'model'}
Inputs
print("inputs:", string_type(data["inputs"], with_shape=True))
inputs: dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x128,T1s2x32x30x128], value_cache=#2[T1s2x32x30x128,T1s2x32x30x128]))
Dynamic Shapes
print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
dynamic shapes: {'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
Let’s check the model runs. We still needs to copy the inputs before using the models, the cache is usually modified inplace. Expected outputs can be used later to compute discrepancies.
inputs_copy = copy.deepcopy(data["inputs"])
model = data["model"]
expected_outputs = model(**inputs_copy)
print("outputs:", string_type(expected_outputs, with_shape=True))
outputs: CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x128,T1s2x32x33x128], value_cache=#2[T1s2x32x33x128,T1s2x32x33x128]))
It works.
Export¶
The model uses transformers.cache_utils.DynamicCache
.
It still requires patches to be exportable (control flow).
See onnx_diagnostic.torch_export_patches.torch_export_patches()
with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(
model,
(),
kwargs=f(data["inputs"]),
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
strict=False,
)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 1024]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 1024]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 1024]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 1024]", p_model_layers_0_self_attn_o_proj_weight: "f32[1024, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[6144, 1024]", p_model_layers_0_mlp_up_proj_weight: "f32[6144, 1024]", p_model_layers_0_mlp_down_proj_weight: "f32[1024, 6144]", p_model_layers_0_input_layernorm_weight: "f32[1024]", p_model_layers_0_post_attention_layernorm_weight: "f32[1024]", p_model_layers_1_self_attn_q_proj_weight: "f32[4096, 1024]", p_model_layers_1_self_attn_k_proj_weight: "f32[4096, 1024]", p_model_layers_1_self_attn_v_proj_weight: "f32[4096, 1024]", p_model_layers_1_self_attn_o_proj_weight: "f32[1024, 4096]", p_model_layers_1_mlp_gate_proj_weight: "f32[6144, 1024]", p_model_layers_1_mlp_up_proj_weight: "f32[6144, 1024]", p_model_layers_1_mlp_down_proj_weight: "f32[1024, 6144]", p_model_layers_1_input_layernorm_weight: "f32[1024]", p_model_layers_1_post_attention_layernorm_weight: "f32[1024]", p_model_norm_weight: "f32[1024]", p_lm_head_weight: "f32[32000, 1024]", b_model_rotary_emb_inv_freq: "f32[64]", input_ids: "i64[s14, s2]", attention_mask: "i64[s14, s10]", position_ids: "i64[s14, s2]", past_key_values_key_cache_0: "f32[s14, 32, s67, 128]", past_key_values_key_cache_1: "f32[s14, 32, s67, 128]", past_key_values_value_cache_0: "f32[s14, 32, s96, 128]", past_key_values_value_cache_1: "f32[s14, 32, s81, 128]"):
#
sym_size_int_15: "Sym(s2)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_18: "Sym(s14)" = torch.ops.aten.sym_size.int(position_ids, 0)
sym_size_int_20: "Sym(s14)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
sym_size_int_21: "Sym(s67)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
sym_size_int_22: "Sym(s14)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_1, 0)
sym_size_int_24: "Sym(s14)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 0)
sym_size_int_26: "Sym(s14)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_1, 0)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s14, s2, 1024]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
#
eq_16: "Sym(True)" = sym_size_int_18 == sym_size_int_20; sym_size_int_20 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_16, "Runtime assertion failed for expression Eq(s41, s89) on node 'eq_16'"); eq_16 = _assert_scalar_default = None
eq_17: "Sym(True)" = sym_size_int_18 == sym_size_int_24; sym_size_int_24 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_17, "Runtime assertion failed for expression Eq(s41, s62) on node 'eq_17'"); eq_17 = _assert_scalar_default_1 = None
eq_18: "Sym(True)" = sym_size_int_18 == sym_size_int_22; sym_size_int_18 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_18, "Runtime assertion failed for expression Eq(s41, s14) on node 'eq_18'"); eq_18 = _assert_scalar_default_2 = None
eq_19: "Sym(True)" = sym_size_int_22 == sym_size_int_26; sym_size_int_26 = None
_assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_19, "Runtime assertion failed for expression Eq(s14, s58) on node 'eq_19'"); eq_19 = _assert_scalar_default_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:416 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s2 + s67)" = sym_size_int_21 + sym_size_int_15
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:415 in forward, code: cache_position = torch.arange(
arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int_21, add, device = device(type='cpu'), pin_memory = False); sym_size_int_21 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:422 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s14, s10]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s2 + s67]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_: "i64[s2 + s67]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None
arange_2: "i64[s14]" = torch.ops.aten.arange.default(sym_size_int_22, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s14, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]); arange_2 = None
reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]); arange_3 = None
reshape_2: "i64[1, 1, s2, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]); arange = None
reshape_3: "i64[1, 1, 1, s2 + s67]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]); add_ = None
expand: "i64[s14, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(reshape, [sym_size_int_22, 1, sym_size_int_15, add]); reshape = None
expand_1: "i64[s14, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_22, 1, sym_size_int_15, add]); reshape_1 = expand_1 = None
expand_2: "i64[s14, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_22, 1, sym_size_int_15, add]); reshape_2 = None
expand_3: "i64[s14, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_22, 1, sym_size_int_15, add]); reshape_3 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
le: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.le.Tensor(expand_3, expand_2); expand_2 = None
and_1: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.__and__.Tensor(new_ones, le); new_ones = le = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:100 in forward, code: return torch.ops.aten.index(x, indices)
index: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]); to = expand = expand_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:422 in forward, code: causal_mask = create_causal_mask(
and_2: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.__and__.Tensor(and_1, index); and_1 = index = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_22, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s14, s2, 128]" = wrap_with_set_grad_enabled[0]
to_7: "f32[s14, s2, 128]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s14, s2, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s14, s2, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s14, s2, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s14, s2, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s14, s2, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:235 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s14, s2, 32, 128]" = torch.ops.aten.view.default(linear, [sym_size_int_22, sym_size_int_15, -1, 128]); linear = None
transpose_1: "f32[s14, 32, s2, 128]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s14, s2, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s14, s2, 32, 128]" = torch.ops.aten.view.default(linear_1, [sym_size_int_22, sym_size_int_15, -1, 128]); linear_1 = None
transpose_2: "f32[s14, 32, s2, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s14, s2, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s14, s2, 32, 128]" = torch.ops.aten.view.default(linear_2, [sym_size_int_22, sym_size_int_15, -1, 128]); linear_2 = None
transpose_3: "f32[s14, 32, s2, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:240 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_3: "f32[s14, 1, s2, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1)
unsqueeze_4: "f32[s14, 1, s2, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1)
mul_4: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
slice_4: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
slice_5: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807); transpose_1 = None
neg: "f32[s14, 32, s2, 64]" = torch.ops.aten.neg.default(slice_5); slice_5 = None
cat_1: "f32[s14, 32, s2, 128]" = torch.ops.aten.cat.default([neg, slice_4], -1); neg = slice_4 = None
mul_5: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_4); cat_1 = None
add_4: "f32[s14, 32, s2, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3); unsqueeze_3 = None
slice_6: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
slice_7: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807); transpose_2 = None
neg_1: "f32[s14, 32, s2, 64]" = torch.ops.aten.neg.default(slice_7); slice_7 = None
cat_2: "f32[s14, 32, s2, 128]" = torch.ops.aten.cat.default([neg_1, slice_6], -1); neg_1 = slice_6 = None
mul_7: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_4); cat_2 = unsqueeze_4 = None
add_5: "f32[s14, 32, s2, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:245 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s14, 32, s2 + s67, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s14, 32, s2 + s96, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:251 in forward, code: attn_output, attn_weights = attention_interface(
slice_8: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(and_2)
slice_9: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_8, 1); slice_8 = None
slice_10: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_9, 2); slice_9 = None
slice_11: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_10, 3, None, add); slice_10 = None
contiguous: "f32[s14, 32, s2, 128]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
scaled_dot_product_attention: "f32[s14, 32, s2, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, cat_3, cat_4, slice_11, scale = 0.08838834764831845); contiguous = slice_11 = None
transpose_4: "f32[s14, s2, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_1: "f32[s14, s2, 32, 128]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:262 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_4: "f32[s14, s2, 4096]" = torch.ops.aten.reshape.default(contiguous_1, [sym_size_int_22, sym_size_int_15, -1]); contiguous_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s14, s2, 1024]" = torch.ops.aten.linear.default(reshape_4, p_model_layers_0_self_attn_o_proj_weight); reshape_4 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:305 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s14, s2, 1024]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s14, s2, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s14, s2, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s14, s2, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s14, s2, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_10: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_10, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(mul_10, torch.float32); mul_10 = None
mul_11: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s14, s2, 6144]" = torch.ops.aten.linear.default(mul_11, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s14, s2, 6144]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s14, s2, 6144]" = torch.ops.aten.linear.default(mul_11, p_model_layers_0_mlp_up_proj_weight); mul_11 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_12: "f32[s14, s2, 6144]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s14, s2, 1024]" = torch.ops.aten.linear.default(mul_12, p_model_layers_0_mlp_down_proj_weight); mul_12 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:311 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s14, s2, 1024]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s14, s2, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s14, s2, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s14, s2, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s14, s2, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_13: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(mul_13, torch.float32); mul_13 = None
mul_14: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_13); p_model_layers_1_input_layernorm_weight = to_13 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s14, s2, 4096]" = torch.ops.aten.linear.default(mul_14, p_model_layers_1_self_attn_q_proj_weight); p_model_layers_1_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:235 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_3: "f32[s14, s2, 32, 128]" = torch.ops.aten.view.default(linear_7, [sym_size_int_22, sym_size_int_15, -1, 128]); linear_7 = None
transpose_5: "f32[s14, 32, s2, 128]" = torch.ops.aten.transpose.int(view_3, 1, 2); view_3 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_8: "f32[s14, s2, 4096]" = torch.ops.aten.linear.default(mul_14, p_model_layers_1_self_attn_k_proj_weight); p_model_layers_1_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_4: "f32[s14, s2, 32, 128]" = torch.ops.aten.view.default(linear_8, [sym_size_int_22, sym_size_int_15, -1, 128]); linear_8 = None
transpose_6: "f32[s14, 32, s2, 128]" = torch.ops.aten.transpose.int(view_4, 1, 2); view_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_9: "f32[s14, s2, 4096]" = torch.ops.aten.linear.default(mul_14, p_model_layers_1_self_attn_v_proj_weight); mul_14 = p_model_layers_1_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_5: "f32[s14, s2, 32, 128]" = torch.ops.aten.view.default(linear_9, [sym_size_int_22, sym_size_int_15, -1, 128]); linear_9 = None
transpose_7: "f32[s14, 32, s2, 128]" = torch.ops.aten.transpose.int(view_5, 1, 2); view_5 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:240 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_5: "f32[s14, 1, s2, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_6: "f32[s14, 1, s2, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_15: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(transpose_5, unsqueeze_5)
slice_12: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 0, 64)
slice_13: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 64, 9223372036854775807); transpose_5 = None
neg_2: "f32[s14, 32, s2, 64]" = torch.ops.aten.neg.default(slice_13); slice_13 = None
cat_5: "f32[s14, 32, s2, 128]" = torch.ops.aten.cat.default([neg_2, slice_12], -1); neg_2 = slice_12 = None
mul_16: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_6); cat_5 = None
add_11: "f32[s14, 32, s2, 128]" = torch.ops.aten.add.Tensor(mul_15, mul_16); mul_15 = mul_16 = None
mul_17: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(transpose_6, unsqueeze_5); unsqueeze_5 = None
slice_14: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 0, 64)
slice_15: "f32[s14, 32, s2, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 64, 9223372036854775807); transpose_6 = None
neg_3: "f32[s14, 32, s2, 64]" = torch.ops.aten.neg.default(slice_15); slice_15 = None
cat_6: "f32[s14, 32, s2, 128]" = torch.ops.aten.cat.default([neg_3, slice_14], -1); neg_3 = slice_14 = None
mul_18: "f32[s14, 32, s2, 128]" = torch.ops.aten.mul.Tensor(cat_6, unsqueeze_6); cat_6 = unsqueeze_6 = None
add_12: "f32[s14, 32, s2, 128]" = torch.ops.aten.add.Tensor(mul_17, mul_18); mul_17 = mul_18 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:245 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_7: "f32[s14, 32, s2 + s67, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_1, add_12], -2); past_key_values_key_cache_1 = add_12 = None
cat_8: "f32[s14, 32, s2 + s81, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_1, transpose_7], -2); past_key_values_value_cache_1 = transpose_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:251 in forward, code: attn_output, attn_weights = attention_interface(
slice_16: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(and_2); and_2 = None
slice_17: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_16, 1); slice_16 = None
slice_18: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_17, 2); slice_17 = None
slice_19: "b8[s14, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_18, 3, None, add); slice_18 = add = None
contiguous_2: "f32[s14, 32, s2, 128]" = torch.ops.aten.contiguous.default(add_11); add_11 = None
scaled_dot_product_attention_1: "f32[s14, 32, s2, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous_2, cat_7, cat_8, slice_19, scale = 0.08838834764831845); contiguous_2 = slice_19 = None
transpose_8: "f32[s14, s2, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2); scaled_dot_product_attention_1 = None
contiguous_3: "f32[s14, s2, 32, 128]" = torch.ops.aten.contiguous.default(transpose_8); transpose_8 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:262 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_5: "f32[s14, s2, 4096]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_22, sym_size_int_15, -1]); contiguous_3 = sym_size_int_22 = sym_size_int_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_10: "f32[s14, s2, 1024]" = torch.ops.aten.linear.default(reshape_5, p_model_layers_1_self_attn_o_proj_weight); reshape_5 = p_model_layers_1_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:305 in forward, code: hidden_states = residual + hidden_states
add_13: "f32[s14, s2, 1024]" = torch.ops.aten.add.Tensor(to_12, linear_10); to_12 = linear_10 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_14: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(add_13, torch.float32); add_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_4: "f32[s14, s2, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_3: "f32[s14, s2, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True); pow_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_14: "f32[s14, s2, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05); mean_3 = None
rsqrt_3: "f32[s14, s2, 1]" = torch.ops.aten.rsqrt.default(add_14); add_14 = None
mul_21: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_3); rsqrt_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_21, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_15: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(mul_21, torch.float32); mul_21 = None
mul_22: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_15); p_model_layers_1_post_attention_layernorm_weight = to_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_11: "f32[s14, s2, 6144]" = torch.ops.aten.linear.default(mul_22, p_model_layers_1_mlp_gate_proj_weight); p_model_layers_1_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu_1: "f32[s14, s2, 6144]" = torch.ops.aten.silu.default(linear_11); linear_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_12: "f32[s14, s2, 6144]" = torch.ops.aten.linear.default(mul_22, p_model_layers_1_mlp_up_proj_weight); mul_22 = p_model_layers_1_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_23: "f32[s14, s2, 6144]" = torch.ops.aten.mul.Tensor(silu_1, linear_12); silu_1 = linear_12 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_13: "f32[s14, s2, 1024]" = torch.ops.aten.linear.default(mul_23, p_model_layers_1_mlp_down_proj_weight); mul_23 = p_model_layers_1_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:311 in forward, code: hidden_states = residual + hidden_states
add_15: "f32[s14, s2, 1024]" = torch.ops.aten.add.Tensor(to_14, linear_13); to_14 = linear_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(add_15, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_16 = None
to_16: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(add_15, torch.float32); add_15 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_5: "f32[s14, s2, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_16, 2)
mean_4: "f32[s14, s2, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True); pow_5 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_16: "f32[s14, s2, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05); mean_4 = None
rsqrt_4: "f32[s14, s2, 1]" = torch.ops.aten.rsqrt.default(add_16); add_16 = None
mul_24: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(to_16, rsqrt_4); to_16 = rsqrt_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(mul_24, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_17 = None
to_17: "f32[s14, s2, 1024]" = torch.ops.aten.to.dtype(mul_24, torch.float32); mul_24 = None
mul_25: "f32[s14, s2, 1024]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_17); p_model_norm_weight = to_17 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_20: "f32[s14, s2, 1024]" = torch.ops.aten.slice.Tensor(mul_25); mul_25 = None
slice_21: "f32[s14, s2, 1024]" = torch.ops.aten.slice.Tensor(slice_20, 1, 0); slice_20 = None
slice_22: "f32[s14, s2, 1024]" = torch.ops.aten.slice.Tensor(slice_21, 2); slice_21 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_14: "f32[s14, s2, 32000]" = torch.ops.aten.linear.default(slice_22, p_lm_head_weight); slice_22 = p_lm_head_weight = None
return (linear_14, cat_3, cat_7, cat_4, cat_8)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[64]", sym_size_int_22: "Sym(s14)", position_ids: "i64[s14, s2]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:96 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
unsqueeze: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_1: "f32[1, 64]" = torch.ops.aten.slice.Tensor(unsqueeze, 1, 0, 9223372036854775807); unsqueeze = None
unsqueeze_1: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 2); slice_1 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32); unsqueeze_1 = None
expand_4: "f32[s14, 64, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_22, -1, 1]); to_1 = sym_size_int_22 = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "f32[s14, 64, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:97 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_2: "i64[s14, s2]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_2: "i64[s14, 1, s2]" = torch.ops.aten.unsqueeze.default(slice_2, 1); slice_2 = None
slice_3: "i64[s14, 1, s2]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[s14, 1, s2]" = torch.ops.aten.to.dtype(slice_3, torch.float32); slice_3 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_2, to_3); submod_3 = to_2 = to_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s14, s2, 128]" = wrap_with_autocast[0]
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s14, s2, 128]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:106 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s14, s2, 128]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s14, s2, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, to_2: "f32[s14, 64, 1]", to_3: "f32[s14, 1, s2]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:101 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s14, 64, 1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s14, 1, s2]" = torch.ops.aten.to.dtype(to_3, torch.float32); to_3 = None
matmul: "f32[s14, 64, s2]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[s14, s2, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:102 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s14, s2, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:103 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s14, s2, 128]" = torch.ops.aten.cos.default(cat)
mul: "f32[s14, s2, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:104 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s14, s2, 128]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s14, s2, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_layers_1_self_attn_q_proj_weight: PARAMETER target='model.layers.1.self_attn.q_proj.weight'
p_model_layers_1_self_attn_k_proj_weight: PARAMETER target='model.layers.1.self_attn.k_proj.weight'
p_model_layers_1_self_attn_v_proj_weight: PARAMETER target='model.layers.1.self_attn.v_proj.weight'
p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
p_model_layers_1_mlp_gate_proj_weight: PARAMETER target='model.layers.1.mlp.gate_proj.weight'
p_model_layers_1_mlp_up_proj_weight: PARAMETER target='model.layers.1.mlp.up_proj.weight'
p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_key_cache_1: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
past_key_values_value_cache_1: USER_INPUT
# outputs
linear_14: USER_OUTPUT
cat_3: USER_OUTPUT
cat_7: USER_OUTPUT
cat_4: USER_OUTPUT
cat_8: USER_OUTPUT
Range constraints: {s14: VR[1, 1024], s2: VR[2, int_oo], s10: VR[4, int_oo], s67: VR[2, int_oo], s96: VR[2, int_oo], s81: VR[2, int_oo]}
doc.plot_legend(
"untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 4.596 seconds)
Related examples

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Export with DynamicCache and guessed dynamic shapes