Test the export on untrained models

Checking the exporter on a whole model takes time as it is usually big but we can create a smaller version with the same architecture. Then fix export issues on such a small model is faster.

codellama/CodeLlama-7b-Python-hf

Let’s grab some information about this model. This reuses huggingface_hub API.

import copy
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.hghub import (
    get_untrained_model_with_inputs,
)
from onnx_diagnostic.torch_models.hghub.hub_api import (
    get_model_info,
    get_pretrained_config,
    task_from_id,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str

model_id = (
    "HuggingFaceM4/tiny-random-idefics"
    if unit_test_going()
    else "codellama/CodeLlama-7b-Python-hf"
)
print(f"model_id={model_id!r}")
print("info", get_model_info(model_id))
model_id='codellama/CodeLlama-7b-Python-hf'
info ModelInfo(id='codellama/CodeLlama-7b-Python-hf', author='codellama', sha='d4178f5d2eead875e627ec487b23679266319b7f', created_at=datetime.datetime(2023, 8, 24, 16, 31, 28, tzinfo=datetime.timezone.utc), last_modified=datetime.datetime(2024, 4, 12, 14, 16, 26, tzinfo=datetime.timezone.utc), private=False, disabled=False, downloads=7877, downloads_all_time=None, gated=False, gguf=None, inference=None, inference_provider_mapping=None, likes=144, library_name='transformers', tags=['transformers', 'pytorch', 'safetensors', 'llama', 'text-generation', 'llama-2', 'code', 'arxiv:2308.12950', 'license:llama2', 'autotrain_compatible', 'text-generation-inference', 'endpoints_compatible', 'region:us'], pipeline_tag='text-generation', mask_token=None, card_data={'base_model': None, 'datasets': None, 'eval_results': None, 'language': ['code'], 'library_name': None, 'license': 'llama2', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': 'text-generation', 'tags': ['llama-2']}, widget_data=None, model_index=None, config={'architectures': ['LlamaForCausalLM'], 'model_type': 'llama', 'tokenizer_config': {'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'pad_token': None, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}}}, transformers_info=TransformersInfo(auto_model='AutoModelForCausalLM', custom_class=None, pipeline_tag='text-generation', processor='AutoTokenizer'), trending_score=None, siblings=[RepoSibling(rfilename='.gitattributes', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='LICENSE', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='README.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='USE_POLICY.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='generation_config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00001-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00002-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model.safetensors.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00001-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00002-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00003-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model.bin.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='special_tokens_map.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.model', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer_config.json', size=None, blob_id=None, lfs=None)], spaces=['bigcode/bigcode-models-leaderboard', 'Intel/low_bit_open_llm_leaderboard', 'BAAI/open_cn_llm_leaderboard', 'qiantong-xu/toolbench-leaderboard', 'gsaivinay/open_llm_leaderboard', 'EvanTHU/MotionLLM', 'GTBench/GTBench', 'Vikhrmodels/small-shlepa-lb', 'Vikhrmodels/DOoM-lb', 'kz-transformers/kaz-llm-lb', '21world/bigcode-models-leaderboard', 'felixz/open_llm_leaderboard', 'BAAI/open_flageval_vlm_leaderboard', 'HemaAM/GPT_train_on_LLaMa', 'OPTML-Group/UnlearnCanvas-Benchmark', 'atlasas/bigcode-models-leaderboard', 'whackthejacker/CodeTuneStudio', 'anantgupta129/LitGPT-Pythia-160M', 'BAAI/EmbodiedVerse', 'neubla/neubla-llm-evaluation-board', 'PrarthanaTS/tsai-gpt-from-scratch', 'MadhurGarg/TSAIGPTRedPajama', 'RaviNaik/ERA-SESSION22', 'theangkko/codellama-CodeLlama-7b-Python-hf', 'rodrigomasini/data_only_open_llm_leaderboard', 'Docfile/open_llm_leaderboard', 'Sijuade/GPTNEXTWORD', 'VDebugger/VDebugger-generalist-for-VQA', 'Temuzin64/code_helper', 'Agents-MCP-Hackathon/universal-api-translator', 'Noveumai/NovaEval', 'piyushgrover/MiniGPT_S22', 'supra-e-acc/Pythia-160M-text-generate', 'venkyyuvy/GPT_redpajama', 'mkthoma/GPT_From_Scratch', 'VarunSivamani/GPT-From-Scratch', 'sanjanatule/GPTNext', 'RashiAgarwal/TSAIGPTRedPajama', 'neuralorbs/DialogGen', 'Navyabhat/ERAV1-Session-22', 'GunaKoppula/ERA-Session-22', 'Vaish2705/ERA_S22', 'smothiki/open_llm_leaderboard', 'aemonge/codellama-CodeLlama-7b-Python-hf', 'sid92/codellama-CodeLlama-7b-Python-hf', 'poubellearman/codellama-CodeLlama-7b-Python-hf', 'IPF/codellama-CodeLlama-7b-Python-hf', 'Chris4K/codellama-CodeLlama-7b-Python-hf', 'CuriosityPdf/codellama-CodeLlama-7b-Python-hf', 'shreefhamed/codellama-CodeLlama-7b-Python-hf', 'markl11/codellama-CodeLlama-7b-Python-hf', '0x1668/open_llm_leaderboard', 'rpratl/codellama-CodeLlama-7b-Python-hf', 'pngwn/open_llm_leaderboard-check', 'asir0z/open_llm_leaderboard', 'LovelySweet/codellama-CodeLlama-7b-Python-hf', 'kbmlcoding/open_llm_leaderboard_free', 'ToletiSri/TSAI_S22', 'aichampions/open_llm_leaderboard', 'Adeco/open_llm_leaderboard', 'anirudh937/open_llm_leaderboard', 'smothiki/open_llm_leaderboard2', 'mjalg/IFEvalTR', 'lastsamuraii/LitGPT-Pythia-160M', 'Wazahat/Pyco', 'feryelb/python-coder', 'moshabann/Virtual_Teachers', 'ahmedsqrd/model_trace'], safetensors=SafeTensorsInfo(parameters={'BF16': 6738415616}, total=6738415616), security_repo_status=None, xet_enabled=None)

The configuration.

print("config", get_pretrained_config(model_id))
config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 16384,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 1000000,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.56.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

The task determines the set of inputs which needs to be created for this input.

print("task", task_from_id(model_id))
task text-generation

Untrained model

The function get_untrained_model_with_inputs. It loads the pretrained configuration, extracts the task associated to the model and them creates random inputs and dynamic shapes for torch.export.export().

data = get_untrained_model_with_inputs(model_id, verbose=1)
print("model size:", data["size"])
print("number of weights:", data["n_weights"])
print("fields:", set(data))
[get_untrained_model_with_inputs] model_id='codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] use preinstalled 'codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] architectures=['LlamaForCausalLM']
[get_untrained_model_with_inputs] cls='LlamaConfig'
[get_untrained_model_with_inputs] task='text-generation'
[get_untrained_model_with_inputs] default config._attn_implementation=None
[get_untrained_model_with_inputs] use fct=<function get_inputs at 0x79b44f864c20>
model size: 2667659264
number of weights: 666914816
fields: {'input_kwargs', 'configuration', 'model', 'model_kwargs', 'size', 'inputs2', 'task', 'n_weights', 'inputs', 'dynamic_shapes'}

Inputs

print("inputs:", string_type(data["inputs"], with_shape=True))
inputs: dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x128,T1s2x32x30x128], value_cache=#2[T1s2x32x30x128,T1s2x32x30x128]))

Dynamic Shapes

print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
dynamic shapes: {'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
                      {0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
                      {0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
 'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}

Let’s check the model runs. We still needs to copy the inputs before using the models, the cache is usually modified inplace. Expected outputs can be used later to compute discrepancies.

inputs_copy = copy.deepcopy(data["inputs"])
model = data["model"]
expected_outputs = model(**inputs_copy)

print("outputs:", string_type(expected_outputs, with_shape=True))
outputs: CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x128,T1s2x32x33x128], value_cache=#2[T1s2x32x33x128,T1s2x32x33x128]))

It works.

Export

The model uses transformers.cache_utils.DynamicCache. It still requires patches to be exportable (control flow). See onnx_diagnostic.torch_export_patches.torch_export_patches()

with torch_export_patches(patch_transformers=True) as f:
    ep = torch.export.export(
        model,
        (),
        kwargs=f(data["inputs"]),
        dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
        strict=False,
    )
    print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 4096]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 4096]", p_model_layers_0_self_attn_o_proj_weight: "f32[4096, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[11008, 4096]", p_model_layers_0_mlp_up_proj_weight: "f32[11008, 4096]", p_model_layers_0_mlp_down_proj_weight: "f32[4096, 11008]", p_model_layers_0_input_layernorm_weight: "f32[4096]", p_model_layers_0_post_attention_layernorm_weight: "f32[4096]", p_model_layers_1_self_attn_q_proj_weight: "f32[4096, 4096]", p_model_layers_1_self_attn_k_proj_weight: "f32[4096, 4096]", p_model_layers_1_self_attn_v_proj_weight: "f32[4096, 4096]", p_model_layers_1_self_attn_o_proj_weight: "f32[4096, 4096]", p_model_layers_1_mlp_gate_proj_weight: "f32[11008, 4096]", p_model_layers_1_mlp_up_proj_weight: "f32[11008, 4096]", p_model_layers_1_mlp_down_proj_weight: "f32[4096, 11008]", p_model_layers_1_input_layernorm_weight: "f32[4096]", p_model_layers_1_post_attention_layernorm_weight: "f32[4096]", p_model_norm_weight: "f32[4096]", p_lm_head_weight: "f32[32000, 4096]", b_model_rotary_emb_inv_freq: "f32[64]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 32, s31, 128]", past_key_values_key_cache_1: "f32[s23, 32, s31, 128]", past_key_values_value_cache_0: "f32[s23, 32, s11, 128]", past_key_values_value_cache_1: "f32[s23, 32, s54, 128]"):
             #
            sym_size_int_16: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_19: "Sym(s23)" = torch.ops.aten.sym_size.int(position_ids, 0)
            sym_size_int_21: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
            sym_size_int_22: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
            sym_size_int_23: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_1, 0)
            eq_8: "Sym(True)" = sym_size_int_19 == sym_size_int_21;  sym_size_int_19 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_8, "Runtime assertion failed for expression Eq(s44, s23) on node 'eq_8'");  eq_8 = _assert_scalar_default = None
            eq_9: "Sym(True)" = sym_size_int_21 == sym_size_int_23;  sym_size_int_23 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_9, "Runtime assertion failed for expression Eq(s23, s27) on node 'eq_9'");  eq_9 = _assert_scalar_default_1 = None

            # No stacktrace found for following nodes
            empty: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_1: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat: "f32[s23, 32, s31, 128]" = torch.ops.aten.cat.default([empty, past_key_values_key_cache_0], -2);  empty = past_key_values_key_cache_0 = None
            cat_1: "f32[s23, 32, s11, 128]" = torch.ops.aten.cat.default([empty_1, past_key_values_value_cache_0], -2);  empty_1 = past_key_values_value_cache_0 = None
            empty_2: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_3: "f32[s23, 32, 0, 128]" = torch.ops.aten.empty.memory_format([sym_size_int_21, 32, 0, 128], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat_2: "f32[s23, 32, s31, 128]" = torch.ops.aten.cat.default([empty_2, past_key_values_key_cache_1], -2);  empty_2 = past_key_values_key_cache_1 = None
            cat_3: "f32[s23, 32, s54, 128]" = torch.ops.aten.cat.default([empty_3, past_key_values_value_cache_1], -2);  empty_3 = past_key_values_value_cache_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[s23, s70, 4096]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:376 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s31 + s70)" = sym_size_int_22 + sym_size_int_16

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
            arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_22 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_21, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]);  arange_2 = None
            reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]);  arange_3 = None
            reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]);  arange = None
            reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]);  add_ = None
            expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_21, 1, sym_size_int_16, add]);  reshape = None
            expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_21, 1, sym_size_int_16, add]);  reshape_1 = expand_1 = None
            expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_21, 1, sym_size_int_16, add]);  reshape_2 = None
            expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_21, 1, sym_size_int_16, add]);  reshape_3 = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
            le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2);  expand_2 = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
            and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:99 in forward, code: return torch.ops.aten.index(x, indices)
            index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]);  to = expand = expand_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
            to_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  index = None
            and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_21, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1057 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_8: "f32[s23, s70, 128]" = wrap_with_set_grad_enabled[0]
            to_9: "f32[s23, s70, 128]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear, [sym_size_int_21, sym_size_int_16, -1, 128]);  linear = None
            transpose_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_1, [sym_size_int_21, sym_size_int_16, -1, 128]);  linear_1 = None
            transpose_2: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_2, [sym_size_int_21, sym_size_int_16, -1, 128]);  linear_2 = None
            transpose_3: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_3: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_8, 1)
            unsqueeze_4: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_9, 1)
            mul_4: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
            slice_2: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
            slice_3: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807);  transpose_1 = None
            neg: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_3);  slice_3 = None
            cat_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg, slice_2], -1);  neg = slice_2 = None
            mul_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_4);  cat_5 = None
            add_4: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3);  unsqueeze_3 = None
            slice_4: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
            slice_5: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_5);  slice_5 = None
            cat_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_1, slice_4], -1);  neg_1 = slice_4 = None
            mul_7: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_6, unsqueeze_4);  cat_6 = unsqueeze_4 = None
            add_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_7: "f32[s23, 32, s31 + s70, 128]" = torch.ops.aten.cat.default([cat, add_5], -2);  cat = add_5 = None
            cat_8: "f32[s23, 32, s11 + s70, 128]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2);  cat_1 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
            slice_6: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add)
            scaled_dot_product_attention: "f32[s23, 32, s70, 128]" = torch.ops.aten.scaled_dot_product_attention.default(add_4, cat_7, cat_8, slice_6, scale = 0.08838834764831845);  add_4 = slice_6 = None
            transpose_4: "f32[s23, s70, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_4: "f32[s23, s70, 4096]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_21, sym_size_int_16, -1]);  transpose_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(reshape_4, p_model_layers_0_self_attn_o_proj_weight);  reshape_4 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
            add_6: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
            mul_16: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_16, torch.float32);  mul_16 = None
            mul_17: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13);  p_model_layers_0_post_attention_layernorm_weight = to_13 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s23, s70, 11008]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight);  mul_17 = p_model_layers_0_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_18: "f32[s23, s70, 11008]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight);  mul_18 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
            add_8: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_12, linear_6);  to_12 = linear_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_8, torch.float32);  add_8 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9);  add_9 = None
            mul_19: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2);  rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_19, torch.float32);  mul_19 = None
            mul_20: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_15);  p_model_layers_1_input_layernorm_weight = to_15 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_20, p_model_layers_1_self_attn_q_proj_weight);  p_model_layers_1_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_3: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_7, [sym_size_int_21, sym_size_int_16, -1, 128]);  linear_7 = None
            transpose_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_3, 1, 2);  view_3 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_8: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_20, p_model_layers_1_self_attn_k_proj_weight);  p_model_layers_1_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_4: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_8, [sym_size_int_21, sym_size_int_16, -1, 128]);  linear_8 = None
            transpose_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_4, 1, 2);  view_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_9: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_20, p_model_layers_1_self_attn_v_proj_weight);  mul_20 = p_model_layers_1_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_5: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_9, [sym_size_int_21, sym_size_int_16, -1, 128]);  linear_9 = None
            transpose_7: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_5, 1, 2);  view_5 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_5: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
            unsqueeze_6: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
            mul_21: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_5, unsqueeze_5)
            slice_7: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 0, 64)
            slice_8: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 64, 9223372036854775807);  transpose_5 = None
            neg_2: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_8);  slice_8 = None
            cat_9: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_2, slice_7], -1);  neg_2 = slice_7 = None
            mul_22: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_9, unsqueeze_6);  cat_9 = None
            add_10: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_21, mul_22);  mul_21 = mul_22 = None
            mul_23: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_6, unsqueeze_5);  unsqueeze_5 = None
            slice_9: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 0, 64)
            slice_10: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 64, 9223372036854775807);  transpose_6 = None
            neg_3: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_10);  slice_10 = None
            cat_10: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_3, slice_9], -1);  neg_3 = slice_9 = None
            mul_24: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_10, unsqueeze_6);  cat_10 = unsqueeze_6 = None
            add_11: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_23, mul_24);  mul_23 = mul_24 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_11: "f32[s23, 32, s31 + s70, 128]" = torch.ops.aten.cat.default([cat_2, add_11], -2);  cat_2 = add_11 = None
            cat_12: "f32[s23, 32, s54 + s70, 128]" = torch.ops.aten.cat.default([cat_3, transpose_7], -2);  cat_3 = transpose_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
            slice_11: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add);  and_2 = add = None
            scaled_dot_product_attention_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.scaled_dot_product_attention.default(add_10, cat_11, cat_12, slice_11, scale = 0.08838834764831845);  add_10 = slice_11 = None
            transpose_8: "f32[s23, s70, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2);  scaled_dot_product_attention_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_5: "f32[s23, s70, 4096]" = torch.ops.aten.reshape.default(transpose_8, [sym_size_int_21, sym_size_int_16, -1]);  transpose_8 = sym_size_int_21 = sym_size_int_16 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_10: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(reshape_5, p_model_layers_1_self_attn_o_proj_weight);  reshape_5 = p_model_layers_1_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
            add_12: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_14, linear_10);  to_14 = linear_10 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(add_12, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_16 = None
            to_16: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_12, torch.float32);  add_12 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_4: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_16, 2)
            mean_3: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True);  pow_4 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_13: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05);  mean_3 = None
            rsqrt_3: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_13);  add_13 = None
            mul_25: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_16, rsqrt_3);  rsqrt_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(mul_25, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_17 = None
            to_17: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_25, torch.float32);  mul_25 = None
            mul_26: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_17);  p_model_layers_1_post_attention_layernorm_weight = to_17 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_11: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_26, p_model_layers_1_mlp_gate_proj_weight);  p_model_layers_1_mlp_gate_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
            silu_1: "f32[s23, s70, 11008]" = torch.ops.aten.silu.default(linear_11);  linear_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_12: "f32[s23, s70, 11008]" = torch.ops.aten.linear.default(mul_26, p_model_layers_1_mlp_up_proj_weight);  mul_26 = p_model_layers_1_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_27: "f32[s23, s70, 11008]" = torch.ops.aten.mul.Tensor(silu_1, linear_12);  silu_1 = linear_12 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_13: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_27, p_model_layers_1_mlp_down_proj_weight);  mul_27 = p_model_layers_1_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
            add_14: "f32[s23, s70, 4096]" = torch.ops.aten.add.Tensor(to_16, linear_13);  to_16 = linear_13 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_18 = torch.ops.aten._assert_tensor_metadata.default(add_14, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_18 = None
            to_18: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(add_14, torch.float32);  add_14 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_5: "f32[s23, s70, 4096]" = torch.ops.aten.pow.Tensor_Scalar(to_18, 2)
            mean_4: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True);  pow_5 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_15: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05);  mean_4 = None
            rsqrt_4: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_15);  add_15 = None
            mul_28: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(to_18, rsqrt_4);  to_18 = rsqrt_4 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_19 = torch.ops.aten._assert_tensor_metadata.default(mul_28, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_19 = None
            to_19: "f32[s23, s70, 4096]" = torch.ops.aten.to.dtype(mul_28, torch.float32);  mul_28 = None
            mul_29: "f32[s23, s70, 4096]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_19);  p_model_norm_weight = to_19 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:479 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_12: "f32[s23, s70, 4096]" = torch.ops.aten.slice.Tensor(mul_29, 1, 0, 9223372036854775807);  mul_29 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_14: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_12, p_lm_head_weight);  slice_12 = p_lm_head_weight = None
            return (linear_14, cat_7, cat_11, cat_8, cat_12)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[64]", sym_size_int_21: "Sym(s23)", position_ids: "i64[s23, s70]"):
                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1039 in forward, code: self.inv_freq[None, :, None]
                unsqueeze: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                unsqueeze_1: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2);  unsqueeze = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1040 in forward, code: .float()
                _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                to_3: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32);  unsqueeze_1 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1041 in forward, code: .expand(position_ids.shape[0], -1, 1)
                expand_4: "f32[s23, 64, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_21, -1, 1]);  to_3 = sym_size_int_21 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1042 in forward, code: .to(x.device)
                _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                to_4: "f32[s23, 64, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_4 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1044 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1);  position_ids = None
                slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
                _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32);  slice_1 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: cos = emb.cos() * self.attention_scaling
                mul: "f32[s23, s70, 128]" = wrap_with_autocast[0]

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_1: "f32[s23, s70, 128]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1057 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
                to_8: "f32[s23, s70, 128]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
                to_9: "f32[s23, s70, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_8, to_9)

            class submod_1(torch.nn.Module):
                def forward(self, to_4: "f32[s23, 64, 1]", to_5: "f32[s23, 1, s70]"):
                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1052 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                    to_6: "f32[s23, 64, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                    _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                    to_7: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                    matmul: "f32[s23, 64, s70]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                    transpose: "f32[s23, s70, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1053 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat_4: "f32[s23, s70, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[s23, s70, 128]" = torch.ops.aten.cos.default(cat_4)
                    mul: "f32[s23, s70, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[s23, s70, 128]" = torch.ops.aten.sin.default(cat_4);  cat_4 = None
                    mul_1: "f32[s23, s70, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                    return (mul, mul_1)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_layers_1_self_attn_q_proj_weight: PARAMETER target='model.layers.1.self_attn.q_proj.weight'
    p_model_layers_1_self_attn_k_proj_weight: PARAMETER target='model.layers.1.self_attn.k_proj.weight'
    p_model_layers_1_self_attn_v_proj_weight: PARAMETER target='model.layers.1.self_attn.v_proj.weight'
    p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
    p_model_layers_1_mlp_gate_proj_weight: PARAMETER target='model.layers.1.mlp.gate_proj.weight'
    p_model_layers_1_mlp_up_proj_weight: PARAMETER target='model.layers.1.mlp.up_proj.weight'
    p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
    p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
    p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_key_cache_1: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT
    past_key_values_value_cache_1: USER_INPUT

    # outputs
    linear_14: USER_OUTPUT
    cat_7: USER_OUTPUT
    cat_11: USER_OUTPUT
    cat_8: USER_OUTPUT
    cat_12: USER_OUTPUT

Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo], s54: VR[2, int_oo]}
doc.plot_legend(
    "untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
)
plot export hub codellama

Total running time of the script: (0 minutes 10.954 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Intermediate results with (ONNX) ReferenceEvaluator

Intermediate results with (ONNX) ReferenceEvaluator

Gallery generated by Sphinx-Gallery