Test the export on untrained models

Checking the exporter on a whole model takes time as it is usually big but we can create a smaller version with the same architecture. Then fix export issues on such a small model is faster.

codellama/CodeLlama-7b-Python-hf

Let’s grab some information about this model. This reuses huggingface_hub API.

import copy
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.hghub import (
    get_untrained_model_with_inputs,
)
from onnx_diagnostic.torch_models.hghub.hub_api import (
    get_model_info,
    get_pretrained_config,
    task_from_id,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str

model_id = (
    "HuggingFaceM4/tiny-random-idefics"
    if unit_test_going()
    else "codellama/CodeLlama-7b-Python-hf"
)
print(f"model_id={model_id!r}")
print("info", get_model_info(model_id))
model_id='codellama/CodeLlama-7b-Python-hf'
info ModelInfo(id='codellama/CodeLlama-7b-Python-hf', author='codellama', sha='d4178f5d2eead875e627ec487b23679266319b7f', created_at=datetime.datetime(2023, 8, 24, 16, 31, 28, tzinfo=datetime.timezone.utc), last_modified=datetime.datetime(2024, 4, 12, 14, 16, 26, tzinfo=datetime.timezone.utc), private=False, disabled=False, downloads=22965, downloads_all_time=None, gated=False, gguf=None, inference=None, inference_provider_mapping=None, likes=142, library_name='transformers', tags=['transformers', 'pytorch', 'safetensors', 'llama', 'text-generation', 'llama-2', 'code', 'arxiv:2308.12950', 'license:llama2', 'autotrain_compatible', 'text-generation-inference', 'endpoints_compatible', 'region:us'], pipeline_tag='text-generation', mask_token=None, card_data={'base_model': None, 'datasets': None, 'eval_results': None, 'language': ['code'], 'library_name': None, 'license': 'llama2', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': 'text-generation', 'tags': ['llama-2']}, widget_data=None, model_index=None, config={'architectures': ['LlamaForCausalLM'], 'model_type': 'llama', 'tokenizer_config': {'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'pad_token': None, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}}}, transformers_info=TransformersInfo(auto_model='AutoModelForCausalLM', custom_class=None, pipeline_tag='text-generation', processor='AutoTokenizer'), trending_score=None, siblings=[RepoSibling(rfilename='.gitattributes', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='LICENSE', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='README.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='USE_POLICY.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='generation_config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00001-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00002-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model.safetensors.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00001-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00002-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00003-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model.bin.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='special_tokens_map.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.model', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer_config.json', size=None, blob_id=None, lfs=None)], spaces=['bigcode/bigcode-models-leaderboard', 'Intel/low_bit_open_llm_leaderboard', 'BAAI/open_cn_llm_leaderboard', 'qiantong-xu/toolbench-leaderboard', 'gsaivinay/open_llm_leaderboard', 'EvanTHU/MotionLLM', 'GTBench/GTBench', 'Vikhrmodels/small-shlepa-lb', 'Vikhrmodels/DOoM-lb', 'kz-transformers/kaz-llm-lb', 'felixz/open_llm_leaderboard', 'BAAI/open_flageval_vlm_leaderboard', 'HemaAM/GPT_train_on_LLaMa', '21world/bigcode-models-leaderboard', 'OPTML-Group/UnlearnCanvas-Benchmark', 'whackthejacker/CodeTuneStudio', 'anantgupta129/LitGPT-Pythia-160M', 'BAAI/EmbodiedVerse', 'neubla/neubla-llm-evaluation-board', 'PrarthanaTS/tsai-gpt-from-scratch', 'MadhurGarg/TSAIGPTRedPajama', 'RaviNaik/ERA-SESSION22', 'theangkko/codellama-CodeLlama-7b-Python-hf', 'rodrigomasini/data_only_open_llm_leaderboard', 'Docfile/open_llm_leaderboard', 'Sijuade/GPTNEXTWORD', 'VDebugger/VDebugger-generalist-for-VQA', 'Temuzin64/code_helper', 'Agents-MCP-Hackathon/universal-api-translator', 'piyushgrover/MiniGPT_S22', 'supra-e-acc/Pythia-160M-text-generate', 'venkyyuvy/GPT_redpajama', 'mkthoma/GPT_From_Scratch', 'VarunSivamani/GPT-From-Scratch', 'sanjanatule/GPTNext', 'RashiAgarwal/TSAIGPTRedPajama', 'neuralorbs/DialogGen', 'Navyabhat/ERAV1-Session-22', 'GunaKoppula/ERA-Session-22', 'Vaish2705/ERA_S22', 'smothiki/open_llm_leaderboard', 'aemonge/codellama-CodeLlama-7b-Python-hf', 'sid92/codellama-CodeLlama-7b-Python-hf', 'poubellearman/codellama-CodeLlama-7b-Python-hf', 'IPF/codellama-CodeLlama-7b-Python-hf', 'Chris4K/codellama-CodeLlama-7b-Python-hf', 'CuriosityPdf/codellama-CodeLlama-7b-Python-hf', 'shreefhamed/codellama-CodeLlama-7b-Python-hf', 'markl11/codellama-CodeLlama-7b-Python-hf', '0x1668/open_llm_leaderboard', 'rpratl/codellama-CodeLlama-7b-Python-hf', 'pngwn/open_llm_leaderboard-check', 'asir0z/open_llm_leaderboard', 'LovelySweet/codellama-CodeLlama-7b-Python-hf', 'kbmlcoding/open_llm_leaderboard_free', 'ToletiSri/TSAI_S22', 'aichampions/open_llm_leaderboard', 'Adeco/open_llm_leaderboard', 'anirudh937/open_llm_leaderboard', 'smothiki/open_llm_leaderboard2', 'mjalg/IFEvalTR', 'lastsamuraii/LitGPT-Pythia-160M', 'atlasas/bigcode-models-leaderboard', 'Wazahat/Pyco', 'feryelb/python-coder', 'moshabann/Virtual_Teachers'], safetensors=SafeTensorsInfo(parameters={'BF16': 6738415616}, total=6738415616), security_repo_status=None, xet_enabled=None)

The configuration.

print("config", get_pretrained_config(model_id))
config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 16384,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 1000000,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.54.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

The task determines the set of inputs which needs to be created for this input.

print("task", task_from_id(model_id))
task text-generation

Untrained model

The function get_untrained_model_with_inputs. It loads the pretrained configuration, extracts the task associated to the model and them creates random inputs and dynamic shapes for torch.export.export().

data = get_untrained_model_with_inputs(model_id, verbose=1)
print("model size:", data["size"])
print("number of weights:", data["n_weights"])
print("fields:", set(data))
[get_untrained_model_with_inputs] model_id='codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] use preinstalled 'codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] architectures=['LlamaForCausalLM']
[get_untrained_model_with_inputs] cls='LlamaConfig'
[get_untrained_model_with_inputs] task='text-generation'
[get_untrained_model_with_inputs] use fct=<function get_inputs at 0x71530ffbc4a0>
model size: 547377152
number of weights: 136844288
fields: {'configuration', 'task', 'inputs', 'dynamic_shapes', 'n_weights', 'model', 'input_kwargs', 'size', 'model_kwargs'}

Inputs

print("inputs:", string_type(data["inputs"], with_shape=True))
inputs: dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x128,T1s2x32x30x128], value_cache=#2[T1s2x32x30x128,T1s2x32x30x128]))

Dynamic Shapes

print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
dynamic shapes: {'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
                      {0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
                      {0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
 'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}

Let’s check the model runs. We still needs to copy the inputs before using the models, the cache is usually modified inplace. Expected outputs can be used later to compute discrepancies.

inputs_copy = copy.deepcopy(data["inputs"])
model = data["model"]
expected_outputs = model(**inputs_copy)

print("outputs:", string_type(expected_outputs, with_shape=True))
outputs: CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x128,T1s2x32x33x128], value_cache=#2[T1s2x32x33x128,T1s2x32x33x128]))

It works.

Export

The model uses transformers.cache_utils.DynamicCache. It still requires patches to be exportable (control flow). See onnx_diagnostic.torch_export_patches.torch_export_patches()

with torch_export_patches(patch_transformers=True) as f:
    ep = torch.export.export(
        model,
        (),
        kwargs=f(data["inputs"]),
        dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
        strict=False,
    )
    print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 1024]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 1024]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 1024]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 1024]", p_model_layers_0_self_attn_o_proj_weight: "f32[1024, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[6144, 1024]", p_model_layers_0_mlp_up_proj_weight: "f32[6144, 1024]", p_model_layers_0_mlp_down_proj_weight: "f32[1024, 6144]", p_model_layers_0_input_layernorm_weight: "f32[1024]", p_model_layers_0_post_attention_layernorm_weight: "f32[1024]", p_model_layers_1_self_attn_q_proj_weight: "f32[4096, 1024]", p_model_layers_1_self_attn_k_proj_weight: "f32[4096, 1024]", p_model_layers_1_self_attn_v_proj_weight: "f32[4096, 1024]", p_model_layers_1_self_attn_o_proj_weight: "f32[1024, 4096]", p_model_layers_1_mlp_gate_proj_weight: "f32[6144, 1024]", p_model_layers_1_mlp_up_proj_weight: "f32[6144, 1024]", p_model_layers_1_mlp_down_proj_weight: "f32[1024, 6144]", p_model_layers_1_input_layernorm_weight: "f32[1024]", p_model_layers_1_post_attention_layernorm_weight: "f32[1024]", p_model_norm_weight: "f32[1024]", p_lm_head_weight: "f32[32000, 1024]", b_model_rotary_emb_inv_freq: "f32[64]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 32, s31, 128]", past_key_values_key_cache_1: "f32[s23, 32, s31, 128]", past_key_values_value_cache_0: "f32[s23, 32, s11, 128]", past_key_values_value_cache_1: "f32[s23, 32, s54, 128]"):
             #
            sym_size_int_15: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_18: "Sym(s23)" = torch.ops.aten.sym_size.int(position_ids, 0)
            sym_size_int_20: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
            sym_size_int_21: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
            sym_size_int_22: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_1, 0)
            sym_size_int_24: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 0)
            sym_size_int_26: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_1, 0)

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[s23, s70, 1024]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             #
            eq_35: "Sym(True)" = sym_size_int_18 == sym_size_int_20;  sym_size_int_18 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_35, "Runtime assertion failed for expression Eq(s44, s23) on node 'eq_35'");  eq_35 = _assert_scalar_default = None
            eq_36: "Sym(True)" = sym_size_int_20 == sym_size_int_24;  sym_size_int_24 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_36, "Runtime assertion failed for expression Eq(s23, s4) on node 'eq_36'");  eq_36 = _assert_scalar_default_1 = None
            eq_37: "Sym(True)" = sym_size_int_20 == sym_size_int_22;  sym_size_int_22 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_37, "Runtime assertion failed for expression Eq(s23, s27) on node 'eq_37'");  eq_37 = _assert_scalar_default_2 = None
            eq_38: "Sym(True)" = sym_size_int_20 == sym_size_int_26;  sym_size_int_26 = None
            _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_38, "Runtime assertion failed for expression Eq(s23, s40) on node 'eq_38'");  eq_38 = _assert_scalar_default_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:413 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s31 + s70)" = sym_size_int_21 + sym_size_int_15

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:412 in forward, code: cache_position = torch.arange(
            arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_21, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_21 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:419 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_20, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]);  arange_2 = None
            reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]);  arange_3 = None
            reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]);  arange = None
            reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]);  add_ = None
            expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_20, 1, sym_size_int_15, add]);  reshape = None
            expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_20, 1, sym_size_int_15, add]);  reshape_1 = expand_1 = None
            expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_20, 1, sym_size_int_15, add]);  reshape_2 = None
            expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_20, 1, sym_size_int_15, add]);  reshape_3 = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
            le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2);  expand_2 = None
            and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, le);  new_ones = le = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:100 in forward, code: return torch.ops.aten.index(x, indices)
            index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]);  to = expand = expand_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:419 in forward, code: causal_mask = create_causal_mask(
            and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, index);  and_1 = index = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_20, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:841 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_6: "f32[s23, s70, 128]" = wrap_with_set_grad_enabled[0]
            to_7: "f32[s23, s70, 128]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:62 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
            to_8: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s23, s70, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
            mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
            to_9: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9);  p_model_layers_0_input_layernorm_weight = to_9 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:231 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_15, -1, 128]);  linear = None
            transpose_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:232 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_15, -1, 128]);  linear_1 = None
            transpose_2: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:233 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_15, -1, 128]);  linear_2 = None
            transpose_3: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_3: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1)
            unsqueeze_4: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1)
            mul_4: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
            slice_4: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
            slice_5: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807);  transpose_1 = None
            neg: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_5);  slice_5 = None
            cat_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg, slice_4], -1);  neg = slice_4 = None
            mul_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_4);  cat_1 = None
            add_4: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3);  unsqueeze_3 = None
            slice_6: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
            slice_7: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_7);  slice_7 = None
            cat_2: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_1, slice_6], -1);  neg_1 = slice_6 = None
            mul_7: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_4);  cat_2 = unsqueeze_4 = None
            add_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s23, 32, s31 + s70, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[s23, 32, s11 + s70, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:247 in forward, code: attn_output, attn_weights = attention_interface(
            slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2)
            slice_9: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(slice_8, 1);  slice_8 = None
            slice_10: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(slice_9, 2);  slice_9 = None
            slice_11: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(slice_10, 3, None, add);  slice_10 = None
            contiguous: "f32[s23, 32, s70, 128]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            scaled_dot_product_attention: "f32[s23, 32, s70, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, cat_3, cat_4, slice_11, scale = 0.08838834764831845);  contiguous = slice_11 = None
            transpose_4: "f32[s23, s70, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_1: "f32[s23, s70, 32, 128]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:258 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_4: "f32[s23, s70, 4096]" = torch.ops.aten.reshape.default(contiguous_1, [sym_size_int_20, sym_size_int_15, -1]);  contiguous_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(reshape_4, p_model_layers_0_self_attn_o_proj_weight);  reshape_4 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:301 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s23, s70, 1024]" = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:62 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s23, s70, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_17: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_17, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(mul_17, torch.float32);  mul_17 = None
            mul_18: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11);  p_model_layers_0_post_attention_layernorm_weight = to_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s23, s70, 6144]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s23, s70, 6144]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s23, s70, 6144]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_up_proj_weight);  mul_18 = p_model_layers_0_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:151 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_19: "f32[s23, s70, 6144]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_19, p_model_layers_0_mlp_down_proj_weight);  mul_19 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:307 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s23, s70, 1024]" = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:62 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s23, s70, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_20: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_20, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(mul_20, torch.float32);  mul_20 = None
            mul_21: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_13);  p_model_layers_1_input_layernorm_weight = to_13 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_21, p_model_layers_1_self_attn_q_proj_weight);  p_model_layers_1_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:231 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_3: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_7, [sym_size_int_20, sym_size_int_15, -1, 128]);  linear_7 = None
            transpose_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_3, 1, 2);  view_3 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_8: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_21, p_model_layers_1_self_attn_k_proj_weight);  p_model_layers_1_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:232 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_4: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_8, [sym_size_int_20, sym_size_int_15, -1, 128]);  linear_8 = None
            transpose_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_4, 1, 2);  view_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_9: "f32[s23, s70, 4096]" = torch.ops.aten.linear.default(mul_21, p_model_layers_1_self_attn_v_proj_weight);  mul_21 = p_model_layers_1_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:233 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_5: "f32[s23, s70, 32, 128]" = torch.ops.aten.view.default(linear_9, [sym_size_int_20, sym_size_int_15, -1, 128]);  linear_9 = None
            transpose_7: "f32[s23, 32, s70, 128]" = torch.ops.aten.transpose.int(view_5, 1, 2);  view_5 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_5: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            unsqueeze_6: "f32[s23, 1, s70, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            mul_22: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_5, unsqueeze_5)
            slice_12: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 0, 64)
            slice_13: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 64, 9223372036854775807);  transpose_5 = None
            neg_2: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_13);  slice_13 = None
            cat_5: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_2, slice_12], -1);  neg_2 = slice_12 = None
            mul_23: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_6);  cat_5 = None
            add_11: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_22, mul_23);  mul_22 = mul_23 = None
            mul_24: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(transpose_6, unsqueeze_5);  unsqueeze_5 = None
            slice_14: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 0, 64)
            slice_15: "f32[s23, 32, s70, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 64, 9223372036854775807);  transpose_6 = None
            neg_3: "f32[s23, 32, s70, 64]" = torch.ops.aten.neg.default(slice_15);  slice_15 = None
            cat_6: "f32[s23, 32, s70, 128]" = torch.ops.aten.cat.default([neg_3, slice_14], -1);  neg_3 = slice_14 = None
            mul_25: "f32[s23, 32, s70, 128]" = torch.ops.aten.mul.Tensor(cat_6, unsqueeze_6);  cat_6 = unsqueeze_6 = None
            add_12: "f32[s23, 32, s70, 128]" = torch.ops.aten.add.Tensor(mul_24, mul_25);  mul_24 = mul_25 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_7: "f32[s23, 32, s31 + s70, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_1, add_12], -2);  past_key_values_key_cache_1 = add_12 = None
            cat_8: "f32[s23, 32, s54 + s70, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_1, transpose_7], -2);  past_key_values_value_cache_1 = transpose_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:247 in forward, code: attn_output, attn_weights = attention_interface(
            slice_16: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2);  and_2 = None
            slice_17: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(slice_16, 1);  slice_16 = None
            slice_18: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(slice_17, 2);  slice_17 = None
            slice_19: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(slice_18, 3, None, add);  slice_18 = add = None
            contiguous_2: "f32[s23, 32, s70, 128]" = torch.ops.aten.contiguous.default(add_11);  add_11 = None
            scaled_dot_product_attention_1: "f32[s23, 32, s70, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous_2, cat_7, cat_8, slice_19, scale = 0.08838834764831845);  contiguous_2 = slice_19 = None
            transpose_8: "f32[s23, s70, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2);  scaled_dot_product_attention_1 = None
            contiguous_3: "f32[s23, s70, 32, 128]" = torch.ops.aten.contiguous.default(transpose_8);  transpose_8 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:258 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_5: "f32[s23, s70, 4096]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_15, -1]);  contiguous_3 = sym_size_int_20 = sym_size_int_15 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_10: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(reshape_5, p_model_layers_1_self_attn_o_proj_weight);  reshape_5 = p_model_layers_1_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:301 in forward, code: hidden_states = residual + hidden_states
            add_13: "f32[s23, s70, 1024]" = torch.ops.aten.add.Tensor(to_12, linear_10);  to_12 = linear_10 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:62 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(add_13, torch.float32);  add_13 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_4: "f32[s23, s70, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_3: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True);  pow_4 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_14: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05);  mean_3 = None
            rsqrt_3: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_14);  add_14 = None
            mul_33: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_3);  rsqrt_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_33, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(mul_33, torch.float32);  mul_33 = None
            mul_34: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_15);  p_model_layers_1_post_attention_layernorm_weight = to_15 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_11: "f32[s23, s70, 6144]" = torch.ops.aten.linear.default(mul_34, p_model_layers_1_mlp_gate_proj_weight);  p_model_layers_1_mlp_gate_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
            silu_1: "f32[s23, s70, 6144]" = torch.ops.aten.silu.default(linear_11);  linear_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_12: "f32[s23, s70, 6144]" = torch.ops.aten.linear.default(mul_34, p_model_layers_1_mlp_up_proj_weight);  mul_34 = p_model_layers_1_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:151 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_35: "f32[s23, s70, 6144]" = torch.ops.aten.mul.Tensor(silu_1, linear_12);  silu_1 = linear_12 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_13: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_35, p_model_layers_1_mlp_down_proj_weight);  mul_35 = p_model_layers_1_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:307 in forward, code: hidden_states = residual + hidden_states
            add_15: "f32[s23, s70, 1024]" = torch.ops.aten.add.Tensor(to_14, linear_13);  to_14 = linear_13 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:62 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_16 = torch.ops.aten._assert_tensor_metadata.default(add_15, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_16 = None
            to_16: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(add_15, torch.float32);  add_15 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_5: "f32[s23, s70, 1024]" = torch.ops.aten.pow.Tensor_Scalar(to_16, 2)
            mean_4: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True);  pow_5 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_16: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05);  mean_4 = None
            rsqrt_4: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_16);  add_16 = None
            mul_36: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(to_16, rsqrt_4);  to_16 = rsqrt_4 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_17 = torch.ops.aten._assert_tensor_metadata.default(mul_36, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_17 = None
            to_17: "f32[s23, s70, 1024]" = torch.ops.aten.to.dtype(mul_36, torch.float32);  mul_36 = None
            mul_37: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_17);  p_model_norm_weight = to_17 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:568 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_20: "f32[s23, s70, 1024]" = torch.ops.aten.slice.Tensor(mul_37);  mul_37 = None
            slice_21: "f32[s23, s70, 1024]" = torch.ops.aten.slice.Tensor(slice_20, 1, 0);  slice_20 = None
            slice_22: "f32[s23, s70, 1024]" = torch.ops.aten.slice.Tensor(slice_21, 2);  slice_21 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_14: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_22, p_lm_head_weight);  slice_22 = p_lm_head_weight = None
            return (linear_14, cat_3, cat_7, cat_4, cat_8)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[64]", sym_size_int_20: "Sym(s23)", position_ids: "i64[s23, s70]"):
                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:823 in forward, code: self.inv_freq[None, :, None]
                unsqueeze: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_1: "f32[1, 64]" = torch.ops.aten.slice.Tensor(unsqueeze, 1, 0, 9223372036854775807);  unsqueeze = None
                unsqueeze_1: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 2);  slice_1 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:824 in forward, code: .float()
                _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
                to_1: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32);  unsqueeze_1 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:825 in forward, code: .expand(position_ids.shape[0], -1, 1)
                expand_4: "f32[s23, 64, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_20, -1, 1]);  to_1 = sym_size_int_20 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:826 in forward, code: .to(x.device)
                _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
                to_2: "f32[s23, 64, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_4 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:828 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_2: "i64[s23, s70]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(slice_2, 1);  slice_2 = None
                slice_3: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
                _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_3, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                to_3: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_3, torch.float32);  slice_3 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_2, to_3);  submod_3 = to_2 = to_3 = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:838 in forward, code: cos = emb.cos() * self.attention_scaling
                mul: "f32[s23, s70, 128]" = wrap_with_autocast[0]

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:839 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_1: "f32[s23, s70, 128]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:841 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_6: "f32[s23, s70, 128]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                to_7: "f32[s23, s70, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_6, to_7)

            class submod_1(torch.nn.Module):
                def forward(self, to_2: "f32[s23, 64, 1]", to_3: "f32[s23, 1, s70]"):
                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:836 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                    to_4: "f32[s23, 64, 1]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                    to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_3, torch.float32);  to_3 = None
                    matmul: "f32[s23, 64, s70]" = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
                    transpose: "f32[s23, s70, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:837 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[s23, s70, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:838 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[s23, s70, 128]" = torch.ops.aten.cos.default(cat)
                    mul: "f32[s23, s70, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                     # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:839 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[s23, s70, 128]" = torch.ops.aten.sin.default(cat);  cat = None
                    mul_1: "f32[s23, s70, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                    return (mul, mul_1)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_layers_1_self_attn_q_proj_weight: PARAMETER target='model.layers.1.self_attn.q_proj.weight'
    p_model_layers_1_self_attn_k_proj_weight: PARAMETER target='model.layers.1.self_attn.k_proj.weight'
    p_model_layers_1_self_attn_v_proj_weight: PARAMETER target='model.layers.1.self_attn.v_proj.weight'
    p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
    p_model_layers_1_mlp_gate_proj_weight: PARAMETER target='model.layers.1.mlp.gate_proj.weight'
    p_model_layers_1_mlp_up_proj_weight: PARAMETER target='model.layers.1.mlp.up_proj.weight'
    p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
    p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
    p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_key_cache_1: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT
    past_key_values_value_cache_1: USER_INPUT

    # outputs
    linear_14: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_7: USER_OUTPUT
    cat_4: USER_OUTPUT
    cat_8: USER_OUTPUT

Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo], s54: VR[2, int_oo]}
doc.plot_legend(
    "untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
)
plot export hub codellama

Total running time of the script: (0 minutes 5.319 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Export with DynamicCache and guessed dynamic shapes

Export with DynamicCache and guessed dynamic shapes

Gallery generated by Sphinx-Gallery