Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Inputs are always dynamic with LLMs that is why dynamic shapes needs to be specified when a LLM is exported with torch.export.export(). Most of the examples on HuggingFace use method transformers.GenerationMixin.generate() but we only want to export the model and its method forward.

That example shows to guess the inputs of this method even though the model is executed through meth generate.

We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.

Steel the forward method

The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

We rewrite the forward method to print the cache dimension.

def _forward_(*args, _f=None, **kwargs):
    assert _f is not None
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        # torch.compiler.is_exporting requires torch>=2.7
        print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
    res = _f(*args, **kwargs)
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        print("->", string_type(res, with_shape=True, with_min_max=True))
    return res


keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
    *args, _f=_f, **kwargs
)

Let’s run the model.

prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.432564735412598,8.368535995483398:A-4.234468644971028],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]),input_ids:T7s1x1[29941,29941:A29941.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.135845184326172,4.338468074798584:A-9.799261990881526],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.811427116394043,6.348220348358154:A-0.1250392806283344]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.008353379246409531]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.811427116394043,6.348220348358154:A-0.1250392806283344]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.008353379246409531]]),input_ids:T7s1x1[2610,2610:A2610.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.630133628845215,11.018857955932617:A-7.585419899705564],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.811427116394043,6.348220348358154:A-0.12102800812234901]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.009305229967787565]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.811427116394043,6.348220348358154:A-0.12102800812234901]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.009305229967787565]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.729768753051758,15.795294761657715:A-6.686570847850293],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.811427116394043,6.348220348358154:A-0.11321421906611956]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.0059324780764124325]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.811427116394043,6.348220348358154:A-0.11321421906611956]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.0059324780764124325]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.151559829711914,12.831623077392578:A-8.03768296736572],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.811427116394043,6.348220348358154:A-0.09604190427237952]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.005043171878160371]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.811427116394043,6.348220348358154:A-0.09604190427237952]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.005043171878160371]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.921062469482422,10.401754379272461:A-10.987481685367879],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.811427116394043,6.6982550621032715:A-0.08939146100766222]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.0033156730564877805]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.811427116394043,6.6982550621032715:A-0.08939146100766222]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.0033156730564877805]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.109092712402344,11.451603889465332:A-8.615241368453717],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.811427116394043,6.6982550621032715:A-0.08828816101069809]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.004018123734315143]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.811427116394043,6.6982550621032715:A-0.08828816101069809]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.004018123734315143]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.535235404968262,18.55143928527832:A-4.359501328858081],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.811427116394043,6.6982550621032715:A-0.10365996128249814]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0018190039553758197]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.811427116394043,6.6982550621032715:A-0.10365996128249814]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0018190039553758197]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.208690643310547,12.71769905090332:A-10.155632720829919],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.856588363647461,6.6982550621032715:A-0.10853694268128795]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.0013809153403028676]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.856588363647461,6.6982550621032715:A-0.10853694268128795]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.0013809153403028676]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.913084030151367,11.392576217651367:A-11.69873876139149],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.856588363647461,6.6982550621032715:A-0.1023646623671084]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A-0.00039179591874893014]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.856588363647461,6.6982550621032715:A-0.1023646623671084]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A-0.00039179591874893014]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.21382713317871,13.484505653381348:A-7.445791564540472],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.856588363647461,6.6982550621032715:A-0.08908409716247677]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A-0.0019779059926373806]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.856588363647461,6.6982550621032715:A-0.08908409716247677]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A-0.0019779059926373806]]),input_ids:T7s1x1[29955,29955:A29955.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.9020938873291,5.044755458831787:A-13.776848735461943],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.856588363647461,6.840986728668213:A-0.07805764829069327]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.00273918923637666]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.856588363647461,6.840986728668213:A-0.07805764829069327]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.00273918923637666]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.70694351196289,7.221324920654297:A-8.892709849011153],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.856588363647461,6.840986728668213:A-0.07859572630424361]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A-0.001949111976363571]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.856588363647461,6.840986728668213:A-0.07859572630424361]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A-0.001949111976363571]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.547388076782227,15.494963645935059:A-5.068759219610598],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.856588363647461,6.840986728668213:A-0.07658115250231992]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.003277233828743137]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.856588363647461,6.840986728668213:A-0.07658115250231992]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.003277233828743137]]),input_ids:T7s1x1[29946,29946:A29946.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.133716583251953,6.478907585144043:A-11.908872234937736],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.882454872131348,6.840986728668213:A-0.07968059192137224]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.003521361922739216]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.882454872131348,6.840986728668213:A-0.07968059192137224]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.003521361922739216]]),input_ids:T7s1x1[29901,29901:A29901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.94466209411621,7.7586669921875:A-12.086376501435414],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.882454872131348,6.840986728668213:A-0.07993684639399766]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.0027230712880800135]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.882454872131348,6.840986728668213:A-0.07993684639399766]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.0027230712880800135]]),input_ids:T7s1x1[29945,29945:A29945.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.87329864501953,9.426261901855469:A-10.317357872662134],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.882454872131348,6.840986728668213:A-0.07471688298367857]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.003106062676579313]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.882454872131348,6.840986728668213:A-0.07471688298367857]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.003106062676579313]]),input_ids:T7s1x1[29953,29953:A29953.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.021434783935547,5.666491508483887:A-11.43309489883529],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.882454872131348,6.840986728668213:A-0.06830911466549412]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.003428964488013009]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.882454872131348,6.840986728668213:A-0.06830911466549412]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.003428964488013009]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.502959251403809,8.647699356079102:A-3.666581474106759],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.882454872131348,6.840986728668213:A-0.06822447889040902]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.00225495119773903]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.882454872131348,6.840986728668213:A-0.06822447889040902]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.00225495119773903]]),input_ids:T7s1x1[29928,29928:A29928.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.725600242614746,10.529221534729004:A-1.5172752494900488],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.882454872131348,6.840986728668213:A-0.06527995995752246]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.00138839243755315]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.882454872131348,6.840986728668213:A-0.06527995995752246]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.00138839243755315]]),input_ids:T7s1x1[485,485:A485.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.5223307609558105,12.079378128051758:A-2.200265917599667],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.882454872131348,6.840986728668213:A-0.06148312655263333]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.001493190796611151]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.882454872131348,6.840986728668213:A-0.06148312655263333]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.001493190796611151]]),input_ids:T7s1x1[262,262:A262.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.054697036743164,3.195920944213867:A-9.912151685696095],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.882454872131348,6.840986728668213:A-0.0526473199523025]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.0016696463759874151]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.882454872131348,6.840986728668213:A-0.0526473199523025]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.0016696463759874151]]),input_ids:T7s1x1[435,435:A435.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.00013542175293,7.802364826202393:A-6.514995807819068],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.882454872131348,6.840986728668213:A-0.04733127825504628]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.0008047377704882507]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.882454872131348,6.840986728668213:A-0.04733127825504628]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.0008047377704882507]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.103503227233887,2.7568225860595703:A-9.391665586687624],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.882454872131348,6.840986728668213:A-0.04055644638731337]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.0003467011769136737]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.882454872131348,6.840986728668213:A-0.04055644638731337]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.0003467011769136737]]),input_ids:T7s1x1[322,322:A322.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.71291732788086,2.768527030944824:A-9.84994037065329],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.882454872131348,6.840986728668213:A-0.040666309959186665]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.0006287981841110609]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.882454872131348,6.840986728668213:A-0.040666309959186665]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.0006287981841110609]]),input_ids:T7s1x1[390,390:A390.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.72878646850586,8.927671432495117:A-5.238189476897475],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.882454872131348,6.840986728668213:A-0.04073903500626758]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.00030211803927355217]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.882454872131348,6.840986728668213:A-0.04073903500626758]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.00030211803927355217]]),input_ids:T7s1x1[29968,29968:A29968.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.1890926361084,4.8266496658325195:A-9.480396195146255],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.882454872131348,6.840986728668213:A-0.036860077806375645]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.00034705521119186806]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.882454872131348,6.840986728668213:A-0.036860077806375645]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.00034705521119186806]]),input_ids:T7s1x1[29928,29928:A29928.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.828540802001953,5.075605869293213:A-10.123816122239456],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.882454872131348,6.840986728668213:A-0.03573528471056138]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.0009487674021994719]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.882454872131348,6.840986728668213:A-0.03573528471056138]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.0009487674021994719]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.334418296813965,3.0961825847625732:A-9.918788560789078],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.882454872131348,6.840986728668213:A-0.03620726755868852]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0012975151271102485]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.882454872131348,6.840986728668213:A-0.03620726755868852]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0012975151271102485]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.548851013183594,10.435674667358398:A-3.7397975996453314],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.882454872131348,6.840986728668213:A-0.036007717760679826]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0013311053766293505]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.882454872131348,6.840986728668213:A-0.036007717760679826]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0013311053766293505]]),input_ids:T7s1x1[1073,1073:A1073.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.840539932250977,6.937159061431885:A-10.19300847663451],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.882454872131348,6.840986728668213:A-0.03415937627835222]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.00108061570339347]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.882454872131348,6.840986728668213:A-0.03415937627835222]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.00108061570339347]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.949212074279785,6.899430274963379:A-8.692059776168783],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.882454872131348,7.078313827514648:A-0.03448053422941181]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0013428207877003236]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.882454872131348,7.078313827514648:A-0.03448053422941181]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0013428207877003236]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.425172805786133,9.370534896850586:A-6.1916607732344415],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.882454872131348,7.078313827514648:A-0.03414592427012631]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.0014931152163145766]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.882454872131348,7.078313827514648:A-0.03414592427012631]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.0014931152163145766]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.616035461425781,15.954569816589355:A-4.68313079584483],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.882454872131348,7.078313827514648:A-0.03241555252936941]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.0008153728593004503]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.882454872131348,7.078313827514648:A-0.03241555252936941]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.0008153728593004503]]),input_ids:T7s1x1[29885,29885:A29885.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.34793472290039,5.864483833312988:A-8.361768380252295],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.882454872131348,7.078313827514648:A-0.03249438322728925]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.0002990090804553964]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.882454872131348,7.078313827514648:A-0.03249438322728925]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.0002990090804553964]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-9.470327377319336,15.861299514770508:A-2.115428224620875],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.882454872131348,7.078313827514648:A-0.028590843711478527]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A-0.0004161455061166359]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.882454872131348,7.078313827514648:A-0.028590843711478527]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A-0.0004161455061166359]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.115264892578125,7.146152496337891:A-8.258798646918498],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.882454872131348,7.078313827514648:A-0.028340438807127714]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A-0.000832271419875286]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.882454872131348,7.078313827514648:A-0.028340438807127714]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A-0.000832271419875286]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.093294143676758,6.9178547859191895:A-8.125556804327294],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.882454872131348,7.078313827514648:A-0.028565603406525992]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A-0.0014778282873264645]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.882454872131348,7.078313827514648:A-0.028565603406525992]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A-0.0014778282873264645]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.948810577392578,3.4617862701416016:A-10.526158623103983],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.198397636413574,7.078313827514648:A-0.028059341673402186]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A-0.0020959146497797204]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.198397636413574,7.078313827514648:A-0.028059341673402186]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A-0.0020959146497797204]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.243667602539062,1.9742556810379028:A-12.392495226606727],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.198397636413574,7.078313827514648:A-0.02473806524628546]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A-0.002688247413797424]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.198397636413574,7.078313827514648:A-0.02473806524628546]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A-0.002688247413797424]]),input_ids:T7s1x1[29941,29941:A29941.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.82381820678711,3.515730857849121:A-12.101337580995635],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-7.198397636413574,7.078313827514648:A-0.01913464772082896]], value_cache=#1[T1s1x1x49x96[-1.1154754161834717,0.7704185843467712:A-0.002685581100340518]]))
-- prompt Continue: it rains...
-- answer Continue: it rains...
3 May 28, 2007, 4:56
Davin J, and RKD, we know. I'm 800036

Let’s restore the forward as it was.

model.forward = keep_model_forward

Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward().

with steal_forward(model):
    model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
  <- args=() --- kwargs=dict(cache_position:T7s8,input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.

Untrained model

This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.

Let’s create an untrained model using the config file provided config.json to create an untrained model: onnx_diagnostic.torch_models.llms.get_tiny_llm(). Then let’s use it.

experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.

print("input type before", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)

print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

The outputs

print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

It works.

ExportedProgram

try:
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s44, s70]", attention_mask: "i64[s43, s53]", position_ids: "i64[s44, s70]", past_key_values_key_0: "f32[s44, 1, s45, 96]", past_key_values_value_0: "f32[s44, 1, s21, 96]"):
             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:409 in forward, code: causal_mask = create_causal_mask(
            function_const_func_spec0 = self.function_const_func_spec0
            torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0

             #
            sym_size_int_16: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_19: "Sym(s44)" = torch.ops.aten.sym_size.int(position_ids, 0)
            sym_size_int_22: "Sym(s45)" = torch.ops.aten.sym_size.int(past_key_values_key_0, 2)
            sym_size_int_24: "Sym(s21)" = torch.ops.aten.sym_size.int(past_key_values_value_0, 2)

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[s44, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:403 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s45 + s70)" = sym_size_int_22 + sym_size_int_16

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:402 in forward, code: cache_position: torch.Tensor = torch.arange(
            arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_22 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:409 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[s43, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[s45 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            add_: "i64[s45 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[s44]" = torch.ops.aten.arange.default(sym_size_int_19, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions = None
            _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(sym_size_int_19, 'error');  _vmap_increment_nesting = None
            _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1);  arange_2 = None
            lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_1 = None
            _vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error');  _vmap_increment_nesting_1 = None
            _add_batch_dim_1: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_3, 0, 2);  arange_3 = _add_batch_dim_1 = None
            lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_2 = None
            _vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(sym_size_int_16, 'error');  _vmap_increment_nesting_2 = None
            _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3);  arange = None
            lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_3 = None
            _vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(add, 'error');  _vmap_increment_nesting_3 = None
            _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4);  add_ = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
            le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2);  _add_batch_dim_2 = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
            and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None
            flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3);  function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None
            _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(flat_apply, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
            to_2: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  flat_apply = None
            and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None
            _remove_batch_dim: "b8[s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(and_2, 4, add, 0);  and_2 = None
            _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
            _remove_batch_dim_1: "b8[s70, s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, sym_size_int_16, 0);  _remove_batch_dim = None
            _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
            _remove_batch_dim_2: "b8[1, s70, s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0)
            expand: "b8[1, s70, s45 + s70]" = torch.ops.aten.expand.default(_remove_batch_dim_1, [1, sym_size_int_16, add]);  _remove_batch_dim_1 = expand = None
            _vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_2 = None
            _remove_batch_dim_3: "b8[s44, 1, s70, s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, sym_size_int_19, 0);  _remove_batch_dim_2 = None
            _vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_3 = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_19, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_8: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[0]
            to_9: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_4: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_4);  add_4 = None
            mul_7: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_7, torch.float32);  mul_7 = None
            mul_8: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_8, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:264 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s44, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_19, sym_size_int_16, -1, 96]);  linear = None
            transpose_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_8, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:265 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_19, sym_size_int_16, -1, 96]);  linear_1 = None
            transpose_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_8, p_model_layers_0_self_attn_v_proj_weight);  mul_8 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:266 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_19, sym_size_int_16, -1, 96]);  linear_2 = None
            transpose_3: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:269 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_3: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
            unsqueeze_4: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
            mul_9: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
            slice_3: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_4: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s44, 2, s70, 48]" = torch.ops.aten.neg.default(slice_4);  slice_4 = None
            cat_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_3], -1);  neg = slice_3 = None
            mul_10: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_4);  cat_1 = None
            add_5: "f32[s44, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_9, mul_10);  mul_9 = mul_10 = None
            mul_11: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3);  unsqueeze_3 = None
            slice_5: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_6: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s44, 1, s70, 48]" = torch.ops.aten.neg.default(slice_6);  slice_6 = None
            cat_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_5], -1);  neg_1 = slice_5 = None
            mul_12: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_4);  cat_2 = unsqueeze_4 = None
            add_6: "f32[s44, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_11, mul_12);  mul_11 = mul_12 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:274 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_key_0, add_6], -2);  past_key_values_key_0 = add_6 = None
            cat_4: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_value_0, transpose_3], -2);  past_key_values_value_0 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:280 in forward, code: attn_output, attn_weights = attention_interface(
            slice_7: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            unsqueeze_5: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_7, 2);  slice_7 = None
            slice_8: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807);  unsqueeze_5 = None
            expand_2: "f32[s44, 1, 2, s45 + s70, 96]" = torch.ops.aten.expand.default(slice_8, [sym_size_int_19, 1, 2, add, 96]);  slice_8 = None
            reshape: "f32[s44, 2, s45 + s70, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_19, 2, add, 96]);  expand_2 = None
            slice_9: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            unsqueeze_6: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_9, 2);  slice_9 = None
            add_11: "Sym(s21 + s70)" = sym_size_int_24 + sym_size_int_16;  sym_size_int_24 = None
            slice_10: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807);  unsqueeze_6 = None
            expand_3: "f32[s44, 1, 2, s21 + s70, 96]" = torch.ops.aten.expand.default(slice_10, [sym_size_int_19, 1, 2, add_11, 96]);  slice_10 = None
            reshape_1: "f32[s44, 2, s21 + s70, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_19, 2, add_11, 96]);  expand_3 = add_11 = None
            slice_11: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.slice.Tensor(_remove_batch_dim_3, 3, None, add);  _remove_batch_dim_3 = add = None
            scaled_dot_product_attention: "f32[s44, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_5, reshape, reshape_1, slice_11, scale = 0.10206207261596575);  add_5 = reshape = reshape_1 = slice_11 = None
            transpose_4: "f32[s44, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:291 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_2: "f32[s44, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_19, sym_size_int_16, -1]);  transpose_4 = sym_size_int_19 = sym_size_int_16 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_0_self_attn_o_proj_weight);  reshape_2 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:331 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_1: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_21: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_21, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_21, torch.float32);  mul_21 = None
            mul_22: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13);  p_model_layers_0_post_attention_layernorm_weight = to_13 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_22, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: ~/github/transformers/src/transformers/activations.py:103 in forward, code: return nn.functional.silu(input)
            silu: "f32[s44, s70, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_22, p_model_layers_0_mlp_up_proj_weight);  mul_22 = p_model_layers_0_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:184 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_23: "f32[s44, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_23, p_model_layers_0_mlp_down_proj_weight);  mul_23 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:337 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6);  to_12 = linear_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_2: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_24: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2);  to_14 = rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_24, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_24, torch.float32);  mul_24 = None
            mul_25: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15);  p_model_norm_weight = to_15 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:500 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_12: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(mul_25, 0, 0, 9223372036854775807);  mul_25 = None
            slice_13: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(slice_12, 1, 0, 9223372036854775807);  slice_12 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s44, s70, 32000]" = torch.ops.aten.linear.default(slice_13, p_lm_head_weight);  slice_13 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_19: "Sym(s44)", position_ids: "i64[s44, s70]"):
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:125 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2);  unsqueeze = None
                _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32);  unsqueeze_1 = None
                expand_1: "f32[s44, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_19, -1, 1]);  to_3 = sym_size_int_19 = None
                _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                to_4: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_1: "i64[s44, s70]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_2: "i64[s44, 1, s70]" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
                slice_2: "i64[s44, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
                _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_2, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                to_5: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(slice_2, torch.float32);  slice_2 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
                mul_5: "f32[s44, s70, 96]" = wrap_with_autocast[0]

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_6: "f32[s44, s70, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
                to_8: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul_5, torch.float32);  mul_5 = None
                _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
                to_9: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul_6, torch.float32);  mul_6 = None
                return (to_8, to_9)

            class submod_1(torch.nn.Module):
                def forward(self, to_4: "f32[s44, 48, 1]", to_5: "f32[s44, 1, s70]"):
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:130 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                    to_6: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                    _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                    to_7: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                    matmul: "f32[s44, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                    transpose: "f32[s44, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:131 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[s44, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[s44, s70, 96]" = torch.ops.aten.cos.default(cat)
                    mul_5: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[s44, s70, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    mul_6: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                    return (mul_5, mul_6)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_0: USER_INPUT
    past_key_values_value_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_4: USER_OUTPUT

Range constraints: {s44: VR[0, int_oo], s70: VR[0, int_oo], s43: VR[0, int_oo], s53: VR[0, int_oo], s45: VR[0, int_oo], s21: VR[0, int_oo]}

Back to the original model

Let’s use the same dummy inputs but we use the downloaded model. Dummy inputs and dynamic shapes are created by function onnx_diagnostic.torch_models.llms.get_tiny_llm().

data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

Let’s print the inputs.

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: 'batch', 1: 'cache+seq'},
 'input_ids': {0: 'batch', 1: 'seq_length'},
 'past_key_values': [{0: 'batch', 2: 'cache_length'},
                     {0: 'batch', 2: 'cache_length'}],
 'position_ids': {0: 'batch', 1: 'seq_length'}}

And Let’s finally export.

try:
    ep = torch.export.export(
        model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s44, s70]", attention_mask: "i64[s43, s53]", position_ids: "i64[s44, s70]", past_key_values_key_0: "f32[s44, 1, s45, 96]", past_key_values_value_0: "f32[s44, 1, s21, 96]"):
             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:409 in forward, code: causal_mask = create_causal_mask(
            function_const_func_spec0 = self.function_const_func_spec0
            torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = self.torch__dynamo__trace_wrapped_higher_order_op_ModIndex0

             #
            sym_size_int_16: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_19: "Sym(s44)" = torch.ops.aten.sym_size.int(position_ids, 0)
            sym_size_int_22: "Sym(s45)" = torch.ops.aten.sym_size.int(past_key_values_key_0, 2)
            sym_size_int_24: "Sym(s21)" = torch.ops.aten.sym_size.int(past_key_values_value_0, 2)

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
            embedding: "f32[s44, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:403 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s45 + s70)" = sym_size_int_22 + sym_size_int_16

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:402 in forward, code: cache_position: torch.Tensor = torch.arange(
            arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_22 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:409 in forward, code: causal_mask = create_causal_mask(
            _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default = None
            to: "b8[s43, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool);  attention_mask = None
            arange_1: "i64[s45 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            add_: "i64[s45 + s70]" = torch.ops.aten.add_.Tensor(arange_1, 0);  arange_1 = None
            arange_2: "i64[s44]" = torch.ops.aten.arange.default(sym_size_int_19, device = device(type='cpu'), pin_memory = False)
            arange_3: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
            lazy_load_decompositions = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions = None
            _vmap_increment_nesting = torch._functorch.predispatch._vmap_increment_nesting(sym_size_int_19, 'error');  _vmap_increment_nesting = None
            _add_batch_dim: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_2, 0, 1);  arange_2 = None
            lazy_load_decompositions_1 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_1 = None
            _vmap_increment_nesting_1 = torch._functorch.predispatch._vmap_increment_nesting(1, 'error');  _vmap_increment_nesting_1 = None
            _add_batch_dim_1: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange_3, 0, 2);  arange_3 = _add_batch_dim_1 = None
            lazy_load_decompositions_2 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_2 = None
            _vmap_increment_nesting_2 = torch._functorch.predispatch._vmap_increment_nesting(sym_size_int_16, 'error');  _vmap_increment_nesting_2 = None
            _add_batch_dim_2: "i64[]" = torch._functorch.predispatch._add_batch_dim(arange, 0, 3);  arange = None
            lazy_load_decompositions_3 = torch._functorch.predispatch.lazy_load_decompositions();  lazy_load_decompositions_3 = None
            _vmap_increment_nesting_3 = torch._functorch.predispatch._vmap_increment_nesting(add, 'error');  _vmap_increment_nesting_3 = None
            _add_batch_dim_3: "i64[]" = torch._functorch.predispatch._add_batch_dim(add_, 0, 4);  add_ = None
            new_ones: "b8[]" = torch.ops.aten.new_ones.default(_add_batch_dim_2, [], dtype = torch.bool, pin_memory = False)
            le: "b8[]" = torch.ops.aten.le.Tensor(_add_batch_dim_3, _add_batch_dim_2);  _add_batch_dim_2 = None
            _assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_1 = None
            to_1: "b8[]" = torch.ops.aten.to.dtype_layout(le, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  le = None
            and_1: "b8[]" = torch.ops.aten.__and__.Tensor(new_ones, to_1);  new_ones = to_1 = None
            flat_apply: "b8[]" = torch.ops.higher_order.flat_apply(function_const_func_spec0, torch__dynamo__trace_wrapped_higher_order_op_mod_index0, 'torch._dynamo._trace_wrapped_higher_order_op.ModIndex', to, _add_batch_dim, _add_batch_dim_3);  function_const_func_spec0 = torch__dynamo__trace_wrapped_higher_order_op_mod_index0 = to = _add_batch_dim = _add_batch_dim_3 = None
            _assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(flat_apply, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_2 = None
            to_2: "b8[]" = torch.ops.aten.to.dtype_layout(flat_apply, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'));  flat_apply = None
            and_2: "b8[]" = torch.ops.aten.__and__.Tensor(and_1, to_2);  and_1 = to_2 = None
            _remove_batch_dim: "b8[s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(and_2, 4, add, 0);  and_2 = None
            _vmap_decrement_nesting = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
            _remove_batch_dim_1: "b8[s70, s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim, 3, sym_size_int_16, 0);  _remove_batch_dim = None
            _vmap_decrement_nesting_1 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_1 = None
            _remove_batch_dim_2: "b8[1, s70, s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_1, 2, 1, 0)
            expand: "b8[1, s70, s45 + s70]" = torch.ops.aten.expand.default(_remove_batch_dim_1, [1, sym_size_int_16, add]);  _remove_batch_dim_1 = expand = None
            _vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_2 = None
            _remove_batch_dim_3: "b8[s44, 1, s70, s45 + s70]" = torch._functorch.predispatch._remove_batch_dim(_remove_batch_dim_2, 1, sym_size_int_19, 0);  _remove_batch_dim_2 = None
            _vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting();  _vmap_decrement_nesting_3 = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_19, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_8: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[0]
            to_9: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_10 = None
            to_10: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_4: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_4);  add_4 = None
            mul_7: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt);  rsqrt = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_11 = None
            to_11: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_7, torch.float32);  mul_7 = None
            mul_8: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11);  p_model_layers_0_input_layernorm_weight = to_11 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_8, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:264 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s44, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_19, sym_size_int_16, -1, 96]);  linear = None
            transpose_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_8, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:265 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_19, sym_size_int_16, -1, 96]);  linear_1 = None
            transpose_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_8, p_model_layers_0_self_attn_v_proj_weight);  mul_8 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:266 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_19, sym_size_int_16, -1, 96]);  linear_2 = None
            transpose_3: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:269 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_3: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1);  to_8 = None
            unsqueeze_4: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1);  to_9 = None
            mul_9: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_3)
            slice_3: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_4: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s44, 2, s70, 48]" = torch.ops.aten.neg.default(slice_4);  slice_4 = None
            cat_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_3], -1);  neg = slice_3 = None
            mul_10: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_4);  cat_1 = None
            add_5: "f32[s44, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_9, mul_10);  mul_9 = mul_10 = None
            mul_11: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_3);  unsqueeze_3 = None
            slice_5: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_6: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s44, 1, s70, 48]" = torch.ops.aten.neg.default(slice_6);  slice_6 = None
            cat_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_5], -1);  neg_1 = slice_5 = None
            mul_12: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_4);  cat_2 = unsqueeze_4 = None
            add_6: "f32[s44, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_11, mul_12);  mul_11 = mul_12 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:274 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_key_0, add_6], -2);  past_key_values_key_0 = add_6 = None
            cat_4: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_value_0, transpose_3], -2);  past_key_values_value_0 = transpose_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:280 in forward, code: attn_output, attn_weights = attention_interface(
            slice_7: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            unsqueeze_5: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_7, 2);  slice_7 = None
            slice_8: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 3, 0, 9223372036854775807);  unsqueeze_5 = None
            expand_2: "f32[s44, 1, 2, s45 + s70, 96]" = torch.ops.aten.expand.default(slice_8, [sym_size_int_19, 1, 2, add, 96]);  slice_8 = None
            reshape: "f32[s44, 2, s45 + s70, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_19, 2, add, 96]);  expand_2 = None
            slice_9: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            unsqueeze_6: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_9, 2);  slice_9 = None
            add_11: "Sym(s21 + s70)" = sym_size_int_24 + sym_size_int_16;  sym_size_int_24 = None
            slice_10: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 3, 0, 9223372036854775807);  unsqueeze_6 = None
            expand_3: "f32[s44, 1, 2, s21 + s70, 96]" = torch.ops.aten.expand.default(slice_10, [sym_size_int_19, 1, 2, add_11, 96]);  slice_10 = None
            reshape_1: "f32[s44, 2, s21 + s70, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_19, 2, add_11, 96]);  expand_3 = add_11 = None
            slice_11: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.slice.Tensor(_remove_batch_dim_3, 3, None, add);  _remove_batch_dim_3 = add = None
            scaled_dot_product_attention: "f32[s44, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_5, reshape, reshape_1, slice_11, scale = 0.10206207261596575);  add_5 = reshape = reshape_1 = slice_11 = None
            transpose_4: "f32[s44, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:291 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_2: "f32[s44, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_19, sym_size_int_16, -1]);  transpose_4 = sym_size_int_19 = sym_size_int_16 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_0_self_attn_o_proj_weight);  reshape_2 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:331 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3);  to_10 = linear_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_12 = None
            to_12: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_1: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_21: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1);  rsqrt_1 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_21, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_13 = None
            to_13: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_21, torch.float32);  mul_21 = None
            mul_22: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13);  p_model_layers_0_post_attention_layernorm_weight = to_13 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_22, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: ~/github/transformers/src/transformers/activations.py:103 in forward, code: return nn.functional.silu(input)
            silu: "f32[s44, s70, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_22, p_model_layers_0_mlp_up_proj_weight);  mul_22 = p_model_layers_0_mlp_up_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:184 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_23: "f32[s44, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_23, p_model_layers_0_mlp_down_proj_weight);  mul_23 = p_model_layers_0_mlp_down_proj_weight = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:337 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6);  to_12 = linear_6 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
            _assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_14 = None
            to_14: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
            mean_2: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_24: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2);  to_14 = rsqrt_2 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
            _assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_24, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_15 = None
            to_15: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_24, torch.float32);  mul_24 = None
            mul_25: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15);  p_model_norm_weight = to_15 = None

             # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:500 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_12: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(mul_25, 0, 0, 9223372036854775807);  mul_25 = None
            slice_13: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(slice_12, 1, 0, 9223372036854775807);  slice_12 = None

             # File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s44, s70, 32000]" = torch.ops.aten.linear.default(slice_13, p_lm_head_weight);  slice_13 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_19: "Sym(s44)", position_ids: "i64[s44, s70]"):
                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:125 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
                unsqueeze: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                unsqueeze_1: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2);  unsqueeze = None
                _assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_3 = None
                to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_1, torch.float32);  unsqueeze_1 = None
                expand_1: "f32[s44, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_19, -1, 1]);  to_3 = sym_size_int_19 = None
                _assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_4 = None
                to_4: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  expand_1 = None

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_1: "i64[s44, s70]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_2: "i64[s44, 1, s70]" = torch.ops.aten.unsqueeze.default(slice_1, 1);  slice_1 = None
                slice_2: "i64[s44, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
                _assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_2, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_5 = None
                to_5: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(slice_2, torch.float32);  slice_2 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5);  submod_3 = to_4 = to_5 = None

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
                mul_5: "f32[s44, s70, 96]" = wrap_with_autocast[0]

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
                mul_6: "f32[s44, s70, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                _assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_8 = None
                to_8: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul_5, torch.float32);  mul_5 = None
                _assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_6, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_9 = None
                to_9: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul_6, torch.float32);  mul_6 = None
                return (to_8, to_9)

            class submod_1(torch.nn.Module):
                def forward(self, to_4: "f32[s44, 48, 1]", to_5: "f32[s44, 1, s70]"):
                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:130 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
                    _assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_6 = None
                    to_6: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32);  to_4 = None
                    _assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided);  _assert_tensor_metadata_default_7 = None
                    to_7: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32);  to_5 = None
                    matmul: "f32[s44, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7);  to_6 = to_7 = None
                    transpose: "f32[s44, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:131 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[s44, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
                    cos: "f32[s44, s70, 96]" = torch.ops.aten.cos.default(cat)
                    mul_5: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                     # File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
                    sin: "f32[s44, s70, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    mul_6: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
                    return (mul_5, mul_6)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_0: USER_INPUT
    past_key_values_value_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_4: USER_OUTPUT

Range constraints: {s44: VR[0, int_oo], s70: VR[0, int_oo], s43: VR[0, int_oo], s53: VR[0, int_oo], s45: VR[0, int_oo], s21: VR[0, int_oo]}

If you have any error, then look at example Export Tiny-LLM with patches.

doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")
plot export tiny llm

Total running time of the script: (0 minutes 5.078 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export with DynamicCache and guessed dynamic shapes

Export with DynamicCache and guessed dynamic shapes

Gallery generated by Sphinx-Gallery