Export Tiny-LLM with patches

Many models from transformers cannot be converted because the implementation uses cache classes. Let’s see how to get around that. We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture. This continues example Steel method forward to guess the dynamic shapes (with Tiny-LLM).

Errors

They depend on transformers version.

transformers>=4.40,<4.50 cannot serialize DynamicCache and cannot map dynamic shapes to instances of DynamicCache. The following errors would appear:

torch._dynamo.exc.UserError: Cannot associate shape
    [[{0: <class '....batch'>, 2: <class '....cache_length'>}],
     [{0: <class '....batch'>, 2: <class '....cache_length'>}]]
    specified at `dynamic_shapes['past_key_values']`
    to non-tensor type <class 'transformers.cache_utils.DynamicCache'>
    at `inputs['past_key_values']` (expected None)
For more information about this error,
see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

With transformers==4.50, it shows the following:

torch._dynamo.exc.UserError: Constraints violated (batch)!
For more information, run with TORCH_LOGS="+dynamic".
    - Not all values of batch = L['args'][1]['input_ids'].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
    - Not all values of batch = L['args'][1]['attention_mask'].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
    - Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
    - Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
 Suggested fixes:
     batch = 2

However, this package implements a patch mechanism with replaces the part causing these issues.

Note

restart after an export failure

If the export fails, it is better to start executing again, or restart the kernel if you are in the notebook. The export may leave torch in one unstable state.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
from onnx_diagnostic.torch_models.llms import get_tiny_llm


experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

cloned_inputs = copy.deepcopy(inputs)

Let’s show this inputs, this was inferred in example Steel method forward to guess the dynamic shapes (with Tiny-LLM).

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x30x96],#1[T1s2x1x30x96]]))

And the dynamic shapes

{'attention_mask': {0: Dim('batch', min=1, max=1024),
                    1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                                min=None,
                                max=None,
                                _factory=True)},
 'input_ids': {0: Dim('batch', min=1, max=1024),
               1: Dim('seq_length', min=1, max=4096)},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024),
                       2: Dim('cache_length', min=1, max=4096)}],
                     [{0: Dim('batch', min=1, max=1024),
                       2: Dim('cache_length', min=1, max=4096)}]],
 'position_ids': {0: Dim('batch', min=1, max=1024),
                  1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                              min=None,
                              max=None,
                              _factory=True)}}

Before exporting, we check transformers.cache_utils.DynamicCache can serialized and deserialized otherwise torch.export.export() fails.

print("-- DynamicCache registered: ", is_cache_dynamic_registered())
-- DynamicCache registered:  True

If they are not registered, function func:onnx_diagnostic.torch_export_patches.bypass_export_some_errors should take care of it. Then we export.

with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
    assert is_cache_dynamic_registered()  # it must be true here
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=modificator(cloned_inputs),
        dynamic_shapes=dynamic_shapes,
        strict=False,  # mandatory for torch==2.6
    )
    print("It worked:")
    print(ep)
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_register_cache_serialization] register MambaCache
[_register_cache_serialization] DynamicCache is unregistered and registered first.
[_register_cache_serialization] <class 'transformers.cache_utils.DynamicCache'> already registered
[bypass_export_some_errors] sympy.__version__='1.13.3'
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] torch.__version__='2.8.0.dev20250404+cu126'
[bypass_export_some_errors] stop_if_static=0
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] modifies shape constraints
[bypass_export_some_errors] transformers.__version__='4.51.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_AttentionMaskConverter: _make_causal_mask
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[bypass_export_some_errors] done patching
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s41, s2]", attention_mask: "i64[s41, s2 + s67]", position_ids: "i64[s41, s2]", past_key_values_key_cache_0: "f32[s41, 1, s67, 96]", past_key_values_value_cache_0: "f32[s41, 1, s67, 96]"):
             #
            sym_size_int_22: "Sym(s41)" = torch.ops.aten.sym_size.int(input_ids, 0)
            sym_size_int_23: "Sym(s2)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_24: "Sym(s67)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[s41, s2, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s2 + s67)" = sym_size_int_24 + sym_size_int_23

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:563 in forward, code: cache_position = torch.arange(
            arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int_24, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_24 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:570 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s2, s2 + s67]" = torch.ops.aten.full.default([sym_size_int_23, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s2, s2 + s67]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s2 + s67]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s2, s2 + s67]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s2, s2 + s67]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze: "f32[1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            slice_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
            slice_2: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            expand: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_22, 1, -1, -1]);  slice_2 = None
            clone: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
            slice_4: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_3, 1);  slice_3 = None
            slice_5: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
            slice_6: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, add);  slice_5 = None
            slice_7: "i64[s41, s2 + s67]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_2: "i64[s41, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
            unsqueeze_3: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
            slice_8: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(slice_8, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
            add_2: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.add.Tensor(slice_6, to);  slice_6 = to = None
            eq_6: "b8[s41, 1, s2, s2 + s67]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_9: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
            slice_10: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
            slice_11: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
            slice_12: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, add);  slice_11 = None
            masked_fill: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_6, -3.4028234663852886e+38);  slice_12 = eq_6 = None
            slice_13: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_14: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807);  slice_13 = None
            slice_15: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
            copy_: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_22, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:152 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_5: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[0]
            to_6: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
            to_7: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_7, 2)
            mean: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:82 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_7, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
            to_8: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_8);  p_model_layers_0_input_layernorm_weight = to_8 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:281 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s41, s2, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_22, sym_size_int_23, -1, 96]);  linear = None
            transpose_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_22, sym_size_int_23, -1, 96]);  linear_1 = None
            transpose_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:283 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_22, sym_size_int_23, -1, 96]);  linear_2 = None
            transpose_3: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:286 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_5, 1);  to_5 = None
            unsqueeze_8: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            mul_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
            slice_19: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_20: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s41, 2, s2, 48]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.cat.default([neg, slice_19], -1);  neg = slice_19 = None
            mul_5: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
            add_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
            slice_21: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_22: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s41, 1, s2, 48]" = torch.ops.aten.neg.default(slice_22);  slice_22 = None
            cat_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.cat.default([neg_1, slice_21], -1);  neg_1 = slice_21 = None
            mul_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
            add_5: "f32[s41, 1, s2, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:291 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:303 in forward, code: attn_output, attn_weights = attention_interface(
            slice_23: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            slice_24: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_23, 1, 0, 9223372036854775807);  slice_23 = None
            unsqueeze_9: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_24, 2);  slice_24 = None
            slice_25: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807);  unsqueeze_9 = None
            slice_26: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_25, 4, 0, 9223372036854775807);  slice_25 = None
            expand_2: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_26, [sym_size_int_22, 1, 2, add, 96]);  slice_26 = None
            reshape_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_22, 2, add, 96]);  expand_2 = None
            slice_27: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            slice_28: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_27, 1, 0, 9223372036854775807);  slice_27 = None
            unsqueeze_10: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_28, 2);  slice_28 = None
            slice_29: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
            slice_30: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_29, 4, 0, 9223372036854775807);  slice_29 = None
            expand_3: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_30, [sym_size_int_22, 1, 2, add, 96]);  slice_30 = None
            reshape_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_22, 2, add, 96]);  expand_3 = None
            slice_31: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone);  clone = None
            slice_32: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_31, 1);  slice_31 = None
            slice_33: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_32, 2);  slice_32 = None
            slice_34: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_33, 3, None, add);  slice_33 = add = None
            contiguous: "f32[s41, 2, s2, 96]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            contiguous_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
            contiguous_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
            scaled_dot_product_attention: "f32[s41, 2, s2, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_34, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_34 = None
            transpose_4: "f32[s41, s2, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_3: "f32[s41, s2, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:314 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[s41, s2, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_22, sym_size_int_23, -1]);  contiguous_3 = sym_size_int_22 = sym_size_int_23 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight);  reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:358 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_7, linear_3);  to_7 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
            to_9: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_9, 2)
            mean_1: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:82 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_9, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_10);  p_model_layers_0_post_attention_layernorm_weight = to_10 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s41, s2, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:201 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[s41, s2, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:364 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_9, linear_6);  to_9 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_11, 2)
            mean_2: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:82 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_11, rsqrt_2);  to_11 = rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_12);  p_model_norm_weight = to_12 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:866 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_35: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(mul_12);  mul_12 = None
            slice_36: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_35, 1, 0);  slice_35 = None
            slice_37: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_36, 2);  slice_36 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s41, s2, 32000]" = torch.ops.aten.linear.default(slice_37, p_lm_head_weight);  slice_37 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_22: "Sym(s41)", position_ids: "i64[s41, s2]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_16: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
                unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_16, 2);  slice_16 = None
                _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
                to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
                expand_1: "f32[s41, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_22, -1, 1]);  to_1 = sym_size_int_22 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:136 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_17: "i64[s41, s2]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_6: "i64[s41, 1, s2]" = torch.ops.aten.unsqueeze.default(slice_17, 1);  slice_17 = None
                slice_18: "i64[s41, 1, s2]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
                _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(slice_18, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
                to_2: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(slice_18, torch.float32);  slice_18 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = emb.cos()
                cos: "f32[s41, s2, 96]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = emb.sin()
                sin: "f32[s41, s2, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:149 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:150 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:152 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                to_5: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_6: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_5, to_6)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[s41, 48, 1]", to_2: "f32[s41, 1, s2]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
                    _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                    to_3: "f32[s41, 48, 1]" = torch.ops.aten.to.device(expand_1, device(type='cpu'), torch.float32);  expand_1 = None
                    _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                    to_4: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[s41, 48, s2]" = torch.ops.aten.matmul.default(to_3, to_4);  to_3 = to_4 = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:143 in forward, code: ).transpose(1, 2)
                    transpose: "f32[s41, s2, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:144 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[s41, s2, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = emb.cos()
                    cos: "f32[s41, s2, 96]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = emb.sin()
                    sin: "f32[s41, s2, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_4: USER_OUTPUT

Range constraints: {s41: VR[1, 1024], s2: VR[2, 4096], s2 + s67: VR[4, 8192], s67: VR[1, 4096]}

[bypass_export_some_errors] remove patches
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored shape constraints
[bypass_export_some_errors] unpatch transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_AttentionMaskConverter: _make_causal_mask
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[_unregister_cache_serialization] unregistered MambaCache
[_unregister_cache_serialization] skip unregister DynamicCache

With the original model

MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

cloned_inputs = copy.deepcopy(inputs)

with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
    ep = torch.export.export(
        model,
        (),
        kwargs=modificator(cloned_inputs),
        dynamic_shapes=dynamic_shapes,
        strict=False,  # mandatory for torch==2.6
    )
    print("It worked:")
    print(ep)
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_register_cache_serialization] register MambaCache
[_register_cache_serialization] <class 'transformers.cache_utils.DynamicCache'> already registered
[bypass_export_some_errors] sympy.__version__='1.13.3'
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] torch.__version__='2.8.0.dev20250404+cu126'
[bypass_export_some_errors] stop_if_static=0
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] modifies shape constraints
[bypass_export_some_errors] transformers.__version__='4.51.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_AttentionMaskConverter: _make_causal_mask
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[bypass_export_some_errors] done patching
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s41, s2]", attention_mask: "i64[s41, s2 + s67]", position_ids: "i64[s41, s2]", past_key_values_key_cache_0: "f32[s41, 1, s67, 96]", past_key_values_value_cache_0: "f32[s41, 1, s67, 96]"):
             #
            sym_size_int_22: "Sym(s41)" = torch.ops.aten.sym_size.int(input_ids, 0)
            sym_size_int_23: "Sym(s2)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_24: "Sym(s67)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[s41, s2, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s2 + s67)" = sym_size_int_24 + sym_size_int_23

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:563 in forward, code: cache_position = torch.arange(
            arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int_24, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_24 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:570 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s2, s2 + s67]" = torch.ops.aten.full.default([sym_size_int_23, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s2, s2 + s67]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s2 + s67]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s2, s2 + s67]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s2, s2 + s67]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze: "f32[1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            slice_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
            slice_2: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            expand: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_22, 1, -1, -1]);  slice_2 = None
            clone: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
            slice_4: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_3, 1);  slice_3 = None
            slice_5: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_4, 2);  slice_4 = None
            slice_6: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, add);  slice_5 = None
            slice_7: "i64[s41, s2 + s67]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_2: "i64[s41, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(slice_7, 1);  slice_7 = None
            unsqueeze_3: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
            slice_8: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(slice_8, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_8 = None
            add_2: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.add.Tensor(slice_6, to);  slice_6 = to = None
            eq_6: "b8[s41, 1, s2, s2 + s67]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_9: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
            slice_10: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_9, 1);  slice_9 = None
            slice_11: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_10, 2);  slice_10 = None
            slice_12: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, add);  slice_11 = None
            masked_fill: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_6, -3.4028234663852886e+38);  slice_12 = eq_6 = None
            slice_13: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_14: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807);  slice_13 = None
            slice_15: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807);  slice_14 = None
            copy_: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.copy_.default(slice_15, masked_fill);  slice_15 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_22, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:152 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_5: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[0]
            to_6: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
            to_7: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_7, 2)
            mean: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:82 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_7, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
            to_8: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_8);  p_model_layers_0_input_layernorm_weight = to_8 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:281 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s41, s2, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_22, sym_size_int_23, -1, 96]);  linear = None
            transpose_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_22, sym_size_int_23, -1, 96]);  linear_1 = None
            transpose_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:283 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_22, sym_size_int_23, -1, 96]);  linear_2 = None
            transpose_3: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:286 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_5, 1);  to_5 = None
            unsqueeze_8: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            mul_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
            slice_19: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_20: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s41, 2, s2, 48]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.cat.default([neg, slice_19], -1);  neg = slice_19 = None
            mul_5: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
            add_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
            slice_21: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_22: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s41, 1, s2, 48]" = torch.ops.aten.neg.default(slice_22);  slice_22 = None
            cat_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.cat.default([neg_1, slice_21], -1);  neg_1 = slice_21 = None
            mul_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
            add_5: "f32[s41, 1, s2, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:291 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:303 in forward, code: attn_output, attn_weights = attention_interface(
            slice_23: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            slice_24: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_23, 1, 0, 9223372036854775807);  slice_23 = None
            unsqueeze_9: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_24, 2);  slice_24 = None
            slice_25: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807);  unsqueeze_9 = None
            slice_26: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_25, 4, 0, 9223372036854775807);  slice_25 = None
            expand_2: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_26, [sym_size_int_22, 1, 2, add, 96]);  slice_26 = None
            reshape_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_22, 2, add, 96]);  expand_2 = None
            slice_27: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            slice_28: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_27, 1, 0, 9223372036854775807);  slice_27 = None
            unsqueeze_10: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_28, 2);  slice_28 = None
            slice_29: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
            slice_30: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_29, 4, 0, 9223372036854775807);  slice_29 = None
            expand_3: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_30, [sym_size_int_22, 1, 2, add, 96]);  slice_30 = None
            reshape_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_22, 2, add, 96]);  expand_3 = None
            slice_31: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone);  clone = None
            slice_32: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_31, 1);  slice_31 = None
            slice_33: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_32, 2);  slice_32 = None
            slice_34: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_33, 3, None, add);  slice_33 = add = None
            contiguous: "f32[s41, 2, s2, 96]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            contiguous_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
            contiguous_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
            scaled_dot_product_attention: "f32[s41, 2, s2, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_34, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_34 = None
            transpose_4: "f32[s41, s2, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_3: "f32[s41, s2, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:314 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[s41, s2, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_22, sym_size_int_23, -1]);  contiguous_3 = sym_size_int_22 = sym_size_int_23 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight);  reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:358 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_7, linear_3);  to_7 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
            to_9: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_9, 2)
            mean_1: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:82 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_9, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_10);  p_model_layers_0_post_attention_layernorm_weight = to_10 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s41, s2, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:201 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[s41, s2, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:364 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_9, linear_6);  to_9 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_11, 2)
            mean_2: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:82 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_11, rsqrt_2);  to_11 = rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_12);  p_model_norm_weight = to_12 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:866 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_35: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(mul_12);  mul_12 = None
            slice_36: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_35, 1, 0);  slice_35 = None
            slice_37: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_36, 2);  slice_36 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s41, s2, 32000]" = torch.ops.aten.linear.default(slice_37, p_lm_head_weight);  slice_37 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_22: "Sym(s41)", position_ids: "i64[s41, s2]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_16: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
                unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_16, 2);  slice_16 = None
                _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
                to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
                expand_1: "f32[s41, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_22, -1, 1]);  to_1 = sym_size_int_22 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:136 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_17: "i64[s41, s2]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_6: "i64[s41, 1, s2]" = torch.ops.aten.unsqueeze.default(slice_17, 1);  slice_17 = None
                slice_18: "i64[s41, 1, s2]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
                _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(slice_18, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
                to_2: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(slice_18, torch.float32);  slice_18 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = emb.cos()
                cos: "f32[s41, s2, 96]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = emb.sin()
                sin: "f32[s41, s2, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:149 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:150 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:152 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                to_5: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_6: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_5, to_6)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[s41, 48, 1]", to_2: "f32[s41, 1, s2]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float()
                    _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                    to_3: "f32[s41, 48, 1]" = torch.ops.aten.to.device(expand_1, device(type='cpu'), torch.float32);  expand_1 = None
                    _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                    to_4: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[s41, 48, s2]" = torch.ops.aten.matmul.default(to_3, to_4);  to_3 = to_4 = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:143 in forward, code: ).transpose(1, 2)
                    transpose: "f32[s41, s2, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:144 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[s41, s2, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = emb.cos()
                    cos: "f32[s41, s2, 96]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = emb.sin()
                    sin: "f32[s41, s2, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_4: USER_OUTPUT

Range constraints: {s41: VR[1, 1024], s2: VR[2, 4096], s2 + s67: VR[4, 8192], s67: VR[1, 4096]}

[bypass_export_some_errors] remove patches
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored shape constraints
[bypass_export_some_errors] unpatch transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_AttentionMaskConverter: _make_causal_mask
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_DynamicCache: reorder_cache, update, crop, from_batch_splits, get_seq_length
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers - patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[_unregister_cache_serialization] unregistered MambaCache
[_unregister_cache_serialization] skip unregister DynamicCache
doc.plot_legend("Tiny-LLM patched", "torch.export.export", "green")
plot export tiny llm patched

Total running time of the script: (0 minutes 7.626 seconds)

Related examples

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Export with DynamicCache and dynamic shapes

Export with DynamicCache and dynamic shapes

Untrained microsoft/phi-2

Untrained microsoft/phi-2

Gallery generated by Sphinx-Gallery