Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Inputs are always dynamic with LLMs that is why dynamic shapes needs to be specified when a LLM is exported with:func:torch.export.export. Most of the examples on HuggingFace use method transformers.GenerationMixin.generate() but we only want to export the model and its method forward.

That example shows to guess the inputs of this method even though the model is executed through meth generate.

We focus on the model Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.

Steel the forward method

The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.llms import get_tiny_llm


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

We rewrite the forward method to print the cache dimension.

def _forward_(*args, _f=None, **kwargs):
    assert _f is not None
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        # torch.compiler.is_exporting requires torch>=2.7
        print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
    res = _f(*args, **kwargs)
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
    return res


keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
    *args, _f=_f, **kwargs
)

Let’s run the model.

prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[830,830:A830.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.13926869702104508]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.008778290737794762]]),input_ids:T7s1x1[830,830:A830.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.13926869702104508]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.008778290737794762]]),input_ids:T7s1x1[635,635:A635.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.13467349399519055]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008463278700810406]]),input_ids:T7s1x1[635,635:A635.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.13467349399519055]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008463278700810406]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.8197526931762695,6.226877689361572:A-0.14064827534512556]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.5150525569915771:A0.00837615266990807]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.8197526931762695,6.226877689361572:A-0.14064827534512556]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.5150525569915771:A0.00837615266990807]]),input_ids:T7s1x1[2355,2355:A2355.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.8197526931762695,6.226877689361572:A-0.1338236682991515]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.575259804725647:A0.00714091036518817]]),input_ids:T7s1x1[2355,2355:A2355.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.8197526931762695,6.226877689361572:A-0.1338236682991515]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.575259804725647:A0.00714091036518817]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.8197526931762695,6.565155982971191:A-0.13599621835535985]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.575259804725647:A0.007871791164327591]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.8197526931762695,6.565155982971191:A-0.13599621835535985]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.575259804725647:A0.007871791164327591]]),input_ids:T7s1x1[2107,2107:A2107.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.8197526931762695,6.565155982971191:A-0.13278416968655837]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.575259804725647:A0.0076829137735568934]]),input_ids:T7s1x1[2107,2107:A2107.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.8197526931762695,6.565155982971191:A-0.13278416968655837]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.575259804725647:A0.0076829137735568934]]),input_ids:T7s1x1[5376,5376:A5376.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.8197526931762695,6.565155982971191:A-0.1292427302262998]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.575259804725647:A0.005242380537907189]]),input_ids:T7s1x1[5376,5376:A5376.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.8197526931762695,6.565155982971191:A-0.1292427302262998]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.575259804725647:A0.005242380537907189]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.8197526931762695,6.565155982971191:A-0.12225772721065671]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.575259804725647:A0.005362754381385078]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.8197526931762695,6.565155982971191:A-0.12225772721065671]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.575259804725647:A0.005362754381385078]]),input_ids:T7s1x1[7458,7458:A7458.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.8197526931762695,6.565155982971191:A-0.12251733124104053]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.575259804725647:A0.00661465932787406]]),input_ids:T7s1x1[7458,7458:A7458.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.8197526931762695,6.565155982971191:A-0.12251733124104053]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.575259804725647:A0.00661465932787406]]),input_ids:T7s1x1[856,856:A856.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.8197526931762695,6.565155982971191:A-0.1233257284622798]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.575259804725647:A0.0066723718805153315]]),input_ids:T7s1x1[856,856:A856.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.8197526931762695,6.565155982971191:A-0.1233257284622798]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.575259804725647:A0.0066723718805153315]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.8197526931762695,6.565155982971191:A-0.12438850456218413]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007809057273613705]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.8197526931762695,6.565155982971191:A-0.12438850456218413]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007809057273613705]]),input_ids:T7s1x1[29908,29908:A29908.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.8197526931762695,6.565155982971191:A-0.12379566539572504]], value_cache=#1[T1s1x1x20x96[-0.7138619422912598,0.7704185843467712:A0.006257893364750089]]),input_ids:T7s1x1[29908,29908:A29908.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.8197526931762695,6.565155982971191:A-0.12379566539572504]], value_cache=#1[T1s1x1x20x96[-0.7138619422912598,0.7704185843467712:A0.006257893364750089]]),input_ids:T7s1x1[29956,29956:A29956.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-5.8197526931762695,6.565155982971191:A-0.11858383356172006]], value_cache=#1[T1s1x1x21x96[-0.7138619422912598,0.7704185843467712:A0.006566355677001584]]),input_ids:T7s1x1[29956,29956:A29956.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-5.8197526931762695,6.565155982971191:A-0.11858383356172006]], value_cache=#1[T1s1x1x21x96[-0.7138619422912598,0.7704185843467712:A0.006566355677001584]]),input_ids:T7s1x1[1639,1639:A1639.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-5.8197526931762695,6.565155982971191:A-0.11352129460079374]], value_cache=#1[T1s1x1x22x96[-0.7138619422912598,0.7704185843467712:A0.004946331530976694]]),input_ids:T7s1x1[1639,1639:A1639.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-5.8197526931762695,6.565155982971191:A-0.11352129460079374]], value_cache=#1[T1s1x1x22x96[-0.7138619422912598,0.7704185843467712:A0.004946331530976694]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-5.987462043762207,7.052847862243652:A-0.11419484736873436]], value_cache=#1[T1s1x1x23x96[-0.7138619422912598,0.7704185843467712:A0.005234265772050212]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-5.987462043762207,7.052847862243652:A-0.11419484736873436]], value_cache=#1[T1s1x1x23x96[-0.7138619422912598,0.7704185843467712:A0.005234265772050212]]),input_ids:T7s1x1[510,510:A510.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-5.987462043762207,7.052847862243652:A-0.11391830670491901]], value_cache=#1[T1s1x1x24x96[-0.7138619422912598,0.7704185843467712:A0.00490483474958915]]),input_ids:T7s1x1[510,510:A510.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-5.987462043762207,7.052847862243652:A-0.11391830670491901]], value_cache=#1[T1s1x1x24x96[-0.7138619422912598,0.7704185843467712:A0.00490483474958915]]),input_ids:T7s1x1[448,448:A448.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-5.987462043762207,7.052847862243652:A-0.11251380358313327]], value_cache=#1[T1s1x1x25x96[-0.7138619422912598,0.7704185843467712:A0.004670132180078023]]),input_ids:T7s1x1[448,448:A448.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-5.987462043762207,7.052847862243652:A-0.11251380358313327]], value_cache=#1[T1s1x1x25x96[-0.7138619422912598,0.7704185843467712:A0.004670132180078023]]),input_ids:T7s1x1[1724,1724:A1724.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-5.987462043762207,7.052847862243652:A-0.11230627737766041]], value_cache=#1[T1s1x1x26x96[-0.7138619422912598,0.7704185843467712:A0.0046593989174174896]]),input_ids:T7s1x1[1724,1724:A1724.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-5.987462043762207,7.052847862243652:A-0.11230627737766041]], value_cache=#1[T1s1x1x26x96[-0.7138619422912598,0.7704185843467712:A0.0046593989174174896]]),input_ids:T7s1x1[2253,2253:A2253.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-5.987462043762207,7.052847862243652:A-0.10919084175403429]], value_cache=#1[T1s1x1x27x96[-0.7138619422912598,0.7704185843467712:A0.004745883260943909]]),input_ids:T7s1x1[2253,2253:A2253.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-5.987462043762207,7.052847862243652:A-0.10919084175403429]], value_cache=#1[T1s1x1x27x96[-0.7138619422912598,0.7704185843467712:A0.004745883260943909]]),input_ids:T7s1x1[982,982:A982.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-5.987462043762207,7.052847862243652:A-0.1025098547061134]], value_cache=#1[T1s1x1x28x96[-0.7138619422912598,0.7704185843467712:A0.005613603910806627]]),input_ids:T7s1x1[982,982:A982.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-5.987462043762207,7.052847862243652:A-0.1025098547061134]], value_cache=#1[T1s1x1x28x96[-0.7138619422912598,0.7704185843467712:A0.005613603910806627]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.72218656539917,7.052847862243652:A-0.09692871276567383]], value_cache=#1[T1s1x1x29x96[-0.7138619422912598,0.7704185843467712:A0.0060061092495356854]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.72218656539917,7.052847862243652:A-0.09692871276567383]], value_cache=#1[T1s1x1x29x96[-0.7138619422912598,0.7704185843467712:A0.0060061092495356854]]),input_ids:T7s1x1[15649,15649:A15649.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.72218656539917,7.052847862243652:A-0.09255540540483101]], value_cache=#1[T1s1x1x30x96[-0.7138619422912598,0.7704185843467712:A0.005399935928891056]]),input_ids:T7s1x1[15649,15649:A15649.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.72218656539917,7.052847862243652:A-0.09255540540483101]], value_cache=#1[T1s1x1x30x96[-0.7138619422912598,0.7704185843467712:A0.005399935928891056]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.72218656539917,7.052847862243652:A-0.08525467339006687]], value_cache=#1[T1s1x1x31x96[-0.7138619422912598,0.7704185843467712:A0.005672597131955803]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.72218656539917,7.052847862243652:A-0.08525467339006687]], value_cache=#1[T1s1x1x31x96[-0.7138619422912598,0.7704185843467712:A0.005672597131955803]]),input_ids:T7s1x1[367,367:A367.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.72218656539917,7.052847862243652:A-0.08051488317592732]], value_cache=#1[T1s1x1x32x96[-0.7138619422912598,0.7704185843467712:A0.005623818974484607]]),input_ids:T7s1x1[367,367:A367.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.72218656539917,7.052847862243652:A-0.08051488317592732]], value_cache=#1[T1s1x1x32x96[-0.7138619422912598,0.7704185843467712:A0.005623818974484607]]),input_ids:T7s1x1[596,596:A596.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.72218656539917,7.052847862243652:A-0.07817058410283709]], value_cache=#1[T1s1x1x33x96[-0.7138619422912598,0.7704185843467712:A0.0061109000237391025]]),input_ids:T7s1x1[596,596:A596.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.72218656539917,7.052847862243652:A-0.07817058410283709]], value_cache=#1[T1s1x1x33x96[-0.7138619422912598,0.7704185843467712:A0.0061109000237391025]]),input_ids:T7s1x1[1914,1914:A1914.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.72218656539917,7.052847862243652:A-0.07948427112209103]], value_cache=#1[T1s1x1x34x96[-0.7138619422912598,0.7704185843467712:A0.005882191799989467]]),input_ids:T7s1x1[1914,1914:A1914.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.72218656539917,7.052847862243652:A-0.07948427112209103]], value_cache=#1[T1s1x1x34x96[-0.7138619422912598,0.7704185843467712:A0.005882191799989467]]),input_ids:T7s1x1[3699,3699:A3699.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.72218656539917,7.052847862243652:A-0.07514602015667368]], value_cache=#1[T1s1x1x35x96[-0.7138619422912598,0.7704185843467712:A0.005620337307131454]]),input_ids:T7s1x1[3699,3699:A3699.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.72218656539917,7.052847862243652:A-0.07514602015667368]], value_cache=#1[T1s1x1x35x96[-0.7138619422912598,0.7704185843467712:A0.005620337307131454]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.72218656539917,7.052847862243652:A-0.07634669775328533]], value_cache=#1[T1s1x1x36x96[-0.7138619422912598,0.7704185843467712:A0.005670651629194018]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.72218656539917,7.052847862243652:A-0.07634669775328533]], value_cache=#1[T1s1x1x36x96[-0.7138619422912598,0.7704185843467712:A0.005670651629194018]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.72218656539917,7.052847862243652:A-0.07316075543306536]], value_cache=#1[T1s1x1x37x96[-0.7138619422912598,0.7704185843467712:A0.006281428459469434]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.72218656539917,7.052847862243652:A-0.07316075543306536]], value_cache=#1[T1s1x1x37x96[-0.7138619422912598,0.7704185843467712:A0.006281428459469434]]),input_ids:T7s1x1[6295,6295:A6295.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.72218656539917,7.052847862243652:A-0.06949802690050683]], value_cache=#1[T1s1x1x38x96[-0.7138619422912598,0.7704185843467712:A0.006212282752589441]]),input_ids:T7s1x1[6295,6295:A6295.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.72218656539917,7.052847862243652:A-0.06949802690050683]], value_cache=#1[T1s1x1x38x96[-0.7138619422912598,0.7704185843467712:A0.006212282752589441]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.72218656539917,7.052847862243652:A-0.07012538888073665]], value_cache=#1[T1s1x1x39x96[-0.7138619422912598,0.7704185843467712:A0.0064081840467255635]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.72218656539917,7.052847862243652:A-0.07012538888073665]], value_cache=#1[T1s1x1x39x96[-0.7138619422912598,0.7704185843467712:A0.0064081840467255635]]),input_ids:T7s1x1[2020,2020:A2020.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.72218656539917,7.052847862243652:A-0.06562370558746504]], value_cache=#1[T1s1x1x40x96[-0.7138619422912598,0.7704185843467712:A0.006578563628465872]]),input_ids:T7s1x1[2020,2020:A2020.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.72218656539917,7.052847862243652:A-0.06562370558746504]], value_cache=#1[T1s1x1x40x96[-0.7138619422912598,0.7704185843467712:A0.006578563628465872]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.72218656539917,7.052847862243652:A-0.0628541645154012]], value_cache=#1[T1s1x1x41x96[-0.7138619422912598,0.7704185843467712:A0.005612376343678323]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.72218656539917,7.052847862243652:A-0.0628541645154012]], value_cache=#1[T1s1x1x41x96[-0.7138619422912598,0.7704185843467712:A0.005612376343678323]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.72218656539917,7.052847862243652:A-0.06273026103340343]], value_cache=#1[T1s1x1x42x96[-0.7138619422912598,0.7704185843467712:A0.005540032730943985]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.72218656539917,7.052847862243652:A-0.06273026103340343]], value_cache=#1[T1s1x1x42x96[-0.7138619422912598,0.7704185843467712:A0.005540032730943985]]),input_ids:T7s1x1[1925,1925:A1925.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.72218656539917,7.052847862243652:A-0.061756996103297775]], value_cache=#1[T1s1x1x43x96[-0.7138619422912598,0.7704185843467712:A0.005460804806473387]]),input_ids:T7s1x1[1925,1925:A1925.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.72218656539917,7.052847862243652:A-0.061756996103297775]], value_cache=#1[T1s1x1x43x96[-0.7138619422912598,0.7704185843467712:A0.005460804806473387]]),input_ids:T7s1x1[1283,1283:A1283.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.72218656539917,7.052847862243652:A-0.060946521573455495]], value_cache=#1[T1s1x1x44x96[-0.7138619422912598,0.7704185843467712:A0.005347465156880775]]),input_ids:T7s1x1[1283,1283:A1283.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.72218656539917,7.052847862243652:A-0.060946521573455495]], value_cache=#1[T1s1x1x44x96[-0.7138619422912598,0.7704185843467712:A0.005347465156880775]]),input_ids:T7s1x1[777,777:A777.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.72218656539917,7.052847862243652:A-0.05882268034876217]], value_cache=#1[T1s1x1x45x96[-0.7138619422912598,0.7704185843467712:A0.00515750024143455]]),input_ids:T7s1x1[777,777:A777.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.72218656539917,7.052847862243652:A-0.05882268034876217]], value_cache=#1[T1s1x1x45x96[-0.7138619422912598,0.7704185843467712:A0.00515750024143455]]),input_ids:T7s1x1[901,901:A901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.72218656539917,7.052847862243652:A-0.05760709743605032]], value_cache=#1[T1s1x1x46x96[-0.7138619422912598,0.7704185843467712:A0.0045914474782276175]]),input_ids:T7s1x1[901,901:A901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.72218656539917,7.052847862243652:A-0.05760709743605032]], value_cache=#1[T1s1x1x46x96[-0.7138619422912598,0.7704185843467712:A0.0045914474782276175]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.72218656539917,7.052847862243652:A-0.05824511129190177]], value_cache=#1[T1s1x1x47x96[-0.7138619422912598,0.7704185843467712:A0.00465187738084796]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.72218656539917,7.052847862243652:A-0.05824511129190177]], value_cache=#1[T1s1x1x47x96[-0.7138619422912598,0.7704185843467712:A0.00465187738084796]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.72218656539917,7.052847862243652:A-0.058899736090880755]], value_cache=#1[T1s1x1x48x96[-0.7138619422912598,0.7704185843467712:A0.004711315192932059]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.72218656539917,7.052847862243652:A-0.058899736090880755]], value_cache=#1[T1s1x1x48x96[-0.7138619422912598,0.7704185843467712:A0.004711315192932059]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-6.72218656539917,7.052847862243652:A-0.05637273617952703]], value_cache=#1[T1s1x1x49x96[-1.1154754161834717,0.7704185843467712:A0.004064715622091023]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-- prompt Continue: it rains...
-- answer Continue: it rains... Really I got a great deal of trouble...
"Winter.com - What better way to buy, be your own house?
So, why have we put off some more? I've

Let’s restore the forward as it was.

model.forward = keep_model_forward

Untrained model

This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.

Let’s create an untrained model using the config file provided config.json to create an untrained model: onnx_diagnostic.torch_models.llms.get_tiny_llm(). Then let’s use it.

experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.

print("input type before", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)

print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

The outputs

print("result type", string_type(expected_output, with_shape=True))
result type dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

It works.

ExportedProgram

try:
    ep = torch.export.export(
        untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
[_catch_produce_guards_and_solve_constraints] ERRORproduce_guards_and_solve_constraints failed, use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping
fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f3a2d8853a0>
dynamic_shapes={'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>}, 'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}], [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]]}
equalities_inputs=EqualityConstraint(warn_only=False, source_pairs=[(TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2))], derived_equalities=[], phantom_symbols=[], relaxed_sources={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1)}, _parents={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2)}, _defs={})
original_signature=(input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Union[transformers.cache_utils.Cache, List[torch.FloatTensor], NoneType] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[transformers.models.llama.modeling_llama.KwargsForCausalLM]) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]
_is_torch_jit_trace=False
exc=Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of batch = L['args'][1]['input_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['attention_mask'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['position_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
Suggested fixes:
  batch = 2
  L['args'][1]['position_ids'].size()[1] = seq_length
gm=<lambda>()



def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1):
    embedding = torch.ops.aten.embedding.default(arg0_1, arg13_1);  arg0_1 = None
    sym_size_int = torch.ops.aten.sym_size.int(arg16_1, 2)
    sym_size_int_1 = torch.ops.aten.sym_size.int(arg13_1, 1)
    add = sym_size_int + sym_size_int_1
    arange = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False);  sym_size_int = add = None
    sym_size_int_2 = torch.ops.aten.sym_size.int(arg14_1, 1)
    full = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    triu = torch.ops.aten.triu.default(full, 1);  full = None
    arange_1 = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False);  sym_size_int_2 = None
    reshape = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
    gt = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
    mul_ = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
    unsqueeze = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
    slice_1 = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
    slice_2 = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
    sym_size_int_5 = torch.ops.aten.sym_size.int(arg13_1, 0);  arg13_1 = None
    expand = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]);  slice_2 = None
    clone = torch.ops.aten.clone.default(expand);  expand = None
    slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
    slice_4 = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
    slice_5 = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
    slice_6 = torch.ops.aten.slice.Tensor(arg14_1, 0, 0, 9223372036854775807);  arg14_1 = None
    unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
    unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
    slice_7 = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
    to = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
    add_2 = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
    eq_7 = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
    slice_8 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
    slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
    slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
    masked_fill = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38);  slice_10 = eq_7 = None
    slice_11 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
    slice_12 = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
    slice_13 = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
    copy_ = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg12_1, 0);  arg12_1 = None
    slice_14 = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
    unsqueeze_5 = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
    to_1 = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
    sym_size_int_13 = torch.ops.aten.sym_size.int(arg15_1, 0)
    expand_1 = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]);  to_1 = sym_size_int_13 = None
    slice_15 = torch.ops.aten.slice.Tensor(arg15_1, 0, 0, 9223372036854775807);  arg15_1 = None
    unsqueeze_6 = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
    slice_16 = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
    to_2 = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', torch.bfloat16, False, False)
    to_3 = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
    to_4 = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
    to_5 = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
    matmul = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
    transpose = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None
    cat = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None
    cos = torch.ops.aten.cos.default(cat)
    sin = torch.ops.aten.sin.default(cat);  cat = None
    _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
    mul = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None
    mul_1 = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
    to_6 = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
    to_7 = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    to_8 = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None
    pow_1 = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
    mean = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
    add_3 = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
    rsqrt = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
    mul_2 = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None
    to_9 = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
    mul_3 = torch.ops.aten.mul.Tensor(arg8_1, to_9);  arg8_1 = to_9 = None
    linear = torch.ops.aten.linear.default(mul_3, arg1_1);  arg1_1 = None
    view = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_1, -1, 96]);  linear = None
    transpose_1 = torch.ops.aten.transpose.int(view, 1, 2);  view = None
    linear_1 = torch.ops.aten.linear.default(mul_3, arg2_1);  arg2_1 = None
    view_1 = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_1, -1, 96]);  linear_1 = None
    transpose_2 = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
    linear_2 = torch.ops.aten.linear.default(mul_3, arg3_1);  mul_3 = arg3_1 = None
    view_2 = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_1, -1, 96]);  linear_2 = sym_size_int_5 = None
    transpose_3 = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
    unsqueeze_7 = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
    unsqueeze_8 = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
    mul_4 = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
    slice_17 = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
    slice_18 = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
    neg = torch.ops.aten.neg.default(slice_18);  slice_18 = None
    cat_1 = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
    mul_5 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
    add_4 = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
    mul_6 = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
    slice_19 = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
    slice_20 = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
    neg_1 = torch.ops.aten.neg.default(slice_20);  slice_20 = None
    cat_2 = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
    mul_7 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
    add_5 = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
    cat_3 = torch.ops.aten.cat.default([arg16_1, add_5], -2);  arg16_1 = add_5 = None
    cat_4 = torch.ops.aten.cat.default([arg17_1, transpose_3], -2);  arg17_1 = transpose_3 = None
    slice_21 = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
    slice_22 = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
    unsqueeze_9 = torch.ops.aten.unsqueeze.default(slice_22, 2);  slice_22 = None
    sym_size_int_16 = torch.ops.aten.sym_size.int(cat_3, 2)
    slice_23 = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807);  unsqueeze_9 = None
    slice_24 = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807);  slice_23 = None
    expand_2 = torch.ops.aten.expand.default(slice_24, [2, 1, 2, sym_size_int_16, 96]);  slice_24 = None
    reshape_1 = torch.ops.aten.reshape.default(expand_2, [2, 2, sym_size_int_16, 96]);  expand_2 = sym_size_int_16 = None
    slice_25 = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
    slice_26 = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807);  slice_25 = None
    unsqueeze_10 = torch.ops.aten.unsqueeze.default(slice_26, 2);  slice_26 = None
    sym_size_int_17 = torch.ops.aten.sym_size.int(cat_4, 2)
    slice_27 = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
    slice_28 = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807);  slice_27 = None
    expand_3 = torch.ops.aten.expand.default(slice_28, [2, 1, 2, sym_size_int_17, 96]);  slice_28 = None
    reshape_2 = torch.ops.aten.reshape.default(expand_3, [2, 2, sym_size_int_17, 96]);  expand_3 = sym_size_int_17 = None
    slice_29 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
    slice_30 = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807);  slice_29 = None
    slice_31 = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807);  slice_30 = None
    contiguous = torch.ops.aten.contiguous.default(add_4);  add_4 = None
    contiguous_1 = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
    contiguous_2 = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
    scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_31 = None
    transpose_4 = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
    contiguous_3 = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None
    reshape_3 = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_1, -1]);  contiguous_3 = sym_size_int_1 = None
    linear_3 = torch.ops.aten.linear.default(reshape_3, arg4_1);  reshape_3 = arg4_1 = None
    add_7 = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None
    to_10 = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None
    pow_2 = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
    mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None
    add_8 = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
    rsqrt_1 = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
    mul_8 = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None
    to_11 = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
    mul_9 = torch.ops.aten.mul.Tensor(arg9_1, to_11);  arg9_1 = to_11 = None
    linear_4 = torch.ops.aten.linear.default(mul_9, arg5_1);  arg5_1 = None
    silu = torch.ops.aten.silu.default(linear_4);  linear_4 = None
    linear_5 = torch.ops.aten.linear.default(mul_9, arg6_1);  mul_9 = arg6_1 = None
    mul_10 = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None
    linear_6 = torch.ops.aten.linear.default(mul_10, arg7_1);  mul_10 = arg7_1 = None
    add_9 = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None
    to_12 = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None
    pow_3 = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
    mean_2 = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None
    add_10 = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
    rsqrt_2 = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
    mul_11 = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  to_12 = rsqrt_2 = None
    to_13 = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
    mul_12 = torch.ops.aten.mul.Tensor(arg10_1, to_13);  arg10_1 = to_13 = None
    slice_32 = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807);  mul_12 = None
    slice_33 = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807);  slice_32 = None
    slice_34 = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807);  slice_33 = None
    linear_7 = torch.ops.aten.linear.default(slice_34, arg11_1);  slice_34 = arg11_1 = None
    return (linear_7, cat_3, cat_4)

# To see more debug info, please use `graph_module.print_readable()`
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[2, s1]", attention_mask: "i64[2, s1 + s7]", position_ids: "i64[2, s1]", past_key_values_key_cache_0: "f32[2, 1, s7, 96]", past_key_values_value_cache_0: "f32[2, 1, s7, 96]"):
             #
            sym_size_int_19: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_20: "Sym(s7)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[2, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s1 + s7)" = sym_size_int_20 + sym_size_int_19

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
            arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_20 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s1, s1 + s7]" = torch.ops.aten.full.default([sym_size_int_19, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s1, s1 + s7]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s1 + s7]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s1, s1 + s7]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s1, s1 + s7]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze: "f32[1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            slice_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
            slice_2: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            sym_size_int_5: "Sym(2)" = torch.ops.aten.sym_size.int(input_ids, 0);  input_ids = None
            expand: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]);  slice_2 = None
            clone: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_4: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
            slice_5: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
            slice_6: "i64[2, s1 + s7]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_2: "i64[2, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
            unsqueeze_3: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
            slice_7: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
            to: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
            add_2: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
            eq_7: "b8[2, 1, s1, s1 + s7]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_8: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_9: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
            slice_10: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
            masked_fill: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38);  slice_10 = eq_7 = None
            slice_11: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_12: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
            slice_13: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
            copy_: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_6: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[0]
            to_7: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_8: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
            mean: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_9: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9);  p_model_layers_0_input_layernorm_weight = to_9 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[2, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_19, -1, 96]);  linear = None
            transpose_1: "f32[2, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_19, -1, 96]);  linear_1 = None
            transpose_2: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_19, -1, 96]);  linear_2 = sym_size_int_5 = None
            transpose_3: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_7: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            unsqueeze_8: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            mul_4: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
            slice_17: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_18: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[2, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18);  slice_18 = None
            cat_1: "f32[2, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
            mul_5: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
            add_4: "f32[2, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
            slice_19: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_20: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[2, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_2: "f32[2, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
            mul_7: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
            add_5: "f32[2, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
            slice_21: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            slice_22: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
            unsqueeze_9: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2);  slice_22 = None
            slice_23: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807);  unsqueeze_9 = None
            slice_24: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807);  slice_23 = None
            expand_2: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_24, [2, 1, 2, add, 96]);  slice_24 = None
            reshape_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_2, [2, 2, add, 96]);  expand_2 = None
            slice_25: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            slice_26: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807);  slice_25 = None
            unsqueeze_10: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2);  slice_26 = None
            slice_27: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
            slice_28: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807);  slice_27 = None
            expand_3: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_28, [2, 1, 2, add, 96]);  slice_28 = None
            reshape_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_3, [2, 2, add, 96]);  expand_3 = add = None
            slice_29: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
            slice_30: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807);  slice_29 = None
            slice_31: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807);  slice_30 = None
            contiguous: "f32[2, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            contiguous_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
            contiguous_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
            scaled_dot_product_attention: "f32[2, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_31 = None
            transpose_4: "f32[2, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_3: "f32[2, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[2, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_19, -1]);  contiguous_3 = sym_size_int_19 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[2, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight);  reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_10: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean_1: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_11: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11);  p_model_layers_0_post_attention_layernorm_weight = to_11 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[2, s1, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[2, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_12: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_2: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  to_12 = rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_13: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13);  p_model_norm_weight = to_13 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_32: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807);  mul_12 = None
            slice_33: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807);  slice_32 = None
            slice_34: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807);  slice_33 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[2, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight);  slice_34 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", position_ids: "i64[2, s1]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
                unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
                to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
                sym_size_int_13: "Sym(2)" = torch.ops.aten.sym_size.int(position_ids, 0)
                expand_1: "f32[2, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]);  to_1 = sym_size_int_13 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_15: "i64[2, s1]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_6: "i64[2, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
                slice_16: "i64[2, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
                to_2: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                cos: "f32[2, s1, 96]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                sin: "f32[2, s1, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_6: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_7: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_6, to_7)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[2, 48, 1]", to_2: "f32[2, 1, s1]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
                    to_3: "f32[2, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
                    to_4: "f32[2, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
                    to_5: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[2, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
                    transpose: "f32[2, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[2, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                    cos: "f32[2, s1, 96]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                    sin: "f32[2, s1, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_4: USER_OUTPUT

Range constraints: {s1: VR[2, 4096], s1 + s7: VR[4, 8192], s7: VR[1, 4096]}

Back to the original model

Let’s use the same dummy inputs but we use the downloaded model. Dummy inputs and dynamic shapes are created by function onnx_diagnostic.torch_models.llms.get_tiny_llm().

data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

Let’s print the inputs.

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
                    1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
 'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
               1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>},
 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}],
                     [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]],
 'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
                  1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}}

And Let’s finally export.

try:
    ep = torch.export.export(
        model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
[_catch_produce_guards_and_solve_constraints] ERRORproduce_guards_and_solve_constraints failed, use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping
fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f3a2d3946e0>
dynamic_shapes={'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>}, 'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}], [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]]}
equalities_inputs=EqualityConstraint(warn_only=False, source_pairs=[(TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2))], derived_equalities=[], phantom_symbols=[], relaxed_sources={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1)}, _parents={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2)}, _defs={})
original_signature=(input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Union[transformers.cache_utils.Cache, List[torch.FloatTensor], NoneType] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[transformers.models.llama.modeling_llama.KwargsForCausalLM]) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]
_is_torch_jit_trace=False
exc=Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of batch = L['args'][1]['input_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['attention_mask'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['position_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
  - Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
Suggested fixes:
  batch = 2
  L['args'][1]['position_ids'].size()[1] = seq_length
gm=<lambda>()



def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1):
    embedding = torch.ops.aten.embedding.default(arg0_1, arg13_1);  arg0_1 = None
    sym_size_int = torch.ops.aten.sym_size.int(arg16_1, 2)
    sym_size_int_1 = torch.ops.aten.sym_size.int(arg13_1, 1)
    add = sym_size_int + sym_size_int_1
    arange = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False);  sym_size_int = add = None
    sym_size_int_2 = torch.ops.aten.sym_size.int(arg14_1, 1)
    full = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    triu = torch.ops.aten.triu.default(full, 1);  full = None
    arange_1 = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False);  sym_size_int_2 = None
    reshape = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
    gt = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
    mul_ = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
    unsqueeze = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
    unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
    slice_1 = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
    slice_2 = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
    sym_size_int_5 = torch.ops.aten.sym_size.int(arg13_1, 0);  arg13_1 = None
    expand = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]);  slice_2 = None
    clone = torch.ops.aten.clone.default(expand);  expand = None
    slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
    slice_4 = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
    slice_5 = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
    slice_6 = torch.ops.aten.slice.Tensor(arg14_1, 0, 0, 9223372036854775807);  arg14_1 = None
    unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
    unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
    slice_7 = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
    to = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
    add_2 = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
    eq_7 = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
    slice_8 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
    slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
    slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
    masked_fill = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38);  slice_10 = eq_7 = None
    slice_11 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
    slice_12 = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
    slice_13 = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
    copy_ = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None
    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
    unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg12_1, 0);  arg12_1 = None
    slice_14 = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
    unsqueeze_5 = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
    to_1 = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
    sym_size_int_13 = torch.ops.aten.sym_size.int(arg15_1, 0)
    expand_1 = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]);  to_1 = sym_size_int_13 = None
    slice_15 = torch.ops.aten.slice.Tensor(arg15_1, 0, 0, 9223372036854775807);  arg15_1 = None
    unsqueeze_6 = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
    slice_16 = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
    to_2 = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', torch.bfloat16, False, False)
    to_3 = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
    to_4 = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
    to_5 = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
    matmul = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
    transpose = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None
    cat = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None
    cos = torch.ops.aten.cos.default(cat)
    sin = torch.ops.aten.sin.default(cat);  cat = None
    _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
    mul = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None
    mul_1 = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None
    to_6 = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
    to_7 = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
    to_8 = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None
    pow_1 = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
    mean = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None
    add_3 = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
    rsqrt = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
    mul_2 = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None
    to_9 = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
    mul_3 = torch.ops.aten.mul.Tensor(arg8_1, to_9);  arg8_1 = to_9 = None
    linear = torch.ops.aten.linear.default(mul_3, arg1_1);  arg1_1 = None
    view = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_1, -1, 96]);  linear = None
    transpose_1 = torch.ops.aten.transpose.int(view, 1, 2);  view = None
    linear_1 = torch.ops.aten.linear.default(mul_3, arg2_1);  arg2_1 = None
    view_1 = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_1, -1, 96]);  linear_1 = None
    transpose_2 = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None
    linear_2 = torch.ops.aten.linear.default(mul_3, arg3_1);  mul_3 = arg3_1 = None
    view_2 = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_1, -1, 96]);  linear_2 = sym_size_int_5 = None
    transpose_3 = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None
    unsqueeze_7 = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
    unsqueeze_8 = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
    mul_4 = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
    slice_17 = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
    slice_18 = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
    neg = torch.ops.aten.neg.default(slice_18);  slice_18 = None
    cat_1 = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
    mul_5 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
    add_4 = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
    mul_6 = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
    slice_19 = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
    slice_20 = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
    neg_1 = torch.ops.aten.neg.default(slice_20);  slice_20 = None
    cat_2 = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
    mul_7 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
    add_5 = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None
    cat_3 = torch.ops.aten.cat.default([arg16_1, add_5], -2);  arg16_1 = add_5 = None
    cat_4 = torch.ops.aten.cat.default([arg17_1, transpose_3], -2);  arg17_1 = transpose_3 = None
    slice_21 = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
    slice_22 = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
    unsqueeze_9 = torch.ops.aten.unsqueeze.default(slice_22, 2);  slice_22 = None
    sym_size_int_16 = torch.ops.aten.sym_size.int(cat_3, 2)
    slice_23 = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807);  unsqueeze_9 = None
    slice_24 = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807);  slice_23 = None
    expand_2 = torch.ops.aten.expand.default(slice_24, [2, 1, 2, sym_size_int_16, 96]);  slice_24 = None
    reshape_1 = torch.ops.aten.reshape.default(expand_2, [2, 2, sym_size_int_16, 96]);  expand_2 = sym_size_int_16 = None
    slice_25 = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
    slice_26 = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807);  slice_25 = None
    unsqueeze_10 = torch.ops.aten.unsqueeze.default(slice_26, 2);  slice_26 = None
    sym_size_int_17 = torch.ops.aten.sym_size.int(cat_4, 2)
    slice_27 = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
    slice_28 = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807);  slice_27 = None
    expand_3 = torch.ops.aten.expand.default(slice_28, [2, 1, 2, sym_size_int_17, 96]);  slice_28 = None
    reshape_2 = torch.ops.aten.reshape.default(expand_3, [2, 2, sym_size_int_17, 96]);  expand_3 = sym_size_int_17 = None
    slice_29 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
    slice_30 = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807);  slice_29 = None
    slice_31 = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807);  slice_30 = None
    contiguous = torch.ops.aten.contiguous.default(add_4);  add_4 = None
    contiguous_1 = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
    contiguous_2 = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
    scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_31 = None
    transpose_4 = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
    contiguous_3 = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None
    reshape_3 = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_1, -1]);  contiguous_3 = sym_size_int_1 = None
    linear_3 = torch.ops.aten.linear.default(reshape_3, arg4_1);  reshape_3 = arg4_1 = None
    add_7 = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None
    to_10 = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None
    pow_2 = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
    mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None
    add_8 = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
    rsqrt_1 = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
    mul_8 = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None
    to_11 = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
    mul_9 = torch.ops.aten.mul.Tensor(arg9_1, to_11);  arg9_1 = to_11 = None
    linear_4 = torch.ops.aten.linear.default(mul_9, arg5_1);  arg5_1 = None
    silu = torch.ops.aten.silu.default(linear_4);  linear_4 = None
    linear_5 = torch.ops.aten.linear.default(mul_9, arg6_1);  mul_9 = arg6_1 = None
    mul_10 = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None
    linear_6 = torch.ops.aten.linear.default(mul_10, arg7_1);  mul_10 = arg7_1 = None
    add_9 = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None
    to_12 = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None
    pow_3 = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
    mean_2 = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None
    add_10 = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
    rsqrt_2 = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
    mul_11 = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  to_12 = rsqrt_2 = None
    to_13 = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
    mul_12 = torch.ops.aten.mul.Tensor(arg10_1, to_13);  arg10_1 = to_13 = None
    slice_32 = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807);  mul_12 = None
    slice_33 = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807);  slice_32 = None
    slice_34 = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807);  slice_33 = None
    linear_7 = torch.ops.aten.linear.default(slice_34, arg11_1);  slice_34 = arg11_1 = None
    return (linear_7, cat_3, cat_4)

# To see more debug info, please use `graph_module.print_readable()`
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[2, s1]", attention_mask: "i64[2, s1 + s7]", position_ids: "i64[2, s1]", past_key_values_key_cache_0: "f32[2, 1, s7, 96]", past_key_values_value_cache_0: "f32[2, 1, s7, 96]"):
             #
            sym_size_int_19: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_20: "Sym(s7)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[2, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s1 + s7)" = sym_size_int_20 + sym_size_int_19

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
            arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_20 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s1, s1 + s7]" = torch.ops.aten.full.default([sym_size_int_19, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s1, s1 + s7]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s1 + s7]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s1, s1 + s7]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s1, s1 + s7]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze: "f32[1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            slice_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807);  unsqueeze_1 = None
            slice_2: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            sym_size_int_5: "Sym(2)" = torch.ops.aten.sym_size.int(input_ids, 0);  input_ids = None
            expand: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]);  slice_2 = None
            clone: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_4: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
            slice_5: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
            slice_6: "i64[2, s1 + s7]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_2: "i64[2, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
            unsqueeze_3: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2);  unsqueeze_2 = None
            slice_7: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807);  unsqueeze_3 = None
            to: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
            add_2: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
            eq_7: "b8[2, 1, s1, s1 + s7]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_8: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_9: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
            slice_10: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
            masked_fill: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38);  slice_10 = eq_7 = None
            slice_11: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_12: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
            slice_13: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
            copy_: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, position_ids);  submod_3 = b_model_rotary_emb_inv_freq = position_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_6: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[0]
            to_7: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_8: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
            mean: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_9: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9);  p_model_layers_0_input_layernorm_weight = to_9 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[2, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_19, -1, 96]);  linear = None
            transpose_1: "f32[2, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_19, -1, 96]);  linear_1 = None
            transpose_2: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_19, -1, 96]);  linear_2 = sym_size_int_5 = None
            transpose_3: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_7: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            unsqueeze_8: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            mul_4: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
            slice_17: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_18: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[2, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18);  slice_18 = None
            cat_1: "f32[2, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
            mul_5: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8);  cat_1 = None
            add_4: "f32[2, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7);  unsqueeze_7 = None
            slice_19: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_20: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[2, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_2: "f32[2, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
            mul_7: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8);  cat_2 = unsqueeze_8 = None
            add_5: "f32[2, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
            slice_21: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            slice_22: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
            unsqueeze_9: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2);  slice_22 = None
            slice_23: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807);  unsqueeze_9 = None
            slice_24: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807);  slice_23 = None
            expand_2: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_24, [2, 1, 2, add, 96]);  slice_24 = None
            reshape_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_2, [2, 2, add, 96]);  expand_2 = None
            slice_25: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            slice_26: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807);  slice_25 = None
            unsqueeze_10: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2);  slice_26 = None
            slice_27: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
            slice_28: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807);  slice_27 = None
            expand_3: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_28, [2, 1, 2, add, 96]);  slice_28 = None
            reshape_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_3, [2, 2, add, 96]);  expand_3 = add = None
            slice_29: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
            slice_30: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807);  slice_29 = None
            slice_31: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807);  slice_30 = None
            contiguous: "f32[2, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            contiguous_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
            contiguous_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
            scaled_dot_product_attention: "f32[2, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_31 = None
            transpose_4: "f32[2, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_3: "f32[2, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[2, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_19, -1]);  contiguous_3 = sym_size_int_19 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[2, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight);  reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_10: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean_1: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_11: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11);  p_model_layers_0_post_attention_layernorm_weight = to_11 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[2, s1, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[2, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_12: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_2: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  to_12 = rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_13: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13);  p_model_norm_weight = to_13 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_32: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807);  mul_12 = None
            slice_33: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807);  slice_32 = None
            slice_34: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807);  slice_33 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[2, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight);  slice_34 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", position_ids: "i64[2, s1]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807);  unsqueeze_4 = None
                unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
                to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32);  unsqueeze_5 = None
                sym_size_int_13: "Sym(2)" = torch.ops.aten.sym_size.int(position_ids, 0)
                expand_1: "f32[2, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]);  to_1 = sym_size_int_13 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_15: "i64[2, s1]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807);  position_ids = None
                unsqueeze_6: "i64[2, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
                slice_16: "i64[2, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807);  unsqueeze_6 = None
                to_2: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                cos: "f32[2, s1, 96]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                sin: "f32[2, s1, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_6: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_7: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_6, to_7)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[2, 48, 1]", to_2: "f32[2, 1, s1]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
                    to_3: "f32[2, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
                    to_4: "f32[2, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
                    to_5: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[2, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
                    transpose: "f32[2, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[2, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                    cos: "f32[2, s1, 96]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                    sin: "f32[2, s1, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature:
    # inputs
    p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
    p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
    p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
    p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
    p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
    p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
    p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
    p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
    p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
    p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
    p_model_norm_weight: PARAMETER target='model.norm.weight'
    p_lm_head_weight: PARAMETER target='lm_head.weight'
    b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
    input_ids: USER_INPUT
    attention_mask: USER_INPUT
    position_ids: USER_INPUT
    past_key_values_key_cache_0: USER_INPUT
    past_key_values_value_cache_0: USER_INPUT

    # outputs
    linear_7: USER_OUTPUT
    cat_3: USER_OUTPUT
    cat_4: USER_OUTPUT

Range constraints: {s1: VR[2, 4096], s1 + s7: VR[4, 8192], s7: VR[1, 4096]}

If you have any error, then look at example Export Tiny-LLM with patches.

doc.plot_legend("Tiny-LLM fails", "torch.export.export", "tomato")
plot export tiny llm

Total running time of the script: (0 minutes 2.703 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export with DynamicCache and dynamic shapes

Export with DynamicCache and dynamic shapes

Find and fix an export issue due to dynamic shapes

Find and fix an export issue due to dynamic shapes

Gallery generated by Sphinx-Gallery