Note
Go to the end to download the full example code.
Export Tiny-LLM with patches¶
Many models from transformers cannot be converted because the implementation uses cache classes. Let’s see how to get around that. We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture. This continues example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM).
Errors¶
They depend on transformers version.
transformers>=4.40,<4.50
cannot serialize DynamicCache and cannot
map dynamic shapes to instances of DynamicCache. The following errors
would appear:
torch._dynamo.exc.UserError: Cannot associate shape
[[{0: <class '....batch'>, 2: <class '....cache_length'>}],
[{0: <class '....batch'>, 2: <class '....cache_length'>}]]
specified at `dynamic_shapes['past_key_values']`
to non-tensor type <class 'transformers.cache_utils.DynamicCache'>
at `inputs['past_key_values']` (expected None)
For more information about this error,
see: https://docs.pytorch.org/docs/stable/generated/exportdb/index.html#dynamic-shapes-validation
With transformers==4.50
, it shows the following:
torch._dynamo.exc.UserError: Constraints violated (batch)!
For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][1]['input_ids'].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['attention_mask'].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
Suggested fixes:
batch = 2
However, this package implements a patch mechanism with replaces the part causing these issues.
Note
restart after an export failure
If the export fails, it is better to start executing again, or restart the kernel if you are in the notebook. The export may leave torch in one unstable state.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.llms import get_tiny_llm
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
cloned_inputs = copy.deepcopy(inputs)
Let’s show this inputs, this was inferred in example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM).
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
And the dynamic shapes
{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
Before exporting, we check transformers.cache_utils.DynamicCache
can serialized and deserialized otherwise torch.export.export()
fails.
print("-- DynamicCache registered: ", is_cache_dynamic_registered())
-- DynamicCache registered: True
If they are not registered, function
onnx_diagnostic.torch_export_patches.torch_export_patches()
should take care of it. Then we export.
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
assert is_cache_dynamic_registered() # it must be true here
ep = torch.export.export(
untrained_model,
(),
kwargs=modificator(cloned_inputs),
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)
print("It worked:")
print(ep)
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_fix_registration] BaseModelOutput is unregistered and registered first
[unregister_cache_serialization] unregistered BaseModelOutput
[register_class_serialization] ---------- register BaseModelOutput
[_fix_registration] BaseModelOutput done.
[register_class_serialization] ---------- register DynamicCache
[register_class_serialization] ---------- register HybridCache
[register_class_serialization] ---------- register MambaCache
[register_class_serialization] ---------- register EncoderDecoderCache
[register_class_serialization] ---------- register SlidingWindowCache
[register_class_serialization] ---------- register StaticCache
[register_class_serialization] already registered BaseModelOutput
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250828+cu126'
[torch_export_patches] stop_if_static=0
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] transformers.__version__='4.56.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[patch_module_or_classes] function: transformers.models.bart.modeling_bart.eager_attention_forward
[patch_module_or_classes] function: transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] patches transformers.masking_utils.eager_mask
[torch_export_patches] done patching
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 1, s31, 96]", past_key_values_value_cache_0: "f32[s23, 1, s11, 96]"):
#
sym_size_int_14: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_19: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
sym_size_int_20: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
sym_size_int_21: "Sym(s11)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 2)
# No stacktrace found for following nodes
empty: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_1: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat: "f32[s23, 1, s31, 96]" = torch.ops.aten.cat.default([empty, past_key_values_key_cache_0], -2); empty = past_key_values_key_cache_0 = None
cat_1: "f32[s23, 1, s11, 96]" = torch.ops.aten.cat.default([empty_1, past_key_values_value_cache_0], -2); empty_1 = past_key_values_value_cache_0 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s23, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:376 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s31 + s70)" = sym_size_int_20 + sym_size_int_14
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False); sym_size_int_20 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None
arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_19, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]); arange_2 = None
reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]); arange_3 = None
reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]); arange = None
reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]); add_ = None
expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_19, 1, sym_size_int_14, add]); reshape = None
expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_19, 1, sym_size_int_14, add]); reshape_1 = expand_1 = None
expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_19, 1, sym_size_int_14, add]); reshape_2 = None
expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_19, 1, sym_size_int_14, add]); reshape_3 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2); expand_2 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); le = None
and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1); new_ones = to_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:99 in forward, code: return torch.ops.aten.index(x, indices)
index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]); to = expand = expand_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); index = None
and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2); and_1 = to_2 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1040 in forward, code: self.inv_freq[None, :, None]
unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1041 in forward, code: .float()
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32); unsqueeze_1 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1042 in forward, code: .expand(position_ids.shape[0], -1, 1)
expand_4: "f32[s23, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_19, -1, 1]); to_3 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1043 in forward, code: .to(x.device)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_4 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1045 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1); position_ids = None
slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32); slice_1 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5); submod_3 = to_4 = to_5 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s23, s70, 96]" = wrap_with_autocast[0]
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s23, s70, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1058 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11); p_model_layers_0_input_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s23, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_19, sym_size_int_14, -1, 96]); linear = None
transpose_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_19, sym_size_int_14, -1, 96]); linear_1 = None
transpose_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_19, sym_size_int_14, -1, 96]); linear_2 = None
transpose_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None
unsqueeze_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1); to_9 = None
mul_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
slice_2: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_3: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s23, 2, s70, 48]" = torch.ops.aten.neg.default(slice_3); slice_3 = None
cat_3: "f32[s23, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_2], -1); neg = slice_2 = None
mul_5: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_3, unsqueeze_4); cat_3 = None
add_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3); unsqueeze_3 = None
slice_4: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_5: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s23, 1, s70, 48]" = torch.ops.aten.neg.default(slice_5); slice_5 = None
cat_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_4], -1); neg_1 = slice_4 = None
mul_7: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_4, unsqueeze_4); cat_4 = unsqueeze_4 = None
add_5: "f32[s23, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_5: "f32[s23, 1, s31 + s70, 96]" = torch.ops.aten.cat.default([cat, add_5], -2); cat = add_5 = None
cat_6: "f32[s23, 1, s11 + s70, 96]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2); cat_1 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
unsqueeze_5: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_5, 2)
slice_6: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807); unsqueeze_5 = None
expand_5: "f32[s23, 1, 2, s31 + s70, 96]" = torch.ops.aten.expand.default(slice_6, [sym_size_int_19, 1, 2, add, 96]); slice_6 = None
reshape_4: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.reshape.default(expand_5, [sym_size_int_19, 2, add, 96]); expand_5 = None
unsqueeze_6: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_6, 2)
add_10: "Sym(s11 + s70)" = sym_size_int_21 + sym_size_int_14; sym_size_int_21 = None
slice_7: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807); unsqueeze_6 = None
expand_6: "f32[s23, 1, 2, s11 + s70, 96]" = torch.ops.aten.expand.default(slice_7, [sym_size_int_19, 1, 2, add_10, 96]); slice_7 = None
reshape_5: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.reshape.default(expand_6, [sym_size_int_19, 2, add_10, 96]); expand_6 = add_10 = None
slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add); and_2 = add = None
scaled_dot_product_attention: "f32[s23, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_4, reshape_4, reshape_5, slice_8, scale = 0.10206207261596575); add_4 = reshape_4 = reshape_5 = slice_8 = None
transpose_4: "f32[s23, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_6: "f32[s23, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_19, sym_size_int_14, -1]); transpose_4 = sym_size_int_19 = sym_size_int_14 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(reshape_6, p_model_layers_0_self_attn_o_proj_weight); reshape_6 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
add_6: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3); to_10 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_6, torch.float32); add_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7); add_7 = None
mul_16: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_16, torch.float32); mul_16 = None
mul_17: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13); p_model_layers_0_post_attention_layernorm_weight = to_13 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s23, s70, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight); mul_17 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_18: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight); mul_18 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
add_8: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6); to_12 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_14: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_8, torch.float32); add_8 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_19: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2); to_14 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_15: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_19, torch.float32); mul_19 = None
mul_20: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15); p_model_norm_weight = to_15 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:473 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_9: "f32[s23, s70, 192]" = torch.ops.aten.slice.Tensor(mul_20, 1, 0, 9223372036854775807); mul_20 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_9, p_lm_head_weight); slice_9 = p_lm_head_weight = None
return (linear_7, cat_5, cat_6)
class submod_1(torch.nn.Module):
def forward(self, to_4: "f32[s23, 48, 1]", to_5: "f32[s23, 1, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1053 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32); to_5 = None
matmul: "f32[s23, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7); to_6 = to_7 = None
transpose: "f32[s23, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat_2: "f32[s23, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s23, s70, 96]" = torch.ops.aten.cos.default(cat_2)
mul: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s23, s70, 96]" = torch.ops.aten.sin.default(cat_2); cat_2 = None
mul_1: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_5: USER_OUTPUT
cat_6: USER_OUTPUT
Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo]}
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored shape constraints
[torch_export_patches] unpatches transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[unpatch_module_or_classes] function transformers.models.bart.modeling_bart.eager_attention_forward
[unpatch_module_or_classes] function transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] restored transformers.masking_utils.eager_mask
With the original model¶
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
cloned_inputs = copy.deepcopy(inputs)
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
ep = torch.export.export(
model,
(),
kwargs=modificator(cloned_inputs),
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)
print("It worked:")
print(ep)
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_fix_registration] DynamicCache is unregistered and registered first
[unregister_cache_serialization] unregistered DynamicCache
[register_class_serialization] ---------- register DynamicCache
[_fix_registration] DynamicCache done.
[register_class_serialization] already registered DynamicCache
[register_class_serialization] already registered HybridCache
[register_class_serialization] already registered MambaCache
[register_class_serialization] already registered EncoderDecoderCache
[register_class_serialization] already registered SlidingWindowCache
[register_class_serialization] already registered StaticCache
[register_class_serialization] already registered BaseModelOutput
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250828+cu126'
[torch_export_patches] stop_if_static=0
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] transformers.__version__='4.56.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[patch_module_or_classes] function: transformers.models.bart.modeling_bart.eager_attention_forward
[patch_module_or_classes] function: transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] patches transformers.masking_utils.eager_mask
[torch_export_patches] done patching
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 1, s31, 96]", past_key_values_value_cache_0: "f32[s23, 1, s11, 96]"):
#
sym_size_int_14: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_19: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
sym_size_int_20: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
sym_size_int_21: "Sym(s11)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 2)
# No stacktrace found for following nodes
empty: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_1: "f32[s23, 1, 0, 96]" = torch.ops.aten.empty.memory_format([sym_size_int_19, 1, 0, 96], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat: "f32[s23, 1, s31, 96]" = torch.ops.aten.cat.default([empty, past_key_values_key_cache_0], -2); empty = past_key_values_key_cache_0 = None
cat_1: "f32[s23, 1, s11, 96]" = torch.ops.aten.cat.default([empty_1, past_key_values_value_cache_0], -2); empty_1 = past_key_values_value_cache_0 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s23, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:376 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s31 + s70)" = sym_size_int_20 + sym_size_int_14
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:375 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False); sym_size_int_20 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None
arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_19, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]); arange_2 = None
reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]); arange_3 = None
reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]); arange = None
reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]); add_ = None
expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_19, 1, sym_size_int_14, add]); reshape = None
expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_19, 1, sym_size_int_14, add]); reshape_1 = expand_1 = None
expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_19, 1, sym_size_int_14, add]); reshape_2 = None
expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_19, 1, sym_size_int_14, add]); reshape_3 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2); expand_2 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); le = None
and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1); new_ones = to_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:99 in forward, code: return torch.ops.aten.index(x, indices)
index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]); to = expand = expand_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:382 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); index = None
and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2); and_1 = to_2 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1040 in forward, code: self.inv_freq[None, :, None]
unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1041 in forward, code: .float()
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32); unsqueeze_1 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1042 in forward, code: .expand(position_ids.shape[0], -1, 1)
expand_4: "f32[s23, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_19, -1, 1]); to_3 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1043 in forward, code: .to(x.device)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_4 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1045 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1); position_ids = None
slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32); slice_1 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5); submod_3 = to_4 = to_5 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s23, s70, 96]" = wrap_with_autocast[0]
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s23, s70, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1058 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11); p_model_layers_0_input_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:236 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s23, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_19, sym_size_int_14, -1, 96]); linear = None
transpose_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_19, sym_size_int_14, -1, 96]); linear_1 = None
transpose_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:238 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_19, sym_size_int_14, -1, 96]); linear_2 = None
transpose_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:241 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None
unsqueeze_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1); to_9 = None
mul_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
slice_2: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_3: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s23, 2, s70, 48]" = torch.ops.aten.neg.default(slice_3); slice_3 = None
cat_3: "f32[s23, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_2], -1); neg = slice_2 = None
mul_5: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_3, unsqueeze_4); cat_3 = None
add_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3); unsqueeze_3 = None
slice_4: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_5: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s23, 1, s70, 48]" = torch.ops.aten.neg.default(slice_5); slice_5 = None
cat_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_4], -1); neg_1 = slice_4 = None
mul_7: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_4, unsqueeze_4); cat_4 = unsqueeze_4 = None
add_5: "f32[s23, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:246 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_5: "f32[s23, 1, s31 + s70, 96]" = torch.ops.aten.cat.default([cat, add_5], -2); cat = add_5 = None
cat_6: "f32[s23, 1, s11 + s70, 96]" = torch.ops.aten.cat.default([cat_1, transpose_3], -2); cat_1 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:252 in forward, code: attn_output, attn_weights = attention_interface(
unsqueeze_5: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_5, 2)
slice_6: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807); unsqueeze_5 = None
expand_5: "f32[s23, 1, 2, s31 + s70, 96]" = torch.ops.aten.expand.default(slice_6, [sym_size_int_19, 1, 2, add, 96]); slice_6 = None
reshape_4: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.reshape.default(expand_5, [sym_size_int_19, 2, add, 96]); expand_5 = None
unsqueeze_6: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_6, 2)
add_10: "Sym(s11 + s70)" = sym_size_int_21 + sym_size_int_14; sym_size_int_21 = None
slice_7: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807); unsqueeze_6 = None
expand_6: "f32[s23, 1, 2, s11 + s70, 96]" = torch.ops.aten.expand.default(slice_7, [sym_size_int_19, 1, 2, add_10, 96]); slice_7 = None
reshape_5: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.reshape.default(expand_6, [sym_size_int_19, 2, add_10, 96]); expand_6 = add_10 = None
slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add); and_2 = add = None
scaled_dot_product_attention: "f32[s23, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_4, reshape_4, reshape_5, slice_8, scale = 0.10206207261596575); add_4 = reshape_4 = reshape_5 = slice_8 = None
transpose_4: "f32[s23, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:263 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_6: "f32[s23, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_19, sym_size_int_14, -1]); transpose_4 = sym_size_int_19 = sym_size_int_14 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(reshape_6, p_model_layers_0_self_attn_o_proj_weight); reshape_6 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:304 in forward, code: hidden_states = residual + hidden_states
add_6: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3); to_10 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_6, torch.float32); add_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7); add_7 = None
mul_16: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_16, torch.float32); mul_16 = None
mul_17: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13); p_model_layers_0_post_attention_layernorm_weight = to_13 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:473 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s23, s70, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight); mul_17 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:155 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_18: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight); mul_18 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: hidden_states = residual + hidden_states
add_8: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6); to_12 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_14: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_8, torch.float32); add_8 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_19: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2); to_14 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_15: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_19, torch.float32); mul_19 = None
mul_20: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15); p_model_norm_weight = to_15 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:473 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_9: "f32[s23, s70, 192]" = torch.ops.aten.slice.Tensor(mul_20, 1, 0, 9223372036854775807); mul_20 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_9, p_lm_head_weight); slice_9 = p_lm_head_weight = None
return (linear_7, cat_5, cat_6)
class submod_1(torch.nn.Module):
def forward(self, to_4: "f32[s23, 48, 1]", to_5: "f32[s23, 1, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1053 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32); to_5 = None
matmul: "f32[s23, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7); to_6 = to_7 = None
transpose: "f32[s23, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1054 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat_2: "f32[s23, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1055 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s23, s70, 96]" = torch.ops.aten.cos.default(cat_2)
mul: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1056 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s23, s70, 96]" = torch.ops.aten.sin.default(cat_2); cat_2 = None
mul_1: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_5: USER_OUTPUT
cat_6: USER_OUTPUT
Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo]}
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored shape constraints
[torch_export_patches] unpatches transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicLayer: lazy_initialization
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Qwen3MoeSparseMoeBlock: forward, _forward_expert_loop
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_VisionAttention: forward
[unpatch_module_or_classes] function transformers.models.bart.modeling_bart.eager_attention_forward
[unpatch_module_or_classes] function transformers.models.marian.modeling_marian.eager_attention_forward
[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] restored transformers.masking_utils.eager_mask
doc.plot_legend("Tiny-LLM patched", "torch.export.export", "green")

Total running time of the script: (0 minutes 0.819 seconds)
Related examples

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Find and fix an export issue due to dynamic shapes