Export Tiny-LLM with patches

Many models from transformers cannot be converted because the implementation uses cache classes. Let’s see how to get around that. We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture. This continues example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM).

Errors

They depend on transformers version.

transformers>=4.40,<4.50 cannot serialize DynamicCache and cannot map dynamic shapes to instances of DynamicCache. The following errors would appear:

torch._dynamo.exc.UserError: Cannot associate shape
    [[{0: <class '....batch'>, 2: <class '....cache_length'>}],
     [{0: <class '....batch'>, 2: <class '....cache_length'>}]]
    specified at `dynamic_shapes['past_key_values']`
    to non-tensor type <class 'transformers.cache_utils.DynamicCache'>
    at `inputs['past_key_values']` (expected None)
For more information about this error,
see: https://docs.pytorch.org/docs/stable/generated/exportdb/index.html#dynamic-shapes-validation

With transformers==4.50, it shows the following:

torch._dynamo.exc.UserError: Constraints violated (batch)!
For more information, run with TORCH_LOGS="+dynamic".
    - Not all values of batch = L['args'][1]['input_ids'].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
    - Not all values of batch = L['args'][1]['attention_mask'].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
    - Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
    - Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0]
        in the specified range batch <= 1024 are valid
        because batch was inferred to be a constant (2).
 Suggested fixes:
     batch = 2

However, this package implements a patch mechanism with replaces the part causing these issues.

Note

restart after an export failure

If the export fails, it is better to start executing again, or restart the kernel if you are in the notebook. The export may leave torch in one unstable state.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.llms import get_tiny_llm


experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

cloned_inputs = copy.deepcopy(inputs)

Let’s show this inputs, this was inferred in example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM).

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))

And the dynamic shapes

{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
 'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}

Before exporting, we check transformers.cache_utils.DynamicCache can serialized and deserialized otherwise torch.export.export() fails.

print("-- DynamicCache registered: ", is_cache_dynamic_registered())
-- DynamicCache registered:  True

If they are not registered, function onnx_diagnostic.torch_export_patches.torch_export_patches() should take care of it. Then we export.

with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
    assert is_cache_dynamic_registered()  # it must be true here
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=modificator(cloned_inputs),
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,  # mandatory for torch==2.6
    )
    print("It worked:")
    print(ep)
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_fix_registration] BaseModelOutput is unregistered and registered first
[unregister_cache_serialization] unregistered BaseModelOutput
[register_class_serialization] ---------- register BaseModelOutput
[_fix_registration] BaseModelOutput done.
[register_class_serialization] ---------- register DynamicCache
[register_class_serialization] ---------- register HybridCache
[register_class_serialization] ---------- register MambaCache
[register_class_serialization] ---------- register EncoderDecoderCache
[register_class_serialization] ---------- register SlidingWindowCache
[register_class_serialization] ---------- register StaticCache
[register_class_serialization] already registered BaseModelOutput
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250828+cu126'
[torch_export_patches] stop_if_static=0
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] transformers.__version__='4.56.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[patch_module_or_classes] function: transformers.models.bart.modeling_bart.eager_attention_forward
[patch_module_or_classes] function: transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] patches transformers.masking_utils.eager_mask
[torch_export_patches] done patching
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 1, s31, 96]", past_key_values_value_cache_0: "f32[s23, 1, s11, 96]"):
             #
            sym_size_int_14: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_19: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
            sym_size_int_20: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
            sym_size_int_21: "Sym(s11)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 2)

            # No stacktrace found for following nodes
            empty: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_1: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat: "f32[s23, 1, s31, 96]" = torch.ops.aten.cat.default([empty, past_key_values_key_cache_0], -2);  empty = past_key_values_key_cache_0 = None
            cat_1: "f32[s23, 1, s11, 96]" = torch.ops.aten.cat.default([empty_1, past_key_values_value_cache_0], -2);  empty_1 = past_key_values_value_cache_0 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[s23, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:376 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s31 + s70)" = sym_size_int_20 + sym_size_int_14

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
            arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_20 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_19, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]);  arange_2 = None
            reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]);  arange_3 = None
            reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]);  arange = None
            reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]);  add_ = None
            expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape = None
            expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape_1 = expand_1 = None
            expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape_2 = None
            expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape_3 = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
            le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2);  expand_2 = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
            and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:99 in forward, code: return torch.ops.aten.index(x, indices)
            index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]);  to = expand = expand_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
            to_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  index = None
            and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1040 in forward, code: self.inv_freq[None, :, None]
            unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
            unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2);  unsqueeze = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1041 in forward, code: .float()
            _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
            to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32);  unsqueeze_1 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1042 in forward, code: .expand(position_ids.shape[0], -1, 1)
            expand_4: "f32[s23, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_19, -1, 1]);  to_3 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1043 in forward, code: .to(x.device)
            _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
            to_4: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_4 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1045 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
            unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1);  position_ids = None
            slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
            _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
            to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32);  slice_1 = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
            mul: "f32[s23, s70, 96]" = wrap_with_autocast[0]

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
            mul_1: "f32[s23, s70, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1058 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
            to_8: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
            _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
            to_9: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s23, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_19, sym_size_int_14, -1, 96]);  linear = None
            transpose_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_19, sym_size_int_14, -1, 96]);  linear_1 = None
            transpose_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_19, sym_size_int_14, -1, 96]);  linear_2 = None
            transpose_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
            unsqueeze_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
            mul_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
            slice_2: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_3: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s23, 2, s70, 48]" = torch.ops.aten.neg.default(slice_3);  slice_3 = None
            cat_3: "f32[s23, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_2], -1);  neg = slice_2 = None
            mul_5: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_3, unsqueeze_4);  cat_3 = None
            add_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3);  unsqueeze_3 = None
            slice_4: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_5: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s23, 1, s70, 48]" = torch.ops.aten.neg.default(slice_5);  slice_5 = None
            cat_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_4], -1);  neg_1 = slice_4 = None
            mul_7: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_4, unsqueeze_4);  cat_4 = unsqueeze_4 = None
            add_5: "f32[s23, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_5: "f32[s23, 1, s31 + s70, 96]" = torch.ops.aten.cat.default([cat, add_5], -2);  cat = add_5 = None
            cat_6: "f32[s23, 1, s11 + s70, 96]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2);  cat_1 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
            unsqueeze_5: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_5, 2)
            slice_6: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807);  unsqueeze_5 = None
            expand_5: "f32[s23, 1, 2, s31 + s70, 96]" = torch.ops.aten.expand.default(slice_6, [sym_size_int_19, 1, 2, add, 96]);  slice_6 = None
            reshape_4: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.reshape.default(expand_5, [sym_size_int_19, 2, add, 96]);  expand_5 = None
            unsqueeze_6: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_6, 2)
            add_10: "Sym(s11 + s70)" = sym_size_int_21 + sym_size_int_14;  sym_size_int_21 = None
            slice_7: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807);  unsqueeze_6 = None
            expand_6: "f32[s23, 1, 2, s11 + s70, 96]" = torch.ops.aten.expand.default(slice_7, [sym_size_int_19, 1, 2, add_10, 96]);  slice_7 = None
            reshape_5: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.reshape.default(expand_6, [sym_size_int_19, 2, add_10, 96]);  expand_6 = add_10 = None
            slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add);  and_2 = add = None
            scaled_dot_product_attention: "f32[s23, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_4, reshape_4, reshape_5, slice_8, scale = 0.10206207261596575);  add_4 = reshape_4 = reshape_5 = slice_8 = None
            transpose_4: "f32[s23, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_6: "f32[s23, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_19, sym_size_int_14, -1]);  transpose_4 = sym_size_int_19 = sym_size_int_14 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(reshape_6, p_model_layers_0_self_attn_o_proj_weight);  reshape_6 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
            add_6: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
            mul_16: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_16, torch.float32);  mul_16 = None
            mul_17: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13);  p_model_layers_0_post_attention_layernorm_weight = to_13 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s23, s70, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight);  mul_17 = p_model_layers_0_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_18: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight);  mul_18 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
            add_8: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6);  to_12 = linear_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_8, torch.float32);  add_8 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9);  add_9 = None
            mul_19: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2);  to_14 = rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_19, torch.float32);  mul_19 = None
            mul_20: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15);  p_model_norm_weight = to_15 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:473 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_9: "f32[s23, s70, 192]" = torch.ops.aten.slice.Tensor(mul_20, 1, 0, 9223372036854775807);  mul_20 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_9, p_lm_head_weight);  slice_9 = p_lm_head_weight = None
            return (linear_7, cat_5, cat_6)

        class submod_1(torch.nn.Module):
            def forward(self, to_4: "f32[s23, 48, 1]", to_5: "f32[s23, 1, s70]"):
                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1053 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_6: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                to_7: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                matmul: "f32[s23, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                transpose: "f32[s23, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                cat_2: "f32[s23, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
                cos: "f32[s23, s70, 96]" = torch.ops.aten.cos.default(cat_2)
                mul: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
                sin: "f32[s23, s70, 96]" = torch.ops.aten.sin.default(cat_2);  cat_2 = None
                mul_1: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                return (mul, mul_1)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_5: USER_OUTPUT
    cat_6: USER_OUTPUT

Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo]}

[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored shape constraints
[torch_export_patches] unpatches transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[unpatch_module_or_classes] function transformers.models.bart.modeling_bart.eager_attention_forward
[unpatch_module_or_classes] function transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] restored transformers.masking_utils.eager_mask

With the original model

MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

cloned_inputs = copy.deepcopy(inputs)

with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
    ep = torch.export.export(
        model,
        (),
        kwargs=modificator(cloned_inputs),
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,  # mandatory for torch==2.6
    )
    print("It worked:")
    print(ep)
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_fix_registration] DynamicCache is unregistered and registered first
[unregister_cache_serialization] unregistered DynamicCache
[register_class_serialization] ---------- register DynamicCache
[_fix_registration] DynamicCache done.
[register_class_serialization] already registered DynamicCache
[register_class_serialization] already registered HybridCache
[register_class_serialization] already registered MambaCache
[register_class_serialization] already registered EncoderDecoderCache
[register_class_serialization] already registered SlidingWindowCache
[register_class_serialization] already registered StaticCache
[register_class_serialization] already registered BaseModelOutput
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250828+cu126'
[torch_export_patches] stop_if_static=0
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] transformers.__version__='4.56.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[patch_module_or_classes] function: transformers.models.bart.modeling_bart.eager_attention_forward
[patch_module_or_classes] function: transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] patches transformers.masking_utils.eager_mask
[torch_export_patches] done patching
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 1, s31, 96]", past_key_values_value_cache_0: "f32[s23, 1, s11, 96]"):
             #
            sym_size_int_14: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_19: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
            sym_size_int_20: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
            sym_size_int_21: "Sym(s11)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 2)

            # No stacktrace found for following nodes
            empty: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_1: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat: "f32[s23, 1, s31, 96]" = torch.ops.aten.cat.default([empty, past_key_values_key_cache_0], -2);  empty = past_key_values_key_cache_0 = None
            cat_1: "f32[s23, 1, s11, 96]" = torch.ops.aten.cat.default([empty_1, past_key_values_value_cache_0], -2);  empty_1 = past_key_values_value_cache_0 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[s23, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:376 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s31 + s70)" = sym_size_int_20 + sym_size_int_14

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
            arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_20 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_19, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]);  arange_2 = None
            reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]);  arange_3 = None
            reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]);  arange = None
            reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]);  add_ = None
            expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape = None
            expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape_1 = expand_1 = None
            expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape_2 = None
            expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_19, 1, sym_size_int_14, add]);  reshape_3 = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
            le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2);  expand_2 = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
            and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:99 in forward, code: return torch.ops.aten.index(x, indices)
            index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]);  to = expand = expand_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
            to_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  index = None
            and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1040 in forward, code: self.inv_freq[None, :, None]
            unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
            unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2);  unsqueeze = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1041 in forward, code: .float()
            _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
            to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32);  unsqueeze_1 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1042 in forward, code: .expand(position_ids.shape[0], -1, 1)
            expand_4: "f32[s23, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_19, -1, 1]);  to_3 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1043 in forward, code: .to(x.device)
            _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
            to_4: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_4 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1045 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
            unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1);  position_ids = None
            slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
            _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
            to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32);  slice_1 = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
            mul: "f32[s23, s70, 96]" = wrap_with_autocast[0]

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
            mul_1: "f32[s23, s70, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

             # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1058 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
            to_8: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
            _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
            to_9: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s23, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_19, sym_size_int_14, -1, 96]);  linear = None
            transpose_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_19, sym_size_int_14, -1, 96]);  linear_1 = None
            transpose_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_19, sym_size_int_14, -1, 96]);  linear_2 = None
            transpose_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
            unsqueeze_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
            mul_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
            slice_2: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_3: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s23, 2, s70, 48]" = torch.ops.aten.neg.default(slice_3);  slice_3 = None
            cat_3: "f32[s23, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_2], -1);  neg = slice_2 = None
            mul_5: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_3, unsqueeze_4);  cat_3 = None
            add_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3);  unsqueeze_3 = None
            slice_4: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_5: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s23, 1, s70, 48]" = torch.ops.aten.neg.default(slice_5);  slice_5 = None
            cat_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_4], -1);  neg_1 = slice_4 = None
            mul_7: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_4, unsqueeze_4);  cat_4 = unsqueeze_4 = None
            add_5: "f32[s23, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_5: "f32[s23, 1, s31 + s70, 96]" = torch.ops.aten.cat.default([cat, add_5], -2);  cat = add_5 = None
            cat_6: "f32[s23, 1, s11 + s70, 96]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2);  cat_1 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
            unsqueeze_5: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_5, 2)
            slice_6: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807);  unsqueeze_5 = None
            expand_5: "f32[s23, 1, 2, s31 + s70, 96]" = torch.ops.aten.expand.default(slice_6, [sym_size_int_19, 1, 2, add, 96]);  slice_6 = None
            reshape_4: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.reshape.default(expand_5, [sym_size_int_19, 2, add, 96]);  expand_5 = None
            unsqueeze_6: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_6, 2)
            add_10: "Sym(s11 + s70)" = sym_size_int_21 + sym_size_int_14;  sym_size_int_21 = None
            slice_7: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807);  unsqueeze_6 = None
            expand_6: "f32[s23, 1, 2, s11 + s70, 96]" = torch.ops.aten.expand.default(slice_7, [sym_size_int_19, 1, 2, add_10, 96]);  slice_7 = None
            reshape_5: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.reshape.default(expand_6, [sym_size_int_19, 2, add_10, 96]);  expand_6 = add_10 = None
            slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add);  and_2 = add = None
            scaled_dot_product_attention: "f32[s23, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_4, reshape_4, reshape_5, slice_8, scale = 0.10206207261596575);  add_4 = reshape_4 = reshape_5 = slice_8 = None
            transpose_4: "f32[s23, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_6: "f32[s23, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_19, sym_size_int_14, -1]);  transpose_4 = sym_size_int_19 = sym_size_int_14 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(reshape_6, p_model_layers_0_self_attn_o_proj_weight);  reshape_6 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
            add_6: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_6, torch.float32);  add_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7);  add_7 = None
            mul_16: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_16, torch.float32);  mul_16 = None
            mul_17: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13);  p_model_layers_0_post_attention_layernorm_weight = to_13 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s23, s70, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight);  mul_17 = p_model_layers_0_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_18: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight);  mul_18 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
            add_8: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6);  to_12 = linear_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_8, torch.float32);  add_8 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9);  add_9 = None
            mul_19: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2);  to_14 = rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_19, torch.float32);  mul_19 = None
            mul_20: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15);  p_model_norm_weight = to_15 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:473 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_9: "f32[s23, s70, 192]" = torch.ops.aten.slice.Tensor(mul_20, 1, 0, 9223372036854775807);  mul_20 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_9, p_lm_head_weight);  slice_9 = p_lm_head_weight = None
            return (linear_7, cat_5, cat_6)

        class submod_1(torch.nn.Module):
            def forward(self, to_4: "f32[s23, 48, 1]", to_5: "f32[s23, 1, s70]"):
                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1053 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                to_6: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                to_7: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                matmul: "f32[s23, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                transpose: "f32[s23, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                cat_2: "f32[s23, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
                cos: "f32[s23, s70, 96]" = torch.ops.aten.cos.default(cat_2)
                mul: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
                sin: "f32[s23, s70, 96]" = torch.ops.aten.sin.default(cat_2);  cat_2 = None
                mul_1: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                return (mul, mul_1)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_5: USER_OUTPUT
    cat_6: USER_OUTPUT

Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo]}

[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored shape constraints
[torch_export_patches] unpatches transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[unpatch_module_or_classes] function transformers.models.bart.modeling_bart.eager_attention_forward
[unpatch_module_or_classes] function transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] restored transformers.masking_utils.eager_mask
doc.plot_legend("Tiny-LLM patched", "torch.export.export", "green")
plot export tiny llm patched

Total running time of the script: (0 minutes 0.819 seconds)

Related examples

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Export microsoft/phi-2

Export microsoft/phi-2

Find and fix an export issue due to dynamic shapes

Find and fix an export issue due to dynamic shapes

Gallery generated by Sphinx-Gallery