Test the export on untrained models

Checking the exporter on a whole model takes time as it is usually big but we can create a smaller version with the same architecture. Then fix export issues on such a small model is faster.

codellama/CodeLlama-7b-Python-hf

Let’s grab some information about this model. This reuses huggingface_hub API.

import copy
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.hghub import (
    get_untrained_model_with_inputs,
)
from onnx_diagnostic.torch_models.hghub.hub_api import (
    get_model_info,
    get_pretrained_config,
    task_from_id,
)
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors

model_id = "codellama/CodeLlama-7b-Python-hf"
print("info", get_model_info(model_id))
info ModelInfo(id='codellama/CodeLlama-7b-Python-hf', author='codellama', sha='d4178f5d2eead875e627ec487b23679266319b7f', created_at=datetime.datetime(2023, 8, 24, 16, 31, 28, tzinfo=datetime.timezone.utc), last_modified=datetime.datetime(2024, 4, 12, 14, 16, 26, tzinfo=datetime.timezone.utc), private=False, disabled=False, downloads=9649, downloads_all_time=None, gated=False, gguf=None, inference=None, likes=140, library_name='transformers', tags=['transformers', 'pytorch', 'safetensors', 'llama', 'text-generation', 'llama-2', 'code', 'arxiv:2308.12950', 'license:llama2', 'autotrain_compatible', 'text-generation-inference', 'endpoints_compatible', 'region:us'], pipeline_tag='text-generation', mask_token=None, card_data={'base_model': None, 'datasets': None, 'eval_results': None, 'language': ['code'], 'library_name': None, 'license': 'llama2', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': 'text-generation', 'tags': ['llama-2']}, widget_data=None, model_index=None, config={'architectures': ['LlamaForCausalLM'], 'model_type': 'llama', 'tokenizer_config': {'bos_token': {'__type': 'AddedToken', 'content': '<s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'eos_token': {'__type': 'AddedToken', 'content': '</s>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}, 'pad_token': None, 'unk_token': {'__type': 'AddedToken', 'content': '<unk>', 'lstrip': False, 'normalized': True, 'rstrip': False, 'single_word': False}}}, transformers_info=TransformersInfo(auto_model='AutoModelForCausalLM', custom_class=None, pipeline_tag='text-generation', processor='AutoTokenizer'), trending_score=None, siblings=[RepoSibling(rfilename='.gitattributes', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='LICENSE', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='README.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='USE_POLICY.md', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='generation_config.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00001-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model-00002-of-00002.safetensors', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='model.safetensors.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00001-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00002-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model-00003-of-00003.bin', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='pytorch_model.bin.index.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='special_tokens_map.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.json', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer.model', size=None, blob_id=None, lfs=None), RepoSibling(rfilename='tokenizer_config.json', size=None, blob_id=None, lfs=None)], spaces=['bigcode/bigcode-models-leaderboard', 'Intel/low_bit_open_llm_leaderboard', 'BAAI/open_cn_llm_leaderboard', 'qiantong-xu/toolbench-leaderboard', 'gsaivinay/open_llm_leaderboard', 'EvanTHU/MotionLLM', 'GTBench/GTBench', 'Vikhrmodels/small-shlepa-lb', 'kz-transformers/kaz-llm-lb', 'felixz/open_llm_leaderboard', 'HemaAM/GPT_train_on_LLaMa', '21world/bigcode-models-leaderboard', 'OPTML-Group/UnlearnCanvas-Benchmark', 'whackthejacker/CodeTuneStudio', 'anantgupta129/LitGPT-Pythia-160M', 'BAAI/open_flageval_vlm_leaderboard', 'neubla/neubla-llm-evaluation-board', 'PrarthanaTS/tsai-gpt-from-scratch', 'MadhurGarg/TSAIGPTRedPajama', 'RaviNaik/ERA-SESSION22', 'theangkko/codellama-CodeLlama-7b-Python-hf', 'rodrigomasini/data_only_open_llm_leaderboard', 'Docfile/open_llm_leaderboard', 'Sijuade/GPTNEXTWORD', 'VDebugger/VDebugger-generalist-for-VQA', 'Temuzin64/code_helper', 'piyushgrover/MiniGPT_S22', 'supra-e-acc/Pythia-160M-text-generate', 'venkyyuvy/GPT_redpajama', 'VarunSivamani/GPT-From-Scratch', 'mkthoma/GPT_From_Scratch', 'sanjanatule/GPTNext', 'RashiAgarwal/TSAIGPTRedPajama', 'neuralorbs/DialogGen', 'Navyabhat/ERAV1-Session-22', 'GunaKoppula/ERA-Session-22', 'Vaish2705/ERA_S22', 'smothiki/open_llm_leaderboard', 'aemonge/codellama-CodeLlama-7b-Python-hf', 'sid92/codellama-CodeLlama-7b-Python-hf', 'poubellearman/codellama-CodeLlama-7b-Python-hf', 'IPF/codellama-CodeLlama-7b-Python-hf', 'Chris4K/codellama-CodeLlama-7b-Python-hf', 'CuriosityPdf/codellama-CodeLlama-7b-Python-hf', 'shreefhamed/codellama-CodeLlama-7b-Python-hf', 'markl11/codellama-CodeLlama-7b-Python-hf', '0x1668/open_llm_leaderboard', 'rpratl/codellama-CodeLlama-7b-Python-hf', 'pngwn/open_llm_leaderboard-check', 'asir0z/open_llm_leaderboard', 'LovelySweet/codellama-CodeLlama-7b-Python-hf', 'kbmlcoding/open_llm_leaderboard_free', 'ToletiSri/TSAI_S22', 'aichampions/open_llm_leaderboard', 'Adeco/open_llm_leaderboard', 'anirudh937/open_llm_leaderboard', 'smothiki/open_llm_leaderboard2', 'mjalg/IFEvalTR', 'lastsamuraii/LitGPT-Pythia-160M', 'atlasas/bigcode-models-leaderboard'], safetensors=SafeTensorsInfo(parameters={'BF16': 6738415616}, total=6738415616), security_repo_status=None)

The configuration.

print("config", get_pretrained_config(model_id))
config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 16384,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 1000000,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.51.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

The task determines the set of inputs which needs to be created for this input.

print("task", task_from_id(model_id))
task text-generation

Untrained model

The function get_untrained_model_with_inputs. It loads the pretrained configuration, extracts the task associated to the model and them creates random inputs and dynamic shapes for torch.export.export().

data = get_untrained_model_with_inputs(model_id, verbose=1)
print("model size:", data["size"])
print("number of weights:", data["n_weights"])
print("fields:", set(data))
[get_untrained_model_with_inputs] model_id='codellama/CodeLlama-7b-Python-hf'
[get_untrained_model_with_inputs] architecture='LlamaForCausalLM'
[get_untrained_model_with_inputs] cls='LlamaConfig'
[get_untrained_model_with_inputs] task='text-generation'
model size: 410532864
number of weights: 102633216
fields: {'dynamic_shapes', 'input_kwargs', 'size', 'configuration', 'n_weights', 'model', 'inputs', 'model_kwargs'}

Inputs

print("inputs:", string_type(data["inputs"], with_shape=True))
inputs: dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x128,T1s2x32x30x128], value_cache=#2[T1s2x32x30x128,T1s2x32x30x128]))

Dynamic Shapes

print("dynamic shapes:", pprint.pformat(data["dynamic_shapes"]))
dynamic shapes: {'attention_mask': {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
                    1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
 'input_ids': {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
               1: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.seq_length'>},
 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>},
                      {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>}],
                     [{0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>},
                      {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.cache_length'>}]],
 'position_ids': {0: <class 'onnx_diagnostic.torch_models.hghub.model_inputs.batch'>,
                  1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}}

Let’s check the model runs. We still needs to copy the inputs before using the models, the cache is usually modified inplace. Expected outputs can be used later to compute discrepancies.

inputs_copy = copy.deepcopy(data["inputs"])
model = data["model"]
expected_outputs = model(**inputs_copy)

print("outputs:", string_type(expected_outputs, with_shape=True))
outputs: dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x128,T1s2x32x33x128], value_cache=#2[T1s2x32x33x128,T1s2x32x33x128]))

It works.

Export

The model uses transformers.cache_utils.DynamicCache. It still requires patches to be exportable (control flow). See onnx_diagnostic.torch_export_patches.bypass_export_some_errors()

with bypass_export_some_errors(patch_transformers=True) as f:
    ep = torch.export.export(
        model,
        (),
        kwargs=f(data["inputs"]),
        dynamic_shapes=data["dynamic_shapes"],
        strict=False,
    )
    print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 768]", p_model_layers_0_self_attn_q_proj_weight: "f32[4096, 768]", p_model_layers_0_self_attn_k_proj_weight: "f32[4096, 768]", p_model_layers_0_self_attn_v_proj_weight: "f32[4096, 768]", p_model_layers_0_self_attn_o_proj_weight: "f32[768, 4096]", p_model_layers_0_mlp_gate_proj_weight: "f32[6144, 768]", p_model_layers_0_mlp_up_proj_weight: "f32[6144, 768]", p_model_layers_0_mlp_down_proj_weight: "f32[768, 6144]", p_model_layers_0_input_layernorm_weight: "f32[768]", p_model_layers_0_post_attention_layernorm_weight: "f32[768]", p_model_layers_1_self_attn_q_proj_weight: "f32[4096, 768]", p_model_layers_1_self_attn_k_proj_weight: "f32[4096, 768]", p_model_layers_1_self_attn_v_proj_weight: "f32[4096, 768]", p_model_layers_1_self_attn_o_proj_weight: "f32[768, 4096]", p_model_layers_1_mlp_gate_proj_weight: "f32[6144, 768]", p_model_layers_1_mlp_up_proj_weight: "f32[6144, 768]", p_model_layers_1_mlp_down_proj_weight: "f32[768, 6144]", p_model_layers_1_input_layernorm_weight: "f32[768]", p_model_layers_1_post_attention_layernorm_weight: "f32[768]", p_model_norm_weight: "f32[768]", p_lm_head_weight: "f32[32000, 768]", b_model_rotary_emb_inv_freq: "f32[64]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s10]", position_ids: "i64[s0, s1]", past_key_values_key_cache_0: "f32[s0, 32, s10, 128]", past_key_values_key_cache_1: "f32[s0, 32, s10, 128]", past_key_values_value_cache_0: "f32[s0, 32, s10, 128]", past_key_values_value_cache_1: "f32[s0, 32, s10, 128]"):
             #
            sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
            sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_22: "Sym(s10)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[s0, s1, 768]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:561 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s1 + s10)" = sym_size_int_22 + sym_size_int_21

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:560 in forward, code: cache_position = torch.arange(
            arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_22 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:567 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s1, s1 + s10]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s1, s1 + s10]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s1 + s10]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False);  add = None
            reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s1, s1 + s10]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s1, s1 + s10]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze: "f32[1, s1, s1 + s10]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_1: "f32[1, 1, s1, s1 + s10]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            slice_1: "f32[1, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
            slice_2: "f32[1, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            expand: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]);  slice_2 = None
            clone: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_4: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
            slice_5: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
            slice_6: "i64[s0, s1 + s10]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_2: "i64[s0, 1, s1 + s10]" = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
            unsqueeze_3: "i64[s0, 1, 1, s1 + s10]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
            slice_7: "i64[s0, 1, 1, s1 + s10]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
            to: "i64[s0, 1, 1, s1 + s10]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
            add_2: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
            eq_9: "b8[s0, 1, s1, s1 + s10]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_8: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_9: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
            slice_10: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
            masked_fill: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_9, -3.4028234663852886e+38);  slice_10 = eq_9 = None
            slice_11: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_12: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
            slice_13: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
            copy_: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_20, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_6: "f32[s0, s1, 128]" = wrap_with_set_grad_enabled[0]
            to_7: "f32[s0, s1, 128]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_8: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
            mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_9: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9);  p_model_layers_0_input_layernorm_weight = to_9 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 128]);  linear = None
            transpose_1: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 128]);  linear_1 = None
            transpose_2: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 128]);  linear_2 = None
            transpose_3: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_7: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1)
            unsqueeze_8: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1)
            mul_4: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
            slice_17: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 64)
            slice_18: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 64, 9223372036854775807);  transpose_1 = None
            neg: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_18);  slice_18 = None
            cat_1: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
            mul_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
            add_4: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
            slice_19: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 64)
            slice_20: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 64, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_2: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
            mul_7: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
            add_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
            slice_21: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_22: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
            slice_23: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_22, 2, 0, 9223372036854775807);  slice_22 = None
            contiguous: "f32[s0, 32, s1, 128]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            scaled_dot_product_attention: "f32[s0, 32, s1, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, cat_3, cat_4, slice_23, scale = 0.08838834764831845);  contiguous = slice_23 = None
            transpose_4: "f32[s0, s1, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_1: "f32[s0, s1, 32, 128]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_1: "f32[s0, s1, 4096]" = torch.ops.aten.reshape.default(contiguous_1, [sym_size_int_20, sym_size_int_21, -1]);  contiguous_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(reshape_1, p_model_layers_0_self_attn_o_proj_weight);  reshape_1 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_10: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_11: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11);  p_model_layers_0_post_attention_layernorm_weight = to_11 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s0, s1, 6144]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[s0, s1, 6144]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_12: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_13: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_1_input_layernorm_weight, to_13);  p_model_layers_1_input_layernorm_weight = to_13 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_q_proj_weight);  p_model_layers_1_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_3: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_7, [sym_size_int_20, sym_size_int_21, -1, 128]);  linear_7 = None
            transpose_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_3, 1, 2);  view_3 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_8: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_k_proj_weight);  p_model_layers_1_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_4: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_8, [sym_size_int_20, sym_size_int_21, -1, 128]);  linear_8 = None
            transpose_6: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_4, 1, 2);  view_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_9: "f32[s0, s1, 4096]" = torch.ops.aten.linear.default(mul_12, p_model_layers_1_self_attn_v_proj_weight);  mul_12 = p_model_layers_1_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_5: "f32[s0, s1, 32, 128]" = torch.ops.aten.view.default(linear_9, [sym_size_int_20, sym_size_int_21, -1, 128]);  linear_9 = None
            transpose_7: "f32[s0, 32, s1, 128]" = torch.ops.aten.transpose.int(view_5, 1, 2);  view_5 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_9: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            unsqueeze_10: "f32[s0, 1, s1, 128]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            mul_13: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_5, unsqueeze_9)
            slice_24: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 0, 64)
            slice_25: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_5, 3, 64, 9223372036854775807);  transpose_5 = None
            neg_2: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_25);  slice_25 = None
            cat_5: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg_2, slice_24], -1);  neg_2 = slice_24 = None
            mul_14: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_5, unsqueeze_10);  cat_5 = None
            add_11: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_13, mul_14);  mul_13 = mul_14 = None
            mul_15: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(transpose_6, unsqueeze_9);  unsqueeze_9 = None
            slice_26: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 0, 64)
            slice_27: "f32[s0, 32, s1, 64]" = torch.ops.aten.slice.Tensor(transpose_6, 3, 64, 9223372036854775807);  transpose_6 = None
            neg_3: "f32[s0, 32, s1, 64]" = torch.ops.aten.neg.default(slice_27);  slice_27 = None
            cat_6: "f32[s0, 32, s1, 128]" = torch.ops.aten.cat.default([neg_3, slice_26], -1);  neg_3 = slice_26 = None
            mul_16: "f32[s0, 32, s1, 128]" = torch.ops.aten.mul.Tensor(cat_6, unsqueeze_10);  cat_6 = unsqueeze_10 = None
            add_12: "f32[s0, 32, s1, 128]" = torch.ops.aten.add.Tensor(mul_15, mul_16);  mul_15 = mul_16 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_7: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_key_cache_1, add_12], -2);  past_key_values_key_cache_1 = add_12 = None
            cat_8: "f32[s0, 32, s1 + s10, 128]" = torch.ops.aten.cat.default([past_key_values_value_cache_1, transpose_7], -2);  past_key_values_value_cache_1 = transpose_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
            slice_28: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
            slice_29: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_28, 1, 0, 9223372036854775807);  slice_28 = None
            slice_30: "f32[s0, 1, s1, s1 + s10]" = torch.ops.aten.slice.Tensor(slice_29, 2, 0, 9223372036854775807);  slice_29 = None
            contiguous_2: "f32[s0, 32, s1, 128]" = torch.ops.aten.contiguous.default(add_11);  add_11 = None
            scaled_dot_product_attention_1: "f32[s0, 32, s1, 128]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous_2, cat_7, cat_8, slice_30, scale = 0.08838834764831845);  contiguous_2 = slice_30 = None
            transpose_8: "f32[s0, s1, 32, 128]" = torch.ops.aten.transpose.int(scaled_dot_product_attention_1, 1, 2);  scaled_dot_product_attention_1 = None
            contiguous_3: "f32[s0, s1, 32, 128]" = torch.ops.aten.contiguous.default(transpose_8);  transpose_8 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_2: "f32[s0, s1, 4096]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]);  contiguous_3 = sym_size_int_20 = sym_size_int_21 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_10: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_1_self_attn_o_proj_weight);  reshape_2 = p_model_layers_1_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
            add_13: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_12, linear_10);  to_12 = linear_10 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_14: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_13, torch.float32);  add_13 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_4: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_3: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True);  pow_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_14: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05);  mean_3 = None
            rsqrt_3: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_14);  add_14 = None
            mul_17: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_3);  rsqrt_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_15: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_17, torch.float32);  mul_17 = None
            mul_18: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_layers_1_post_attention_layernorm_weight, to_15);  p_model_layers_1_post_attention_layernorm_weight = to_15 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_11: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_gate_proj_weight);  p_model_layers_1_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu_1: "f32[s0, s1, 6144]" = torch.ops.aten.silu.default(linear_11);  linear_11 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_12: "f32[s0, s1, 6144]" = torch.ops.aten.linear.default(mul_18, p_model_layers_1_mlp_up_proj_weight);  mul_18 = p_model_layers_1_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_19: "f32[s0, s1, 6144]" = torch.ops.aten.mul.Tensor(silu_1, linear_12);  silu_1 = linear_12 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_13: "f32[s0, s1, 768]" = torch.ops.aten.linear.default(mul_19, p_model_layers_1_mlp_down_proj_weight);  mul_19 = p_model_layers_1_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
            add_15: "f32[s0, s1, 768]" = torch.ops.aten.add.Tensor(to_14, linear_13);  to_14 = linear_13 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_16: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(add_15, torch.float32);  add_15 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_5: "f32[s0, s1, 768]" = torch.ops.aten.pow.Tensor_Scalar(to_16, 2)
            mean_4: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True);  pow_5 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_16: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05);  mean_4 = None
            rsqrt_4: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_16);  add_16 = None
            mul_20: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(to_16, rsqrt_4);  to_16 = rsqrt_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_17: "f32[s0, s1, 768]" = torch.ops.aten.to.dtype(mul_20, torch.float32);  mul_20 = None
            mul_21: "f32[s0, s1, 768]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_17);  p_model_norm_weight = to_17 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:866 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_31: "f32[s0, s1, 768]" = torch.ops.aten.slice.Tensor(mul_21, 0, 0, 9223372036854775807);  mul_21 = None
            slice_32: "f32[s0, s1, 768]" = torch.ops.aten.slice.Tensor(slice_31, 1, 0, 9223372036854775807);  slice_31 = None
            slice_33: "f32[s0, s1, 768]" = torch.ops.aten.slice.Tensor(slice_32, 2, 0, 9223372036854775807);  slice_32 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_14: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_33, p_lm_head_weight);  slice_33 = p_lm_head_weight = None
            return (linear_14, cat_3, cat_7, cat_4, cat_8)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[64]", sym_size_int_20: "Sym(s0)", position_ids: "i64[s0, s1]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_4: "f32[1, 64]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_14: "f32[1, 64]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
                unsqueeze_5: "f32[1, 64, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
                to_1: "f32[1, 64, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
                expand_1: "f32[s0, 64, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_20, -1, 1]);  to_1 = sym_size_int_20 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_15: "i64[s0, s1]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_6: "i64[s0, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
                slice_16: "i64[s0, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
                to_2: "f32[s0, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                cos: "f32[s0, s1, 128]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                sin: "f32[s0, s1, 128]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[s0, s1, 128]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[s0, s1, 128]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_6: "f32[s0, s1, 128]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_7: "f32[s0, s1, 128]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_6, to_7)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[s0, 64, 1]", to_2: "f32[s0, 1, s1]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
                    to_3: "f32[s0, 64, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
                    to_4: "f32[s0, 64, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
                    to_5: "f32[s0, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[s0, 64, s1]" = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
                    transpose: "f32[s0, s1, 64]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[s0, s1, 128]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                    cos: "f32[s0, s1, 128]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                    sin: "f32[s0, s1, 128]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_layers_1_self_attn_q_proj_weight: PARAMETER target='model.layers.1.self_attn.q_proj.weight'
    p_model_layers_1_self_attn_k_proj_weight: PARAMETER target='model.layers.1.self_attn.k_proj.weight'
    p_model_layers_1_self_attn_v_proj_weight: PARAMETER target='model.layers.1.self_attn.v_proj.weight'
    p_model_layers_1_self_attn_o_proj_weight: PARAMETER target='model.layers.1.self_attn.o_proj.weight'
    p_model_layers_1_mlp_gate_proj_weight: PARAMETER target='model.layers.1.mlp.gate_proj.weight'
    p_model_layers_1_mlp_up_proj_weight: PARAMETER target='model.layers.1.mlp.up_proj.weight'
    p_model_layers_1_mlp_down_proj_weight: PARAMETER target='model.layers.1.mlp.down_proj.weight'
    p_model_layers_1_input_layernorm_weight: PARAMETER target='model.layers.1.input_layernorm.weight'
    p_model_layers_1_post_attention_layernorm_weight: PARAMETER target='model.layers.1.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_key_cache_1: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT
    past_key_values_value_cache_1: USER_INPUT

    # outputs
    linear_14: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_7: USER_OUTPUT
    cat_4: USER_OUTPUT
    cat_8: USER_OUTPUT

Range constraints: {s0: VR[1, 1024], s1: VR[2, 4096], s1 + s10: VR[4, 8192], s10: VR[1, 4096]}
doc.plot_legend(
    "untrained\ncodellama/\nCodeLlama-7b-Python-hf", "torch.export.export", "tomato"
)
plot export hub codellama

Total running time of the script: (0 minutes 2.713 seconds)

Related examples

Untrained microsoft/phi-2

Untrained microsoft/phi-2

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Gallery generated by Sphinx-Gallery