Note
Go to the end to download the full example code.
Test the export on untrained models¶
Checking the exporter on a whole model takes time as it is usually big but we can create a smaller version with the same architecture. Then fix export issues on such a small model is faster.
codellama/CodeLlama-7b-Python-hf¶
Let’s grab some information about this model. This reuses huggingface_hub API.
import copy
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
)
from onnx_diagnostic.torch_models.hghub.hub_api import (
get_model_info,
get_pretrained_config,
task_from_id,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
model_id = "codellama/CodeLlama-7b-Python-hf"
print("info", get_model_info(model_id))
info ModelInfo(id='codellama/CodeLlama-7b-Python-hf', author='codellama', sha='d4178f5d2eead875e627ec487b23679266319b7f', created_at=datetime.datetime(2023, 8, 24, 16, 31, 28, tzinfo=datetime.timezone.utc), last_modified=datetime.datetime(2024, 4, 12, 14, 16, 26, tzinfo=datetime.timezone.utc), private=False, disabled=False, downloads=9649, downloads_all_time=None, gated=False, gguf=None, inference=None, likes=140, library_name='transformers', tags=['transformers', 'pytorch', 'safetensors', 'llama', 'text-generation', 'llama-2', 'code', 'arxiv:2308.12950', 'license:llama2', 'autotrain_compatible', 'text-generation-inference', 'endpoints_compatible', 'region:us'], pipeline_tag='text-generation', mask_token=None, card_data={'base_model': None, 'datasets': None, 'eval_results': None, 'language': ['code'], 'library_name': None, 'license': 'llama2', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': 'text-generation', 'tags': ['llama-2']}, widget_data=None, model_index=None, config={'architectures': ['LlamaForCausalLM'], 'model_type': 'llama', 'tokenizer_config': {'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'pad_token': None, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}}}, transformers_info=TransformersInfo(auto_model='AutoModelForCausalLM', custom_class=None, pipeline_tag='text-generation', processor='AutoTokenizer'), trending_score=None, siblings=[RepoSibling(rfilename='.gitattributes', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='LICENSE', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='README.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='USE_POLICY.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='generation_config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00001-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00002-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model.safetensors.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00001-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00002-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00003-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model.bin.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='special_tokens_map.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.model', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer_config.json', size=None, blob_id=None, lfs=None)], spaces=['bigcode/bigcode-models-leaderboard', 'Intel/low_bit_open_llm_leaderboard', 'BAAI/open_cn_llm_leaderboard', 'qiantong-xu/toolbench-leaderboard', 'gsaivinay/open_llm_leaderboard', 'EvanTHU/MotionLLM', 'GTBench/GTBench', 'Vikhrmodels/small-shlepa-lb', 'kz-transformers/kaz-llm-lb', 'felixz/open_llm_leaderboard', 'HemaAM/GPT_train_on_LLaMa', '21world/bigcode-models-leaderboard', 'OPTML-Group/UnlearnCanvas-Benchmark', 'whackthejacker/CodeTuneStudio', 'anantgupta129/LitGPT-Pythia-160M', 'BAAI/open_flageval_vlm_leaderboard', 'neubla/neubla-llm-evaluation-board', 'PrarthanaTS/tsai-gpt-from-scratch', 'MadhurGarg/TSAIGPTRedPajama', 'RaviNaik/ERA-SESSION22', 'theangkko/codellama-CodeLlama-7b-Python-hf', 'rodrigomasini/data_only_open_llm_leaderboard', 'Docfile/open_llm_leaderboard', 'Sijuade/GPTNEXTWORD', 'VDebugger/VDebugger-generalist-for-VQA', 'Temuzin64/code_helper', 'piyushgrover/MiniGPT_S22', 'supra-e-acc/Pythia-160M-text-generate', 'venkyyuvy/GPT_redpajama', 'VarunSivamani/GPT-From-Scratch', 'mkthoma/GPT_From_Scratch', 'sanjanatule/GPTNext', 'RashiAgarwal/TSAIGPTRedPajama', 'neuralorbs/DialogGen', 'Navyabhat/ERAV1-Session-22', 'GunaKoppula/ERA-Session-22', 'Vaish2705/ERA_S22', 'smothiki/open_llm_leaderboard', 'aemonge/codellama-CodeLlama-7b-Python-hf', 'sid92/codellama-CodeLlama-7b-Python-hf', 'poubellearman/codellama-CodeLlama-7b-Python-hf', 'IPF/codellama-CodeLlama-7b-Python-hf', 'Chris4K/codellama-CodeLlama-7b-Python-hf', 'CuriosityPdf/codellama-CodeLlama-7b-Python-hf', 'shreefhamed/codellama-CodeLlama-7b-Python-hf', 'markl11/codellama-CodeLlama-7b-Python-hf', '0x1668/open_llm_leaderboard', 'rpratl/codellama-CodeLlama-7b-Python-hf', 'pngwn/open_llm_leaderboard-check', 'asir0z/open_llm_leaderboard', 'LovelySweet/codellama-CodeLlama-7b-Python-hf', 'kbmlcoding/open_llm_leaderboard_free', 'ToletiSri/TSAI_S22', 'aichampions/open_llm_leaderboard', 'Adeco/open_llm_leaderboard', 'anirudh937/open_llm_leaderboard', 'smothiki/open_llm_leaderboard2', 'mjalg/IFEvalTR', 'lastsamuraii/LitGPT-Pythia-160M', 'atlasas/bigcode-models-leaderboard'], safetensors=SafeTensorsInfo(parameters={'BF16': 6738415616}, total=6738415616), security_repo_status=None)
The configuration.
print("config", get_pretrained_config(model_id))
config LlamaConfig {
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 16384,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 1000000,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
The task determines the set of inputs which needs to be created for this input.
print("task", task_from_id(model_id))
task text-generation
Untrained model¶
The function get_untrained_model_with_inputs
.
It loads the pretrained configuration, extracts the task associated
to the model and them creates random inputs and dynamic shapes
for torch.export.export()
.
[get_untrained_model_with_inputs] model_id='codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] architecture='LlamaForCausalLM'
[get_untrained_model_with_inputs] cls='LlamaConfig'
[get_untrained_model_with_inputs] task='text-generation'
model size: 410532864
number of weights: 102633216
fields: {'dynamic_shapes', 'input_kwargs', 'size', 'configuration', 'n_weights', 'model', 'inputs', 'model_kwargs'}
Inputs
print("inputs:", string_type(data["inputs"], with_shape=True))
inputs: dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x128,T1s2x32x30x128], value_cache=#2[T1s2x32x30x128,T1s2x32x30x128]))
Dynamic Shapes
print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
dynamic shapes: {'attention_mask': {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
'input_ids': {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
1: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.seq_length'>},
'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>},
{0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>}],
[{0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>},
{0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>}]],
'position_ids': {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}}
Let’s check the model runs. We still needs to copy the inputs before using the models, the cache is usually modified inplace. Expected outputs can be used later to compute discrepancies.
inputs_copy = copy.deepcopy(data["inputs"])
model = data["model"]
expected_outputs = model(**inputs_copy)
print("outputs:", string_type(expected_outputs, with_shape=True))
outputs: dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x128,T1s2x32x33x128], value_cache=#2[T1s2x32x33x128,T1s2x32x33x128]))
It works.
Export¶
The model uses transformers.cache_utils.DynamicCache
.
It still requires patches to be exportable (control flow).
See onnx_diagnostic.torch_export_patches.bypass_export_some_errors()
with bypass_export_some_errors(patch_transformers=True) as f:
ep = torch.export.export(
model,
(),
kwargs=f(data["inputs"]),
dynamic_shapes=data["dynamic_shapes"],
strict=False,
)
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 768]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 768]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 768]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 768]", p_model_layers_0_self_attn_o_proj_weight: "f32[768, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[6144, 768]", p_model_layers_0_mlp_up_proj_weight: "f32[6144, 768]", p_model_layers_0_mlp_down_proj_weight: "f32[768, 6144]", p_model_layers_0_input_layernorm_weight: "f32[768]", p_model_layers_0_post_attention_layernorm_weight: "f32[768]", p_model_layers_1_self_attn_q_proj_weight: "f32[4096, 768]", p_model_layers_1_self_attn_k_proj_weight: "f32[4096, 768]", p_model_layers_1_self_attn_v_proj_weight: "f32[4096, 768]", p_model_layers_1_self_attn_o_proj_weight: "f32[768, 4096]", p_model_layers_1_mlp_gate_proj_weight: "f32[6144, 768]", p_model_layers_1_mlp_up_proj_weight: "f32[6144, 768]", p_model_layers_1_mlp_down_proj_weight: "f32[768, 6144]", p_model_layers_1_input_layernorm_weight: "f32[768]", p_model_layers_1_post_attention_layernorm_weight: "f32[768]", p_model_norm_weight: "f32[768]", p_lm_head_weight: "f32[32000, 768]", b_model_rotary_emb_inv_freq: "f32[64]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s10]", position_ids: "i64[s0, s1]", past_key_values_key_cache_0: "f32[s0, 32, s10, 128]", past_key_values_key_cache_1: "f32[s0, 32, s10, 128]", past_key_values_value_cache_0: "f32[s0, 32, s10, 128]", past_key_values_value_cache_1: "f32[s0, 32, s10, 128]"):
#
sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_22: "Sym(s10)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s0, s1, 768]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:561 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s10)" = sym_size_int_22 + sym_size_int_21
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:560 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False); sym_size_int_22 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:567 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s10]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s10]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s10]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False); add = None
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s10]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s10]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze: "f32[1, s1, s1 + s10]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1: "f32[1, 1, s1, s1 + s10]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1: "f32[1, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2: "f32[1, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]); slice_2 = None
clone: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[s0, s1 + s10]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_2: "i64[s0, 1, s1 + s10]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_3: "i64[s0, 1, 1, s1 + s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_7: "i64[s0, 1, 1, s1 + s10]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
to: "i64[s0, 1, 1, s1 + s10]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_9: "b8[s0, 1, s1, s1 + s10]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_9, -3.4028234663852886e+38); slice_10 = eq_9 = None
slice_11: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_20, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s0, s1, 128]" = wrap_with_set_grad_enabled[0]
to_7: "f32[s0, s1, 128]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 128]); linear = None
transpose_1: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 128]); linear_1 = None
transpose_2: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 128]); linear_2 = None
transpose_3: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_7: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1)
unsqueeze_8: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1)
mul_4: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_17: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
slice_18: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807); transpose_1 = None
neg: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_19: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
slice_20: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807); transpose_2 = None
neg_1: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_22: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
slice_23: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_22, 2, 0, 9223372036854775807); slice_22 = None
contiguous: "f32[s0, 32, s1, 128]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
scaled_dot_product_attention: "f32[s0, 32, s1, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, cat_3, cat_4, slice_23, scale = 0.08838834764831845); contiguous = slice_23 = None
transpose_4: "f32[s0, s1, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_1: "f32[s0, s1, 32, 128]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_1: "f32[s0, s1, 4096]" = torch.ops.aten.reshape.default(contiguous_1, [sym_size_int_20, sym_size_int_21, -1]); contiguous_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(reshape_1, p_model_layers_0_self_attn_o_proj_weight); reshape_1 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s0, s1, 6144]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s0, s1, 6144]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_13); p_model_layers_1_input_layernorm_weight = to_13 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_q_proj_weight); p_model_layers_1_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_3: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_7, [sym_size_int_20, sym_size_int_21, -1, 128]); linear_7 = None
transpose_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_3, 1, 2); view_3 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_8: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_k_proj_weight); p_model_layers_1_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_4: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_8, [sym_size_int_20, sym_size_int_21, -1, 128]); linear_8 = None
transpose_6: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_4, 1, 2); view_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_9: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_v_proj_weight); mul_12 = p_model_layers_1_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_5: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_9, [sym_size_int_20, sym_size_int_21, -1, 128]); linear_9 = None
transpose_7: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_5, 1, 2); view_5 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_9: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_10: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_13: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_5, unsqueeze_9)
slice_24: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 0, 64)
slice_25: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 64, 9223372036854775807); transpose_5 = None
neg_2: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_25); slice_25 = None
cat_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg_2, slice_24], -1); neg_2 = slice_24 = None
mul_14: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_10); cat_5 = None
add_11: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_13, mul_14); mul_13 = mul_14 = None
mul_15: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_6, unsqueeze_9); unsqueeze_9 = None
slice_26: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 0, 64)
slice_27: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 64, 9223372036854775807); transpose_6 = None
neg_3: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_27); slice_27 = None
cat_6: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg_3, slice_26], -1); neg_3 = slice_26 = None
mul_16: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_6, unsqueeze_10); cat_6 = unsqueeze_10 = None
add_12: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_15, mul_16); mul_15 = mul_16 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_7: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_1, add_12], -2); past_key_values_key_cache_1 = add_12 = None
cat_8: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_1, transpose_7], -2); past_key_values_value_cache_1 = transpose_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_28: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_29: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_28, 1, 0, 9223372036854775807); slice_28 = None
slice_30: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_29, 2, 0, 9223372036854775807); slice_29 = None
contiguous_2: "f32[s0, 32, s1, 128]" = torch.ops.aten.contiguous.default(add_11); add_11 = None
scaled_dot_product_attention_1: "f32[s0, 32, s1, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous_2, cat_7, cat_8, slice_30, scale = 0.08838834764831845); contiguous_2 = slice_30 = None
transpose_8: "f32[s0, s1, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2); scaled_dot_product_attention_1 = None
contiguous_3: "f32[s0, s1, 32, 128]" = torch.ops.aten.contiguous.default(transpose_8); transpose_8 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_2: "f32[s0, s1, 4096]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_21 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_10: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_1_self_attn_o_proj_weight); reshape_2 = p_model_layers_1_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_13: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_12, linear_10); to_12 = linear_10 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_14: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_13, torch.float32); add_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_4: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_3: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True); pow_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_14: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05); mean_3 = None
rsqrt_3: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_14); add_14 = None
mul_17: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_3); rsqrt_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_15: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_17, torch.float32); mul_17 = None
mul_18: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_15); p_model_layers_1_post_attention_layernorm_weight = to_15 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_11: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_gate_proj_weight); p_model_layers_1_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu_1: "f32[s0, s1, 6144]" = torch.ops.aten.silu.default(linear_11); linear_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_12: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_up_proj_weight); mul_18 = p_model_layers_1_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_19: "f32[s0, s1, 6144]" = torch.ops.aten.mul.Tensor(silu_1, linear_12); silu_1 = linear_12 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_13: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(mul_19, p_model_layers_1_mlp_down_proj_weight); mul_19 = p_model_layers_1_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_15: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_14, linear_13); to_14 = linear_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_16: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_15, torch.float32); add_15 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_5: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_16, 2)
mean_4: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True); pow_5 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_16: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05); mean_4 = None
rsqrt_4: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_16); add_16 = None
mul_20: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_16, rsqrt_4); to_16 = rsqrt_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_17: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_20, torch.float32); mul_20 = None
mul_21: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_17); p_model_norm_weight = to_17 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:866 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_31: "f32[s0, s1, 768]" = torch.ops.aten.slice.Tensor(mul_21, 0, 0, 9223372036854775807); mul_21 = None
slice_32: "f32[s0, s1, 768]" = torch.ops.aten.slice.Tensor(slice_31, 1, 0, 9223372036854775807); slice_31 = None
slice_33: "f32[s0, s1, 768]" = torch.ops.aten.slice.Tensor(slice_32, 2, 0, 9223372036854775807); slice_32 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_14: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_33, p_lm_head_weight); slice_33 = p_lm_head_weight = None
return (linear_14, cat_3, cat_7, cat_4, cat_8)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[64]", sym_size_int_20: "Sym(s0)", position_ids: "i64[s0, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_4: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 64]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
expand_1: "f32[s0, 64, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_20, -1, 1]); to_1 = sym_size_int_20 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[s0, s1]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_6: "i64[s0, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[s0, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
to_2: "f32[s0, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[s0, s1, 128]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[s0, s1, 128]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[s0, s1, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[s0, s1, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s0, s1, 128]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[s0, s1, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[s0, 64, 1]", to_2: "f32[s0, 1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[s0, 64, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[s0, 64, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[s0, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[s0, 64, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[s0, s1, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s0, s1, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[s0, s1, 128]" = torch.ops.aten.cos.default(cat)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[s0, s1, 128]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_layers_1_self_attn_q_proj_weight: PARAMETER target='model.layers.1.self_attn.q_proj.weight'
p_model_layers_1_self_attn_k_proj_weight: PARAMETER target='model.layers.1.self_attn.k_proj.weight'
p_model_layers_1_self_attn_v_proj_weight: PARAMETER target='model.layers.1.self_attn.v_proj.weight'
p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
p_model_layers_1_mlp_gate_proj_weight: PARAMETER target='model.layers.1.mlp.gate_proj.weight'
p_model_layers_1_mlp_up_proj_weight: PARAMETER target='model.layers.1.mlp.up_proj.weight'
p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_key_cache_1: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
past_key_values_value_cache_1: USER_INPUT
# outputs
linear_14: USER_OUTPUT
cat_3: USER_OUTPUT
cat_7: USER_OUTPUT
cat_4: USER_OUTPUT
cat_8: USER_OUTPUT
Range constraints: {s0: VR[1, 1024], s1: VR[2, 4096], s1 + s10: VR[4, 8192], s10: VR[1, 4096]}
doc.plot_legend(
"untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 2.713 seconds)
Related examples

Steel method forward to guess the dynamic shapes (with Tiny-LLM)