Note
Go to the end to download the full example code.
Export Tiny-LLM with patches¶
Many models from transformers cannot be converted because the implementation uses cache classes. Let’s see how to get around that. We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture. This continues example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM).
Errors¶
They depend on transformers version.
transformers>=4.40,<4.50
cannot serialize DynamicCache and cannot
map dynamic shapes to instances of DynamicCache. The following errors
would appear:
torch._dynamo.exc.UserError: Cannot associate shape
[[{0: <class '....batch'>, 2: <class '....cache_length'>}],
[{0: <class '....batch'>, 2: <class '....cache_length'>}]]
specified at `dynamic_shapes['past_key_values']`
to non-tensor type <class 'transformers.cache_utils.DynamicCache'>
at `inputs['past_key_values']` (expected None)
For more information about this error,
see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
With transformers==4.50
, it shows the following:
torch._dynamo.exc.UserError: Constraints violated (batch)!
For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][1]['input_ids'].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['attention_mask'].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0]
in the specified range batch <= 1024 are valid
because batch was inferred to be a constant (2).
Suggested fixes:
batch = 2
However, this package implements a patch mechanism with replaces the part causing these issues.
Note
restart after an export failure
If the export fails, it is better to start executing again, or restart the kernel if you are in the notebook. The export may leave torch in one unstable state.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.llms import get_tiny_llm
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
cloned_inputs = copy.deepcopy(inputs)
Let’s show this inputs, this was inferred in example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM).
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
And the dynamic shapes
{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
Before exporting, we check transformers.cache_utils.DynamicCache
can serialized and deserialized otherwise torch.export.export()
fails.
print("-- DynamicCache registered: ", is_cache_dynamic_registered())
-- DynamicCache registered: True
If they are not registered, function
onnx_diagnostic.torch_export_patches.torch_export_patches()
should take care of it. Then we export.
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
assert is_cache_dynamic_registered() # it must be true here
ep = torch.export.export(
untrained_model,
(),
kwargs=modificator(cloned_inputs),
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)
print("It worked:")
print(ep)
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_fix_registration] DynamicCache is unregistered and registered first
[unregister_cache_serialization] unregistered DynamicCache
[register_class_serialization] ---------- register DynamicCache
[_fix_registration] DynamicCache done.
[_fix_registration] BaseModelOutput is unregistered and registered first
[unregister_cache_serialization] unregistered BaseModelOutput
[register_class_serialization] ---------- register BaseModelOutput
[_fix_registration] BaseModelOutput done.
[register_class_serialization] already registered DynamicCache
[register_class_serialization] ---------- register HybridCache
[register_class_serialization] ---------- register MambaCache
[register_class_serialization] ---------- register EncoderDecoderCache
[register_class_serialization] ---------- register SlidingWindowCache
[register_class_serialization] ---------- register StaticCache
[register_class_serialization] already registered BaseModelOutput
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250723+cu126'
[torch_export_patches] stop_if_static=0
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] transformers.__version__='4.54.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[patch_module_or_classes] function: transformers.models.bart.modeling_bart.eager_attention_forward
[patch_module_or_classes] function: transformers.models.marian.modeling_marian.eager_attention_forward
[patch_module_or_classes] function: transformers.cache_utils.parse_processor_args
[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] patches transformers.masking_utils.eager_mask
[torch_export_patches] done patching
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 1, s31, 96]", past_key_values_value_cache_0: "f32[s23, 1, s11, 96]"):
#
sym_size_int_15: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_18: "Sym(s23)" = torch.ops.aten.sym_size.int(position_ids, 0)
sym_size_int_20: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
sym_size_int_21: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
sym_size_int_22: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 0)
sym_size_int_23: "Sym(s11)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 2)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s23, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
#
eq_39: "Sym(True)" = sym_size_int_18 == sym_size_int_20; sym_size_int_18 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_39, "Runtime assertion failed for expression Eq(s44, s23) on node 'eq_39'"); eq_39 = _assert_scalar_default = None
eq_40: "Sym(True)" = sym_size_int_20 == sym_size_int_22; sym_size_int_22 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_40, "Runtime assertion failed for expression Eq(s23, s4) on node 'eq_40'"); eq_40 = _assert_scalar_default_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:371 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s31 + s70)" = sym_size_int_21 + sym_size_int_15
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:370 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_21, add, device = device(type='cpu'), pin_memory = False); sym_size_int_21 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:377 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None
arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_20, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]); arange_2 = None
reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]); arange_3 = None
reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]); arange = None
reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]); add_ = None
expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_20, 1, sym_size_int_15, add]); reshape = None
expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_20, 1, sym_size_int_15, add]); reshape_1 = expand_1 = None
expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_20, 1, sym_size_int_15, add]); reshape_2 = None
expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_20, 1, sym_size_int_15, add]); reshape_3 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2); expand_2 = None
and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, le); new_ones = le = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:100 in forward, code: return torch.ops.aten.index(x, indices)
index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]); to = expand = expand_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:377 in forward, code: causal_mask = create_causal_mask(
and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, index); and_1 = index = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_20, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1026 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s23, s70, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[s23, s70, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:232 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s23, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_15, -1, 96]); linear = None
transpose_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:233 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_15, -1, 96]); linear_1 = None
transpose_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:234 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_15, -1, 96]); linear_2 = None
transpose_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
slice_2: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_3: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s23, 2, s70, 48]" = torch.ops.aten.neg.default(slice_3); slice_3 = None
cat_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_2], -1); neg = slice_2 = None
mul_5: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_4); cat_1 = None
add_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3); unsqueeze_3 = None
slice_4: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_5: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s23, 1, s70, 48]" = torch.ops.aten.neg.default(slice_5); slice_5 = None
cat_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_4], -1); neg_1 = slice_4 = None
mul_7: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_4); cat_2 = unsqueeze_4 = None
add_5: "f32[s23, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:242 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s23, 1, s31 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s23, 1, s11 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:248 in forward, code: attn_output, attn_weights = attention_interface(
unsqueeze_5: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_3, 2)
slice_6: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807); unsqueeze_5 = None
expand_5: "f32[s23, 1, 2, s31 + s70, 96]" = torch.ops.aten.expand.default(slice_6, [sym_size_int_20, 1, 2, add, 96]); slice_6 = None
reshape_4: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.reshape.default(expand_5, [sym_size_int_20, 2, add, 96]); expand_5 = None
unsqueeze_6: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_4, 2)
add_10: "Sym(s11 + s70)" = sym_size_int_23 + sym_size_int_15; sym_size_int_23 = None
slice_7: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807); unsqueeze_6 = None
expand_6: "f32[s23, 1, 2, s11 + s70, 96]" = torch.ops.aten.expand.default(slice_7, [sym_size_int_20, 1, 2, add_10, 96]); slice_7 = None
reshape_5: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.reshape.default(expand_6, [sym_size_int_20, 2, add_10, 96]); expand_6 = add_10 = None
slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add); and_2 = add = None
contiguous: "f32[s23, 2, s70, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.contiguous.default(reshape_4); reshape_4 = None
contiguous_2: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.contiguous.default(reshape_5); reshape_5 = None
scaled_dot_product_attention: "f32[s23, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_8, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_8 = None
transpose_4: "f32[s23, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s23, s70, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:259 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_6: "f32[s23, s70, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_15, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(reshape_6, p_model_layers_0_self_attn_o_proj_weight); reshape_6 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: hidden_states = residual + hidden_states
add_6: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_6, torch.float32); add_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7); add_7 = None
mul_29: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_29, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_29, torch.float32); mul_29 = None
mul_30: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_30, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s23, s70, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_30, p_model_layers_0_mlp_up_proj_weight); mul_30 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:152 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_31: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_31, p_model_layers_0_mlp_down_proj_weight); mul_31 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:305 in forward, code: hidden_states = residual + hidden_states
add_8: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_8, torch.float32); add_8 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_32: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_32, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_32, torch.float32); mul_32 = None
mul_33: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:474 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_9: "f32[s23, s70, 192]" = torch.ops.aten.slice.Tensor(mul_33, 1, 0, 9223372036854775807); mul_33 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_9, p_lm_head_weight); slice_9 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_20: "Sym(s23)", position_ids: "i64[s23, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1008 in forward, code: self.inv_freq[None, :, None]
unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1009 in forward, code: .float()
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32); unsqueeze_1 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1010 in forward, code: .expand(position_ids.shape[0], -1, 1)
expand_4: "f32[s23, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_20, -1, 1]); to_1 = sym_size_int_20 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1011 in forward, code: .to(x.device)
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_4 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1013 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1); position_ids = None
slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32); slice_1 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_2, to_3); submod_3 = to_2 = to_3 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1023 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s23, s70, 96]" = wrap_with_autocast[0]
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1024 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s23, s70, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1026 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, to_2: "f32[s23, 48, 1]", to_3: "f32[s23, 1, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1021 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_3, torch.float32); to_3 = None
matmul: "f32[s23, 48, s70]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[s23, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1022 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s23, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1023 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s23, s70, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1024 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s23, s70, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo]}
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored shape constraints
[torch_export_patches] unpatches transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[unpatch_module_or_classes] function transformers.models.bart.modeling_bart.eager_attention_forward
[unpatch_module_or_classes] function transformers.models.marian.modeling_marian.eager_attention_forward
[unpatch_module_or_classes] function transformers.cache_utils.parse_processor_args
[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] restored transformers.masking_utils.eager_mask
With the original model¶
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
cloned_inputs = copy.deepcopy(inputs)
with torch_export_patches(patch_transformers=True, verbose=10) as modificator:
ep = torch.export.export(
model,
(),
kwargs=modificator(cloned_inputs),
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False, # mandatory for torch==2.6
)
print("It worked:")
print(ep)
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[register_class_serialization] already registered DynamicCache
[register_class_serialization] already registered HybridCache
[register_class_serialization] already registered MambaCache
[register_class_serialization] already registered EncoderDecoderCache
[register_class_serialization] already registered SlidingWindowCache
[register_class_serialization] already registered StaticCache
[register_class_serialization] already registered BaseModelOutput
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250723+cu126'
[torch_export_patches] stop_if_static=0
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] transformers.__version__='4.54.0.dev0'
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[patch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[patch_module_or_classes] function: transformers.models.bart.modeling_bart.eager_attention_forward
[patch_module_or_classes] function: transformers.models.marian.modeling_marian.eager_attention_forward
[patch_module_or_classes] function: transformers.cache_utils.parse_processor_args
[torch_export_patches] patches transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] patches transformers.masking_utils.eager_mask
[torch_export_patches] done patching
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s23, s70]", attention_mask: "i64[s23, s53]", position_ids: "i64[s23, s70]", past_key_values_key_cache_0: "f32[s23, 1, s31, 96]", past_key_values_value_cache_0: "f32[s23, 1, s11, 96]"):
#
sym_size_int_15: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_18: "Sym(s23)" = torch.ops.aten.sym_size.int(position_ids, 0)
sym_size_int_20: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 0)
sym_size_int_21: "Sym(s31)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
sym_size_int_22: "Sym(s23)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 0)
sym_size_int_23: "Sym(s11)" = torch.ops.aten.sym_size.int(past_key_values_value_cache_0, 2)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s23, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
#
eq_39: "Sym(True)" = sym_size_int_18 == sym_size_int_20; sym_size_int_18 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq_39, "Runtime assertion failed for expression Eq(s44, s23) on node 'eq_39'"); eq_39 = _assert_scalar_default = None
eq_40: "Sym(True)" = sym_size_int_20 == sym_size_int_22; sym_size_int_22 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_40, "Runtime assertion failed for expression Eq(s23, s4) on node 'eq_40'"); eq_40 = _assert_scalar_default_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:371 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s31 + s70)" = sym_size_int_21 + sym_size_int_15
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:370 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_21, add, device = device(type='cpu'), pin_memory = False); sym_size_int_21 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:377 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s23, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s31 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_: "i64[s31 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0); arange_1 = None
arange_2: "i64[s23]" = torch.ops.aten.arange.default(sym_size_int_20, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s23, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_2, [-1, 1, 1, 1]); arange_2 = None
reshape_1: "i64[1, 1, 1, 1]" = torch.ops.aten.reshape.default(arange_3, [1, -1, 1, 1]); arange_3 = None
reshape_2: "i64[1, 1, s70, 1]" = torch.ops.aten.reshape.default(arange, [1, 1, -1, 1]); arange = None
reshape_3: "i64[1, 1, 1, s31 + s70]" = torch.ops.aten.reshape.default(add_, [1, 1, 1, -1]); add_ = None
expand: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape, [sym_size_int_20, 1, sym_size_int_15, add]); reshape = None
expand_1: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_1, [sym_size_int_20, 1, sym_size_int_15, add]); reshape_1 = expand_1 = None
expand_2: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_2, [sym_size_int_20, 1, sym_size_int_15, add]); reshape_2 = None
expand_3: "i64[s23, 1, s70, s31 + s70]" = torch.ops.aten.expand.default(reshape_3, [sym_size_int_20, 1, sym_size_int_15, add]); reshape_3 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(expand_2, [], dtype = torch.bool, pin_memory = False)
le: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.le.Tensor(expand_3, expand_2); expand_2 = None
and_1: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, le); new_ones = le = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py:100 in forward, code: return torch.ops.aten.index(x, indices)
index: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.index.Tensor(to, [expand, expand_3]); to = expand = expand_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:377 in forward, code: causal_mask = create_causal_mask(
and_2: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.__and__.Tensor(and_1, index); and_1 = index = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_20, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1026 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s23, s70, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[s23, s70, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:232 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s23, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_15, -1, 96]); linear = None
transpose_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:233 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_15, -1, 96]); linear_1 = None
transpose_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s23, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:234 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s23, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_15, -1, 96]); linear_2 = None
transpose_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:237 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_3: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_4: "f32[s23, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
slice_2: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_3: "f32[s23, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s23, 2, s70, 48]" = torch.ops.aten.neg.default(slice_3); slice_3 = None
cat_1: "f32[s23, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_2], -1); neg = slice_2 = None
mul_5: "f32[s23, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_4); cat_1 = None
add_4: "f32[s23, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3); unsqueeze_3 = None
slice_4: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_5: "f32[s23, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s23, 1, s70, 48]" = torch.ops.aten.neg.default(slice_5); slice_5 = None
cat_2: "f32[s23, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_4], -1); neg_1 = slice_4 = None
mul_7: "f32[s23, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_4); cat_2 = unsqueeze_4 = None
add_5: "f32[s23, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:242 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s23, 1, s31 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s23, 1, s11 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:248 in forward, code: attn_output, attn_weights = attention_interface(
unsqueeze_5: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_3, 2)
slice_6: "f32[s23, 1, 1, s31 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807); unsqueeze_5 = None
expand_5: "f32[s23, 1, 2, s31 + s70, 96]" = torch.ops.aten.expand.default(slice_6, [sym_size_int_20, 1, 2, add, 96]); slice_6 = None
reshape_4: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.reshape.default(expand_5, [sym_size_int_20, 2, add, 96]); expand_5 = None
unsqueeze_6: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.unsqueeze.default(cat_4, 2)
add_10: "Sym(s11 + s70)" = sym_size_int_23 + sym_size_int_15; sym_size_int_23 = None
slice_7: "f32[s23, 1, 1, s11 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807); unsqueeze_6 = None
expand_6: "f32[s23, 1, 2, s11 + s70, 96]" = torch.ops.aten.expand.default(slice_7, [sym_size_int_20, 1, 2, add_10, 96]); slice_7 = None
reshape_5: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.reshape.default(expand_6, [sym_size_int_20, 2, add_10, 96]); expand_6 = add_10 = None
slice_8: "b8[s23, 1, s70, s31 + s70]" = torch.ops.aten.slice.Tensor(and_2, 3, None, add); and_2 = add = None
contiguous: "f32[s23, 2, s70, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s23, 2, s31 + s70, 96]" = torch.ops.aten.contiguous.default(reshape_4); reshape_4 = None
contiguous_2: "f32[s23, 2, s11 + s70, 96]" = torch.ops.aten.contiguous.default(reshape_5); reshape_5 = None
scaled_dot_product_attention: "f32[s23, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_8, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_8 = None
transpose_4: "f32[s23, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s23, s70, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:259 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_6: "f32[s23, s70, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_15, -1]); contiguous_3 = sym_size_int_20 = sym_size_int_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(reshape_6, p_model_layers_0_self_attn_o_proj_weight); reshape_6 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: hidden_states = residual + hidden_states
add_6: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_6, torch.float32); add_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_7: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_7); add_7 = None
mul_29: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_29, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_29, torch.float32); mul_29 = None
mul_30: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_30, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s23, s70, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s23, s70, 1024]" = torch.ops.aten.linear.default(mul_30, p_model_layers_0_mlp_up_proj_weight); mul_30 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:152 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_31: "f32[s23, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s23, s70, 192]" = torch.ops.aten.linear.default(mul_31, p_model_layers_0_mlp_down_proj_weight); mul_31 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:305 in forward, code: hidden_states = residual + hidden_states
add_8: "f32[s23, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:63 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(add_8, torch.float32); add_8 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s23, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s23, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_9: "f32[s23, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s23, s70, 1]" = torch.ops.aten.rsqrt.default(add_9); add_9 = None
mul_32: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_32, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s23, s70, 192]" = torch.ops.aten.to.dtype(mul_32, torch.float32); mul_32 = None
mul_33: "f32[s23, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:474 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_9: "f32[s23, s70, 192]" = torch.ops.aten.slice.Tensor(mul_33, 1, 0, 9223372036854775807); mul_33 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s23, s70, 32000]" = torch.ops.aten.linear.default(slice_9, p_lm_head_weight); slice_9 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_20: "Sym(s23)", position_ids: "i64[s23, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1008 in forward, code: self.inv_freq[None, :, None]
unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1009 in forward, code: .float()
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32); unsqueeze_1 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1010 in forward, code: .expand(position_ids.shape[0], -1, 1)
expand_4: "f32[s23, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_20, -1, 1]); to_1 = sym_size_int_20 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1011 in forward, code: .to(x.device)
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_4 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1013 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
unsqueeze_2: "i64[s23, 1, s70]" = torch.ops.aten.unsqueeze.default(position_ids, 1); position_ids = None
slice_1: "i64[s23, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807); unsqueeze_2 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_1, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(slice_1, torch.float32); slice_1 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_2, to_3); submod_3 = to_2 = to_3 = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1023 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s23, s70, 96]" = wrap_with_autocast[0]
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1024 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s23, s70, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1026 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s23, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, to_2: "f32[s23, 48, 1]", to_3: "f32[s23, 1, s70]"):
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1021 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s23, 48, 1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s23, 1, s70]" = torch.ops.aten.to.dtype(to_3, torch.float32); to_3 = None
matmul: "f32[s23, 48, s70]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[s23, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1022 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s23, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1023 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s23, s70, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py:1024 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s23, s70, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s23, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s23: VR[1, 1024], s70: VR[2, int_oo], s53: VR[4, int_oo], s31: VR[2, int_oo], s11: VR[2, int_oo]}
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored shape constraints
[torch_export_patches] unpatches transformers
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_AttentionMaskConverter:
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma2RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Gemma3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GemmaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_GenerationMixin: _cache_dependant_input_preparation, _cache_dependant_input_preparation_exporting, prepare_inputs_for_generation
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsAttention: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_IdeficsEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_LlamaRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MistralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_MixtralRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi3RotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_Phi4MultimodalRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_PhiRotaryEmbedding: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SamMaskDecoder: forward
[unpatch_module_or_classes] onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_SmolLM3RotaryEmbedding: forward
[unpatch_module_or_classes] function transformers.models.bart.modeling_bart.eager_attention_forward
[unpatch_module_or_classes] function transformers.models.marian.modeling_marian.eager_attention_forward
[unpatch_module_or_classes] function transformers.cache_utils.parse_processor_args
[torch_export_patches] restored transformers.masking_utils._vmap_for_bhqkv
[torch_export_patches] restored transformers.masking_utils.eager_mask
doc.plot_legend("Tiny-LLM patched", "torch.export.export", "green")

Total running time of the script: (0 minutes 4.547 seconds)
Related examples

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Export with DynamicCache and guessed dynamic shapes