Note
Go to the end to download the full example code.
Test the export on untrained models¶
Checking the exporter on a whole model takes time as it is usually big but we can create a smaller version with the same architecture. Then fix export issues on such a small model is faster.
codellama/CodeLlama-7b-Python-hf¶
Let’s grab some information about this model. This reuses huggingface_hub API.
import copy
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.hghub import (
get_untrained_model_with_inputs,
)
from onnx_diagnostic.torch_models.hghub.hub_api import (
get_model_info,
get_pretrained_config,
task_from_id,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
model_id = (
"HuggingFaceM4/tiny-random-idefics"
if unit_test_going()
else "codellama/CodeLlama-7b-Python-hf"
)
print(f"model_id={model_id!r}")
print("info", get_model_info(model_id))
model_id='codellama/CodeLlama-7b-Python-hf'
info ModelInfo(id='codellama/CodeLlama-7b-Python-hf', author='codellama', sha='d4178f5d2eead875e627ec487b23679266319b7f', created_at=datetime.datetime(2023, 8, 24, 16, 31, 28, tzinfo=datetime.timezone.utc), last_modified=datetime.datetime(2024, 4, 12, 14, 16, 26, tzinfo=datetime.timezone.utc), private=False, disabled=False, downloads=6711, downloads_all_time=None, gated=False, gguf=None, inference=None, inference_provider_mapping=None, likes=144, library_name='transformers', tags=['transformers', 'pytorch', 'safetensors', 'llama', 'text-generation', 'llama-2', 'code', 'arxiv:2308.12950', 'license:llama2', 'autotrain_compatible', 'text-generation-inference', 'endpoints_compatible', 'region:us'], pipeline_tag='text-generation', mask_token=None, card_data={'base_model': None, 'datasets': None, 'eval_results': None, 'language': ['code'], 'library_name': None, 'license': 'llama2', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': 'text-generation', 'tags': ['llama-2']}, widget_data=None, model_index=None, config={'architectures': ['LlamaForCausalLM'], 'model_type': 'llama', 'tokenizer_config': {'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'pad_token': None, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}}}, transformers_info=TransformersInfo(auto_model='AutoModelForCausalLM', custom_class=None, pipeline_tag='text-generation', processor='AutoTokenizer'), trending_score=None, siblings=[RepoSibling(rfilename='.gitattributes', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='LICENSE', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='README.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='USE_POLICY.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='generation_config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00001-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00002-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model.safetensors.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00001-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00002-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00003-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model.bin.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='special_tokens_map.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.model', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer_config.json', size=None, blob_id=None, lfs=None)], spaces=['bigcode/bigcode-models-leaderboard', 'Intel/low_bit_open_llm_leaderboard', 'BAAI/open_cn_llm_leaderboard', 'qiantong-xu/toolbench-leaderboard', 'gsaivinay/open_llm_leaderboard', 'EvanTHU/MotionLLM', 'GTBench/GTBench', 'Vikhrmodels/small-shlepa-lb', 'Vikhrmodels/DOoM-lb', 'kz-transformers/kaz-llm-lb', '21world/bigcode-models-leaderboard', 'felixz/open_llm_leaderboard', 'BAAI/open_flageval_vlm_leaderboard', 'HemaAM/GPT_train_on_LLaMa', 'OPTML-Group/UnlearnCanvas-Benchmark', 'atlasas/bigcode-models-leaderboard', 'whackthejacker/CodeTuneStudio', 'anantgupta129/LitGPT-Pythia-160M', 'BAAI/EmbodiedVerse', 'neubla/neubla-llm-evaluation-board', 'PrarthanaTS/tsai-gpt-from-scratch', 'MadhurGarg/TSAIGPTRedPajama', 'RaviNaik/ERA-SESSION22', 'theangkko/codellama-CodeLlama-7b-Python-hf', 'rodrigomasini/data_only_open_llm_leaderboard', 'Docfile/open_llm_leaderboard', 'Sijuade/GPTNEXTWORD', 'VDebugger/VDebugger-generalist-for-VQA', 'Temuzin64/code_helper', 'Agents-MCP-Hackathon/universal-api-translator', 'Noveumai/NovaEval', 'piyushgrover/MiniGPT_S22', 'supra-e-acc/Pythia-160M-text-generate', 'venkyyuvy/GPT_redpajama', 'mkthoma/GPT_From_Scratch', 'VarunSivamani/GPT-From-Scratch', 'sanjanatule/GPTNext', 'RashiAgarwal/TSAIGPTRedPajama', 'neuralorbs/DialogGen', 'Navyabhat/ERAV1-Session-22', 'GunaKoppula/ERA-Session-22', 'Vaish2705/ERA_S22', 'smothiki/open_llm_leaderboard', 'aemonge/codellama-CodeLlama-7b-Python-hf', 'sid92/codellama-CodeLlama-7b-Python-hf', 'poubellearman/codellama-CodeLlama-7b-Python-hf', 'IPF/codellama-CodeLlama-7b-Python-hf', 'Chris4K/codellama-CodeLlama-7b-Python-hf', 'CuriosityPdf/codellama-CodeLlama-7b-Python-hf', 'shreefhamed/codellama-CodeLlama-7b-Python-hf', 'markl11/codellama-CodeLlama-7b-Python-hf', '0x1668/open_llm_leaderboard', 'rpratl/codellama-CodeLlama-7b-Python-hf', 'pngwn/open_llm_leaderboard-check', 'asir0z/open_llm_leaderboard', 'LovelySweet/codellama-CodeLlama-7b-Python-hf', 'kbmlcoding/open_llm_leaderboard_free', 'ToletiSri/TSAI_S22', 'aichampions/open_llm_leaderboard', 'Adeco/open_llm_leaderboard', 'anirudh937/open_llm_leaderboard', 'smothiki/open_llm_leaderboard2', 'mjalg/IFEvalTR', 'lastsamuraii/LitGPT-Pythia-160M', 'Wazahat/Pyco', 'feryelb/python-coder', 'moshabann/Virtual_Teachers', 'ahmedsqrd/model_trace'], safetensors=SafeTensorsInfo(parameters={'BF16': 6738415616}, total=6738415616), security_repo_status=None, xet_enabled=None)
The configuration.
print("config", get_pretrained_config(model_id))
config LlamaConfig {
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 1,
"dtype": "bfloat16",
"eos_token_id": 2,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 16384,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 1000000,
"tie_word_embeddings": false,
"transformers_version": "4.56.0.dev0",
"use_cache": true,
"vocab_size": 32000
}
The task determines the set of inputs which needs to be created for this input.
print("task", task_from_id(model_id))
task text-generation
Untrained model¶
The function get_untrained_model_with_inputs
.
It loads the pretrained configuration, extracts the task associated
to the model and them creates random inputs and dynamic shapes
for torch.export.export()
.
[get_untrained_model_with_inputs] model_id='codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] use preinstalled 'codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] architectures=['LlamaForCausalLM']
[get_untrained_model_with_inputs] cls='LlamaConfig'
[get_untrained_model_with_inputs] task='text-generation'
[get_untrained_model_with_inputs] default config._attn_implementation=None
[get_untrained_model_with_inputs] use fct=<function get_inputs at 0x711526ee6700>
model size: 2667659264
number of weights: 666914816
fields: {'configuration', 'model_kwargs', 'dynamic_shapes', 'inputs2', 'n_weights', 'input_kwargs', 'inputs', 'model', 'size', 'task'}
Inputs
print("inputs:", string_type(data["inputs"], with_shape=True))
inputs: dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x128,T1s2x32x30x128], value_cache=#2[T1s2x32x30x128,T1s2x32x30x128]))
Dynamic Shapes
print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
dynamic shapes: {'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
Let’s check the model runs. We still needs to copy the inputs before using the models, the cache is usually modified inplace. Expected outputs can be used later to compute discrepancies.
inputs_copy = copy.deepcopy(data["inputs"])
model = data["model"]
expected_outputs = model(**inputs_copy)
print("outputs:", string_type(expected_outputs, with_shape=True))
outputs: CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x128,T1s2x32x33x128], value_cache=#2[T1s2x32x33x128,T1s2x32x33x128]))
It works.
Export¶
The model uses transformers.cache_utils.DynamicCache
.
It still requires patches to be exportable (control flow).
See onnx_diagnostic.torch_export_patches.torch_export_patches()
with torch_export_patches(patch_transformers=True) as f:
ep = torch.export.export(
model,
(),
kwargs=f(data["inputs"]),
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
strict=False,
)
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 4096]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_o_proj_weight: "f32[4096, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[11008, 4096]", p_model_layers_0_mlp_up_proj_weight: "f32[11008, 4096]", p_model_layers_0_mlp_down_proj_weight: "f32[4096, 11008]", p_model_layers_0_input_layernorm_weight: "f32[4096]", p_model_layers_0_post_attention_layernorm_weight: "f32[4096]", p_model_layers_1_self_attn_q_proj_weight: "f32[4096, 4096]", p_model_layers_1_self_attn_k_proj_weight: "f32[4096, 4096]", p_model_layers_1_self_attn_v_proj_weight: "f32[4096, 4096]", p_model_layers_1_self_attn_o_proj_weight: "f32[4096, 4096]", p_model_layers_1_mlp_gate_proj_weight: "f32[11008, 4096]", p_model_layers_1_mlp_up_proj_weight: "f32[11008, 4096]", p_model_layers_1_mlp_down_proj_weight: "f32[4096, 11008]", p_model_layers_1_input_layernorm_weight: "f32[4096]", p_model_layers_1_post_attention_layernorm_weight: "f32[4096]", p_model_norm_weight: "f32[4096]", p_lm_head_weight: "f32[32000, 4096]", b_model_rotary_emb_inv_freq: "f32[64]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 32, s31, 128]", past_key_values_key_cache_1: "f32[s23, 32, s31, 128]", past_key_values_value_cache_0: "f32[s23, 32, s11, 128]", past_key_values_value_cache_1: "f32[s23, 32, s54, 128]"):
#
sym_size_int_16: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_21: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
sym_size_int_22: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# No stacktrace found for following nodes
empty: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_1: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat: "f32[s23, 32, s31, 128]" = torch.ops.aten.cat.default([empty, past_key_values_key_cache_0], -2); empty = past_key_values_key_cache_0 = None
cat_1: "f32[s23, 32, s11, 128]" = torch.ops.aten.cat.default([empty_1, past_key_values_value_cache_0], -2); empty_1 = past_key_values_value_cache_0 = None
empty_2: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_3: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat_2: "f32[s23, 32, s31, 128]" = torch.ops.aten.cat.default([empty_2, past_key_values_key_cache_1], -2); empty_2 = past_key_values_key_cache_1 = None
cat_3: "f32[s23, 32, s54, 128]" = torch.ops.aten.cat.default([empty_3, past_key_values_value_cache_1], -2); empty_3 = past_key_values_value_cache_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s23, s70, 4096]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:376 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s31 + s70)" = sym_size_int_22 + sym_size_int_16
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False); sym_size_int_22 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None
arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_21, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]); arange_2 = None
reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]); arange_3 = None
reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]); arange = None
reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]); add_ = None
expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_21, 1, sym_size_int_16, add]); reshape = None
expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_21, 1, sym_size_int_16, add]); reshape_1 = expand_1 = None
expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_21, 1, sym_size_int_16, add]); reshape_2 = None
expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_21, 1, sym_size_int_16, add]); reshape_3 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2); expand_2 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); le = None
and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1); new_ones = to_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:99 in forward, code: return torch.ops.aten.index(x, indices)
index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]); to = expand = expand_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); index = None
and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2); and_1 = to_2 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1040 in forward, code: self.inv_freq[None, :, None]
unsqueeze: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_1: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1041 in forward, code: .float()
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32); unsqueeze_1 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1042 in forward, code: .expand(position_ids.shape[0], -1, 1)
expand_4: "f32[s23, 64, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_21, -1, 1]); to_3 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1043 in forward, code: .to(x.device)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s23, 64, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_4 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1045 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1); position_ids = None
slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32); slice_1 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5); submod_3 = to_4 = to_5 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s23, s70, 128]" = wrap_with_autocast[0]
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s23, s70, 128]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1058 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s23, s70, 128]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s23, s70, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_10, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11); p_model_layers_0_input_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear, [sym_size_int_21, sym_size_int_16, -1, 128]); linear = None
transpose_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_1, [sym_size_int_21, sym_size_int_16, -1, 128]); linear_1 = None
transpose_2: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_2, [sym_size_int_21, sym_size_int_16, -1, 128]); linear_2 = None
transpose_3: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_3: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_8, 1)
unsqueeze_4: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_9, 1)
mul_4: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
slice_2: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
slice_3: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807); transpose_1 = None
neg: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_3); slice_3 = None
cat_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg, slice_2], -1); neg = slice_2 = None
mul_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_4); cat_5 = None
add_4: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3); unsqueeze_3 = None
slice_4: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
slice_5: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807); transpose_2 = None
neg_1: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_5); slice_5 = None
cat_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_1, slice_4], -1); neg_1 = slice_4 = None
mul_7: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_6, unsqueeze_4); cat_6 = unsqueeze_4 = None
add_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_7: "f32[s23, 32, s31 + s70, 128]" = torch.ops.aten.cat.default([cat, add_5], -2); cat = add_5 = None
cat_8: "f32[s23, 32, s11 + s70, 128]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2); cat_1 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
slice_6: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add)
scaled_dot_product_attention: "f32[s23, 32, s70, 128]" = torch.ops.aten.scaled_dot_product_attention.default(add_4, cat_7, cat_8, slice_6, scale = 0.08838834764831845); add_4 = slice_6 = None
transpose_4: "f32[s23, s70, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_4: "f32[s23, s70, 4096]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_21, sym_size_int_16, -1]); transpose_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(reshape_4, p_model_layers_0_self_attn_o_proj_weight); reshape_4 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
add_6: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_10, linear_3); to_10 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_6, torch.float32); add_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7); add_7 = None
mul_16: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_16, torch.float32); mul_16 = None
mul_17: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13); p_model_layers_0_post_attention_layernorm_weight = to_13 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s23, s70, 11008]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight); mul_17 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_18: "f32[s23, s70, 11008]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight); mul_18 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
add_8: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_12, linear_6); to_12 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_14: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_8, torch.float32); add_8 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_19: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2); rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_15: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_19, torch.float32); mul_19 = None
mul_20: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_15); p_model_layers_1_input_layernorm_weight = to_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_20, p_model_layers_1_self_attn_q_proj_weight); p_model_layers_1_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_3: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_7, [sym_size_int_21, sym_size_int_16, -1, 128]); linear_7 = None
transpose_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_3, 1, 2); view_3 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_8: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_20, p_model_layers_1_self_attn_k_proj_weight); p_model_layers_1_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_4: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_8, [sym_size_int_21, sym_size_int_16, -1, 128]); linear_8 = None
transpose_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_4, 1, 2); view_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_9: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_20, p_model_layers_1_self_attn_v_proj_weight); mul_20 = p_model_layers_1_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_5: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_9, [sym_size_int_21, sym_size_int_16, -1, 128]); linear_9 = None
transpose_7: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_5, 1, 2); view_5 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_5: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None
unsqueeze_6: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_9, 1); to_9 = None
mul_21: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_5, unsqueeze_5)
slice_7: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 0, 64)
slice_8: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 64, 9223372036854775807); transpose_5 = None
neg_2: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_8); slice_8 = None
cat_9: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_2, slice_7], -1); neg_2 = slice_7 = None
mul_22: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_9, unsqueeze_6); cat_9 = None
add_10: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_21, mul_22); mul_21 = mul_22 = None
mul_23: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_6, unsqueeze_5); unsqueeze_5 = None
slice_9: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 0, 64)
slice_10: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 64, 9223372036854775807); transpose_6 = None
neg_3: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_10); slice_10 = None
cat_10: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_3, slice_9], -1); neg_3 = slice_9 = None
mul_24: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_10, unsqueeze_6); cat_10 = unsqueeze_6 = None
add_11: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_23, mul_24); mul_23 = mul_24 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_11: "f32[s23, 32, s31 + s70, 128]" = torch.ops.aten.cat.default([cat_2, add_11], -2); cat_2 = add_11 = None
cat_12: "f32[s23, 32, s54 + s70, 128]" = torch.ops.aten.cat.default([cat_3, transpose_7], -2); cat_3 = transpose_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
slice_11: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add); and_2 = add = None
scaled_dot_product_attention_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.scaled_dot_product_attention.default(add_10, cat_11, cat_12, slice_11, scale = 0.08838834764831845); add_10 = slice_11 = None
transpose_8: "f32[s23, s70, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2); scaled_dot_product_attention_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_5: "f32[s23, s70, 4096]" = torch.ops.aten.reshape.default(transpose_8, [sym_size_int_21, sym_size_int_16, -1]); transpose_8 = sym_size_int_21 = sym_size_int_16 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_10: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(reshape_5, p_model_layers_1_self_attn_o_proj_weight); reshape_5 = p_model_layers_1_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
add_12: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_14, linear_10); to_14 = linear_10 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(add_12, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_16 = None
to_16: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_12, torch.float32); add_12 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_4: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_16, 2)
mean_3: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True); pow_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_13: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05); mean_3 = None
rsqrt_3: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_13); add_13 = None
mul_25: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_16, rsqrt_3); rsqrt_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(mul_25, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_17 = None
to_17: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_25, torch.float32); mul_25 = None
mul_26: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_17); p_model_layers_1_post_attention_layernorm_weight = to_17 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_11: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_26, p_model_layers_1_mlp_gate_proj_weight); p_model_layers_1_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
silu_1: "f32[s23, s70, 11008]" = torch.ops.aten.silu.default(linear_11); linear_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_12: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_26, p_model_layers_1_mlp_up_proj_weight); mul_26 = p_model_layers_1_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_27: "f32[s23, s70, 11008]" = torch.ops.aten.mul.Tensor(silu_1, linear_12); silu_1 = linear_12 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_13: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_27, p_model_layers_1_mlp_down_proj_weight); mul_27 = p_model_layers_1_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
add_14: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_16, linear_13); to_16 = linear_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_18 = torch.ops.aten._assert_tensor_metadata.default(add_14, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_18 = None
to_18: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_14, torch.float32); add_14 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_5: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_18, 2)
mean_4: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True); pow_5 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_15: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05); mean_4 = None
rsqrt_4: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_15); add_15 = None
mul_28: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_18, rsqrt_4); to_18 = rsqrt_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_19 = torch.ops.aten._assert_tensor_metadata.default(mul_28, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_19 = None
to_19: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_28, torch.float32); mul_28 = None
mul_29: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_19); p_model_norm_weight = to_19 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:473 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_12: "f32[s23, s70, 4096]" = torch.ops.aten.slice.Tensor(mul_29, 1, 0, 9223372036854775807); mul_29 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_14: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_12, p_lm_head_weight); slice_12 = p_lm_head_weight = None
return (linear_14, cat_7, cat_11, cat_8, cat_12)
class submod_1(torch.nn.Module):
def forward(self, to_4: "f32[s23, 64, 1]", to_5: "f32[s23, 1, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1053 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s23, 64, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32); to_5 = None
matmul: "f32[s23, 64, s70]" = torch.ops.aten.matmul.default(to_6, to_7); to_6 = to_7 = None
transpose: "f32[s23, s70, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat_4: "f32[s23, s70, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s23, s70, 128]" = torch.ops.aten.cos.default(cat_4)
mul: "f32[s23, s70, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s23, s70, 128]" = torch.ops.aten.sin.default(cat_4); cat_4 = None
mul_1: "f32[s23, s70, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_layers_1_self_attn_q_proj_weight: PARAMETER target='model.layers.1.self_attn.q_proj.weight'
p_model_layers_1_self_attn_k_proj_weight: PARAMETER target='model.layers.1.self_attn.k_proj.weight'
p_model_layers_1_self_attn_v_proj_weight: PARAMETER target='model.layers.1.self_attn.v_proj.weight'
p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
p_model_layers_1_mlp_gate_proj_weight: PARAMETER target='model.layers.1.mlp.gate_proj.weight'
p_model_layers_1_mlp_up_proj_weight: PARAMETER target='model.layers.1.mlp.up_proj.weight'
p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_key_cache_1: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
past_key_values_value_cache_1: USER_INPUT
# outputs
linear_14: USER_OUTPUT
cat_7: USER_OUTPUT
cat_11: USER_OUTPUT
cat_8: USER_OUTPUT
cat_12: USER_OUTPUT
Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo], s54: VR[2, int_oo]}
doc.plot_legend(
"untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 7.231 seconds)
Related examples

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Intermediate results with (ONNX) ReferenceEvaluator