Steel method forward to guess the dynamic shapes

Inputs are always dynamic with LLMs that is why dynamic shapes needs to be specified when a LLM is exported with:func:torch.export.export. Most of the examples on HuggingFace use method transformers.GenerationMixin.generate() but we only want to export the model and its method forward.

That example shows to guess the inputs of this method even though the model is executed through meth generate.

We focus on the model Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.

Steel the forward method

The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.llms import get_tiny_llm


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

We rewrite the forward method to print the cache dimension.

def _forward_(*args, _f=None, **kwargs):
    assert _f is not None
    if not torch.compiler.is_exporting():
        print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
    res = _f(*args, **kwargs)
    if not torch.compiler.is_exporting():
        print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
    return res


keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
    *args, _f=_f, **kwargs
)

Let’s run the model.

prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]),input_ids:T7s1x1[29908,29908:A29908.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.509540557861328,6.348220348358154:A-0.12333908123679672]], value_cache=#1[T1s1x1x10x96[-0.7138619422912598,0.7704185843467712:A0.006287716961484572]]),input_ids:T7s1x1[29908,29908:A29908.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.509540557861328,6.348220348358154:A-0.12333908123679672]], value_cache=#1[T1s1x1x10x96[-0.7138619422912598,0.7704185843467712:A0.006287716961484572]]),input_ids:T7s1x1[29967,29967:A29967.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.509540557861328,6.348220348358154:A-0.11205155890958408]], value_cache=#1[T1s1x1x11x96[-0.7138619422912598,0.7704185843467712:A0.007200541126304884]]),input_ids:T7s1x1[29967,29967:A29967.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.509540557861328,6.348220348358154:A-0.11205155890958408]], value_cache=#1[T1s1x1x11x96[-0.7138619422912598,0.7704185843467712:A0.007200541126304884]]),input_ids:T7s1x1[273,273:A273.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.509540557861328,6.348220348358154:A-0.08577997010060143]], value_cache=#1[T1s1x1x12x96[-0.7138619422912598,0.7704185843467712:A0.007518486174666982]]),input_ids:T7s1x1[273,273:A273.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.509540557861328,6.348220348358154:A-0.08577997010060143]], value_cache=#1[T1s1x1x12x96[-0.7138619422912598,0.7704185843467712:A0.007518486174666982]]),input_ids:T7s1x1[293,293:A293.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.509540557861328,6.348220348358154:A-0.06358176507367319]], value_cache=#1[T1s1x1x13x96[-0.7138619422912598,0.7704185843467712:A0.00504805397781055]]),input_ids:T7s1x1[293,293:A293.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.509540557861328,6.348220348358154:A-0.06358176507367319]], value_cache=#1[T1s1x1x13x96[-0.7138619422912598,0.7704185843467712:A0.00504805397781055]]),input_ids:T7s1x1[1608,1608:A1608.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.509540557861328,6.348220348358154:A-0.055063005458751355]], value_cache=#1[T1s1x1x14x96[-0.7138619422912598,0.7704185843467712:A0.004413282569300593]]),input_ids:T7s1x1[1608,1608:A1608.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.509540557861328,6.348220348358154:A-0.055063005458751355]], value_cache=#1[T1s1x1x14x96[-0.7138619422912598,0.7704185843467712:A0.004413282569300593]]),input_ids:T7s1x1[376,376:A376.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.509540557861328,6.348220348358154:A-0.06414192461678693]], value_cache=#1[T1s1x1x15x96[-0.7138619422912598,0.7704185843467712:A0.0043507608415994685]]),input_ids:T7s1x1[376,376:A376.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.509540557861328,6.348220348358154:A-0.06414192461678693]], value_cache=#1[T1s1x1x15x96[-0.7138619422912598,0.7704185843467712:A0.0043507608415994685]]),input_ids:T7s1x1[3492,3492:A3492.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.637684345245361,6.348220348358154:A-0.06626327035784622]], value_cache=#1[T1s1x1x16x96[-0.7138619422912598,0.7704185843467712:A0.004901123234266909]]),input_ids:T7s1x1[3492,3492:A3492.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.637684345245361,6.348220348358154:A-0.06626327035784622]], value_cache=#1[T1s1x1x16x96[-0.7138619422912598,0.7704185843467712:A0.004901123234266909]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.637684345245361,6.348220348358154:A-0.06418875876696285]], value_cache=#1[T1s1x1x17x96[-1.1154754161834717,0.7704185843467712:A0.0030262298805877543]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.637684345245361,6.348220348358154:A-0.06418875876696285]], value_cache=#1[T1s1x1x17x96[-1.1154754161834717,0.7704185843467712:A0.0030262298805877543]]),input_ids:T7s1x1[276,276:A276.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.637684345245361,6.348220348358154:A-0.06393031390022468]], value_cache=#1[T1s1x1x18x96[-1.1154754161834717,0.7704185843467712:A0.0022916303218136346]]),input_ids:T7s1x1[276,276:A276.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.637684345245361,6.348220348358154:A-0.06393031390022468]], value_cache=#1[T1s1x1x18x96[-1.1154754161834717,0.7704185843467712:A0.0022916303218136346]]),input_ids:T7s1x1[2675,2675:A2675.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.637684345245361,6.348220348358154:A-0.060121849416885344]], value_cache=#1[T1s1x1x19x96[-1.1154754161834717,0.7704185843467712:A0.0028799789417797384]]),input_ids:T7s1x1[2675,2675:A2675.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.637684345245361,6.348220348358154:A-0.060121849416885344]], value_cache=#1[T1s1x1x19x96[-1.1154754161834717,0.7704185843467712:A0.0028799789417797384]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.741750717163086,6.9735283851623535:A-0.057861972187614207]], value_cache=#1[T1s1x1x20x96[-1.1154754161834717,0.7704185843467712:A0.003585792931388217]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.741750717163086,6.9735283851623535:A-0.057861972187614207]], value_cache=#1[T1s1x1x20x96[-1.1154754161834717,0.7704185843467712:A0.003585792931388217]]),input_ids:T7s1x1[679,679:A679.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.741750717163086,6.9735283851623535:A-0.055930863007695675]], value_cache=#1[T1s1x1x21x96[-1.1154754161834717,0.7704185843467712:A0.0035763041296738105]]),input_ids:T7s1x1[679,679:A679.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.741750717163086,6.9735283851623535:A-0.055930863007695675]], value_cache=#1[T1s1x1x21x96[-1.1154754161834717,0.7704185843467712:A0.0035763041296738105]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.741750717163086,6.976338863372803:A-0.0555708509484358]], value_cache=#1[T1s1x1x22x96[-1.1154754161834717,0.7704185843467712:A0.004186302066231788]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.741750717163086,6.976338863372803:A-0.0555708509484358]], value_cache=#1[T1s1x1x22x96[-1.1154754161834717,0.7704185843467712:A0.004186302066231788]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.741750717163086,6.976338863372803:A-0.05874493688510932]], value_cache=#1[T1s1x1x23x96[-1.1154754161834717,0.7704185843467712:A0.004598314676414979]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.741750717163086,6.976338863372803:A-0.05874493688510932]], value_cache=#1[T1s1x1x23x96[-1.1154754161834717,0.7704185843467712:A0.004598314676414979]]),input_ids:T7s1x1[1298,1298:A1298.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.741750717163086,6.976338863372803:A-0.056680566019824456]], value_cache=#1[T1s1x1x24x96[-1.1154754161834717,0.7704185843467712:A0.004455911805602379]]),input_ids:T7s1x1[1298,1298:A1298.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.741750717163086,6.976338863372803:A-0.056680566019824456]], value_cache=#1[T1s1x1x24x96[-1.1154754161834717,0.7704185843467712:A0.004455911805602379]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.741750717163086,6.976338863372803:A-0.05747400988645192]], value_cache=#1[T1s1x1x25x96[-1.1154754161834717,0.7704185843467712:A0.004577871027813671]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.741750717163086,6.976338863372803:A-0.05747400988645192]], value_cache=#1[T1s1x1x25x96[-1.1154754161834717,0.7704185843467712:A0.004577871027813671]]),input_ids:T7s1x1[2714,2714:A2714.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.741750717163086,6.976338863372803:A-0.05311932354151139]], value_cache=#1[T1s1x1x26x96[-1.1154754161834717,0.7704185843467712:A0.00480658852845222]]),input_ids:T7s1x1[2714,2714:A2714.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.741750717163086,6.976338863372803:A-0.05311932354151139]], value_cache=#1[T1s1x1x26x96[-1.1154754161834717,0.7704185843467712:A0.00480658852845222]]),input_ids:T7s1x1[366,366:A366.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.741750717163086,6.976338863372803:A-0.05179963874921489]], value_cache=#1[T1s1x1x27x96[-1.1154754161834717,0.7704185843467712:A0.0055259858195762564]]),input_ids:T7s1x1[366,366:A366.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.741750717163086,6.976338863372803:A-0.05179963874921489]], value_cache=#1[T1s1x1x27x96[-1.1154754161834717,0.7704185843467712:A0.0055259858195762564]]),input_ids:T7s1x1[1033,1033:A1033.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.741750717163086,6.976338863372803:A-0.04720863733222186]], value_cache=#1[T1s1x1x28x96[-1.1154754161834717,0.7704185843467712:A0.005759596489903223]]),input_ids:T7s1x1[1033,1033:A1033.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.741750717163086,6.976338863372803:A-0.04720863733222186]], value_cache=#1[T1s1x1x28x96[-1.1154754161834717,0.7704185843467712:A0.005759596489903223]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.741750717163086,6.976338863372803:A-0.04382178836799697]], value_cache=#1[T1s1x1x29x96[-1.1154754161834717,0.7704185843467712:A0.004630918549621087]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.741750717163086,6.976338863372803:A-0.04382178836799697]], value_cache=#1[T1s1x1x29x96[-1.1154754161834717,0.7704185843467712:A0.004630918549621087]]),input_ids:T7s1x1[345,345:A345.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.741750717163086,6.976338863372803:A-0.03854061667981215]], value_cache=#1[T1s1x1x30x96[-1.1154754161834717,0.7704185843467712:A0.003502099651849575]]),input_ids:T7s1x1[345,345:A345.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.741750717163086,6.976338863372803:A-0.03854061667981215]], value_cache=#1[T1s1x1x30x96[-1.1154754161834717,0.7704185843467712:A0.003502099651849575]]),input_ids:T7s1x1[925,925:A925.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.741750717163086,6.976338863372803:A-0.04191077151503984]], value_cache=#1[T1s1x1x31x96[-1.1154754161834717,0.7704185843467712:A0.003402401293804223]]),input_ids:T7s1x1[925,925:A925.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.741750717163086,6.976338863372803:A-0.04191077151503984]], value_cache=#1[T1s1x1x31x96[-1.1154754161834717,0.7704185843467712:A0.003402401293804223]]),input_ids:T7s1x1[2309,2309:A2309.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.741750717163086,6.976338863372803:A-0.040550344240841696]], value_cache=#1[T1s1x1x32x96[-1.1154754161834717,0.7704185843467712:A0.0031453480477949824]]),input_ids:T7s1x1[2309,2309:A2309.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.741750717163086,6.976338863372803:A-0.040550344240841696]], value_cache=#1[T1s1x1x32x96[-1.1154754161834717,0.7704185843467712:A0.0031453480477949824]]),input_ids:T7s1x1[1554,1554:A1554.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.741750717163086,6.976338863372803:A-0.03767420404873326]], value_cache=#1[T1s1x1x33x96[-1.1154754161834717,0.7704185843467712:A0.003041020568625846]]),input_ids:T7s1x1[1554,1554:A1554.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.741750717163086,6.976338863372803:A-0.03767420404873326]], value_cache=#1[T1s1x1x33x96[-1.1154754161834717,0.7704185843467712:A0.003041020568625846]]),input_ids:T7s1x1[363,363:A363.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.741750717163086,7.764976501464844:A-0.03646976319374945]], value_cache=#1[T1s1x1x34x96[-1.1154754161834717,0.7704185843467712:A0.002769684347662869]]),input_ids:T7s1x1[363,363:A363.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.741750717163086,7.764976501464844:A-0.03646976319374945]], value_cache=#1[T1s1x1x34x96[-1.1154754161834717,0.7704185843467712:A0.002769684347662869]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.741750717163086,7.764976501464844:A-0.032388471439307544]], value_cache=#1[T1s1x1x35x96[-1.1154754161834717,0.7704185843467712:A0.003086334315555307]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.741750717163086,7.764976501464844:A-0.032388471439307544]], value_cache=#1[T1s1x1x35x96[-1.1154754161834717,0.7704185843467712:A0.003086334315555307]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.741750717163086,7.764976501464844:A-0.03246603250654853]], value_cache=#1[T1s1x1x36x96[-1.1154754161834717,0.7704185843467712:A0.003209072039036679]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.741750717163086,7.764976501464844:A-0.03246603250654853]], value_cache=#1[T1s1x1x36x96[-1.1154754161834717,0.7704185843467712:A0.003209072039036679]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.741750717163086,7.764976501464844:A-0.030996181259422325]], value_cache=#1[T1s1x1x37x96[-1.1154754161834717,0.7704185843467712:A0.0023933656655957224]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.741750717163086,7.764976501464844:A-0.030996181259422325]], value_cache=#1[T1s1x1x37x96[-1.1154754161834717,0.7704185843467712:A0.0023933656655957224]]),input_ids:T7s1x1[29885,29885:A29885.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.741750717163086,7.764976501464844:A-0.030970624311904477]], value_cache=#1[T1s1x1x38x96[-1.1154754161834717,0.7704185843467712:A0.0017675331577896016]]),input_ids:T7s1x1[29885,29885:A29885.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.741750717163086,7.764976501464844:A-0.030970624311904477]], value_cache=#1[T1s1x1x38x96[-1.1154754161834717,0.7704185843467712:A0.0017675331577896016]]),input_ids:T7s1x1[2675,2675:A2675.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.741750717163086,7.764976501464844:A-0.02776791333129732]], value_cache=#1[T1s1x1x39x96[-1.1154754161834717,0.7704185843467712:A0.002067602925568576]]),input_ids:T7s1x1[2675,2675:A2675.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.741750717163086,7.764976501464844:A-0.02776791333129732]], value_cache=#1[T1s1x1x39x96[-1.1154754161834717,0.7704185843467712:A0.002067602925568576]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.741750717163086,7.764976501464844:A-0.025911171886491502]], value_cache=#1[T1s1x1x40x96[-1.1154754161834717,0.7704185843467712:A0.0024408193207780945]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.741750717163086,7.764976501464844:A-0.025911171886491502]], value_cache=#1[T1s1x1x40x96[-1.1154754161834717,0.7704185843467712:A0.0024408193207780945]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.741750717163086,7.764976501464844:A-0.02411022919737824]], value_cache=#1[T1s1x1x41x96[-1.1154754161834717,0.7704185843467712:A0.001575552628860979]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.741750717163086,7.764976501464844:A-0.02411022919737824]], value_cache=#1[T1s1x1x41x96[-1.1154754161834717,0.7704185843467712:A0.001575552628860979]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.741750717163086,7.764976501464844:A-0.022309982135474065]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.0019342861556500186]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.741750717163086,7.764976501464844:A-0.022309982135474065]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.0019342861556500186]]),input_ids:T7s1x1[982,982:A982.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.741750717163086,7.764976501464844:A-0.021771217863631117]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.002564699534753507]]),input_ids:T7s1x1[982,982:A982.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.741750717163086,7.764976501464844:A-0.021771217863631117]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.002564699534753507]]),input_ids:T7s1x1[727,727:A727.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.741750717163086,7.764976501464844:A-0.02141206981168304]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A0.0015399219475863322]]),input_ids:T7s1x1[727,727:A727.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.741750717163086,7.764976501464844:A-0.02141206981168304]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A0.0015399219475863322]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.741750717163086,7.764976501464844:A-0.022240773732294286]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A0.0017627863282103607]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.741750717163086,7.764976501464844:A-0.022240773732294286]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A0.0017627863282103607]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.741750717163086,7.764976501464844:A-0.023507214221208942]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A0.0018876147202250202]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.741750717163086,7.764976501464844:A-0.023507214221208942]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A0.0018876147202250202]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.741750717163086,7.764976501464844:A-0.022914434492406197]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A0.0012735790074908978]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.741750717163086,7.764976501464844:A-0.022914434492406197]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A0.0012735790074908978]]),input_ids:T7s1x1[29885,29885:A29885.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.741750717163086,7.764976501464844:A-0.023525536320524527]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A0.0008014571608549027]]),input_ids:T7s1x1[29885,29885:A29885.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.741750717163086,7.764976501464844:A-0.023525536320524527]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A0.0008014571608549027]]),input_ids:T7s1x1[2734,2734:A2734.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-6.741750717163086,7.764976501464844:A-0.021361293218792123]], value_cache=#1[T1s1x1x49x96[-1.1154754161834717,0.7704185843467712:A0.0005911843754709443]]),input_ids:T7s1x1[2734,2734:A2734.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
Continue: it rains...
"Janicism "You're going to get to the point I thought you could've just done something for, I'm going to have a way there. I'm running this

Let’s restore the forward as it was.

model.forward = keep_model_forward

Untrained model

This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.

Let’s create an untrained model using the config file provided config.json to create an untrained model: onnx_diagnostic.torch_models.llms.get_tiny_llm(). Then let’s use it.

experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.

print("input type before", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)

print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

The outputs

print("result type", string_type(expected_output, with_shape=True))
result type dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

It works.

ExportedProgram

try:
    ep = torch.export.export(
        untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s5]", past_key_values_key_cache_0: "f32[s0, 1, s5, 96]", past_key_values_value_cache_0: "f32[s0, 1, s5, 96]"):
             #
            sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
            sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_22: "Sym(s5)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[s0, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s1 + s5)" = sym_size_int_22 + sym_size_int_21

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
            arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_22 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:569 in forward, code: position_ids = cache_position.unsqueeze(0)
            unsqueeze: "i64[1, s1]" = torch.ops.aten.unsqueeze.default(arange, 0)

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s1, s1 + s5]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s1, s1 + s5]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s1 + s5]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s1, s1 + s5]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s1, s1 + s5]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze_1: "f32[1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
            slice_1: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
            slice_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            expand: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]);  slice_2 = None
            clone: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_4: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
            slice_5: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
            slice_6: "i64[s0, s1 + s5]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_3: "i64[s0, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
            unsqueeze_4: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
            slice_7: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
            to: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
            add_2: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
            eq_7: "b8[s0, 1, s1, s1 + s5]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_8: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_9: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
            slice_10: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
            masked_fill: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38);  slice_10 = eq_7 = None
            slice_11: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_12: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
            slice_13: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
            copy_: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze);  submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_6: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[0]
            to_7: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_8: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
            mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_9: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9);  p_model_layers_0_input_layernorm_weight = to_9 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s0, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 96]);  linear = None
            transpose_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 96]);  linear_1 = None
            transpose_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 96]);  linear_2 = None
            transpose_3: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_8: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            unsqueeze_9: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            mul_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_8)
            slice_17: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_18: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s0, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18);  slice_18 = None
            cat_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
            mul_5: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9);  cat_1 = None
            add_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_8);  unsqueeze_8 = None
            slice_19: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_20: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s0, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
            mul_7: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_9);  cat_2 = unsqueeze_9 = None
            add_5: "f32[s0, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
            slice_21: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            slice_22: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
            unsqueeze_10: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2);  slice_22 = None
            slice_23: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
            slice_24: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807);  slice_23 = None
            expand_2: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_24, [sym_size_int_20, 1, 2, add, 96]);  slice_24 = None
            reshape_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_20, 2, add, 96]);  expand_2 = None
            slice_25: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            slice_26: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807);  slice_25 = None
            unsqueeze_11: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2);  slice_26 = None
            slice_27: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807);  unsqueeze_11 = None
            slice_28: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807);  slice_27 = None
            expand_3: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_28, [sym_size_int_20, 1, 2, add, 96]);  slice_28 = None
            reshape_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_20, 2, add, 96]);  expand_3 = add = None
            slice_29: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
            slice_30: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807);  slice_29 = None
            slice_31: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807);  slice_30 = None
            contiguous: "f32[s0, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            contiguous_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
            contiguous_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
            scaled_dot_product_attention: "f32[s0, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_31 = None
            transpose_4: "f32[s0, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_3: "f32[s0, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[s0, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]);  contiguous_3 = sym_size_int_20 = sym_size_int_21 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight);  reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_10: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_11: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11);  p_model_layers_0_post_attention_layernorm_weight = to_11 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s0, s1, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[s0, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_12: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  to_12 = rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_13: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13);  p_model_norm_weight = to_13 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_32: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807);  mul_12 = None
            slice_33: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807);  slice_32 = None
            slice_34: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807);  slice_33 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight);  slice_34 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", unsqueeze: "i64[1, s1]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807);  unsqueeze_5 = None
                unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
                to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32);  unsqueeze_6 = None
                expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_1, [1, -1, 1]);  to_1 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_15: "i64[1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807);  unsqueeze = None
                unsqueeze_7: "i64[1, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
                slice_16: "i64[1, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807);  unsqueeze_7 = None
                to_2: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                cos: "f32[1, s1, 96]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                sin: "f32[1, s1, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_6: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_7: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_6, to_7)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[1, 48, 1]", to_2: "f32[1, 1, s1]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
                    to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
                    to_4: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
                    to_5: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[1, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
                    transpose: "f32[1, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[1, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                    cos: "f32[1, s1, 96]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                    sin: "f32[1, s1, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_embed_tokens_weight'), target='model.embed_tokens.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_q_proj_weight'), target='model.layers.0.self_attn.q_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_k_proj_weight'), target='model.layers.0.self_attn.k_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_v_proj_weight'), target='model.layers.0.self_attn.v_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_o_proj_weight'), target='model.layers.0.self_attn.o_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_gate_proj_weight'), target='model.layers.0.mlp.gate_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_up_proj_weight'), target='model.layers.0.mlp.up_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_down_proj_weight'), target='model.layers.0.mlp.down_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_input_layernorm_weight'), target='model.layers.0.input_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_post_attention_layernorm_weight'), target='model.layers.0.post_attention_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_norm_weight'), target='model.norm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lm_head_weight'), target='lm_head.weight', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_model_rotary_emb_inv_freq'), target='model.rotary_emb.inv_freq', persistent=False), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_ids'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='attention_mask'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_4'), target=None)])
Range constraints: {s0: VR[1, 1024], s1: VR[1, 4096], s1 + s5: VR[4, 8192], s5: VR[1, 4096]}

Back to the original model

Let’s use the same dummy inputs but we use the downloaded model. Dummy inputs and dynamic shapes are created by function onnx_diagnostic.torch_models.llms.get_tiny_llm().

data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

Let’s print the inputs.

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: <class 'onnx_diagnostic.torch_models.llms.batch'>,
                    1: <_DimHint.DYNAMIC: 3>},
 'input_ids': {0: <class 'onnx_diagnostic.torch_models.llms.batch'>,
               1: <class 'onnx_diagnostic.torch_models.llms.seq_length'>},
 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.llms.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.llms.cache_length'>}],
                     [{0: <class 'onnx_diagnostic.torch_models.llms.batch'>,
                       2: <class 'onnx_diagnostic.torch_models.llms.cache_length'>}]]}

And Let’s finally export.

try:
    ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
It worked:
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s0, s1]", attention_mask: "i64[s0, s1 + s5]", past_key_values_key_cache_0: "f32[s0, 1, s5, 96]", past_key_values_value_cache_0: "f32[s0, 1, s5, 96]"):
             #
            sym_size_int_20: "Sym(s0)" = torch.ops.aten.sym_size.int(input_ids, 0)
            sym_size_int_21: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
            sym_size_int_22: "Sym(s5)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
            embedding: "f32[s0, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids);  p_model_embed_tokens_weight = input_ids = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            add: "Sym(s1 + s5)" = sym_size_int_22 + sym_size_int_21

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
            arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_22, add, device = device(type='cpu'), pin_memory = False);  sym_size_int_22 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:569 in forward, code: position_ids = cache_position.unsqueeze(0)
            unsqueeze: "i64[1, s1]" = torch.ops.aten.unsqueeze.default(arange, 0)

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
            full: "f32[s1, s1 + s5]" = torch.ops.aten.full.default([sym_size_int_21, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            triu: "f32[s1, s1 + s5]" = torch.ops.aten.triu.default(full, 1);  full = None
            arange_1: "i64[s1 + s5]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
            reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]);  arange = None
            gt: "b8[s1, s1 + s5]" = torch.ops.aten.gt.Tensor(arange_1, reshape);  arange_1 = reshape = None
            mul_: "f32[s1, s1 + s5]" = torch.ops.aten.mul_.Tensor(triu, gt);  triu = gt = None
            unsqueeze_1: "f32[1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(mul_, 0);  mul_ = None
            unsqueeze_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 1);  unsqueeze_1 = None
            slice_1: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_2, 2, 0, 9223372036854775807);  unsqueeze_2 = None
            slice_2: "f32[1, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807);  slice_1 = None
            expand: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_20, 1, -1, -1]);  slice_2 = None
            clone: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.clone.default(expand);  expand = None
            slice_3: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_4: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807);  slice_3 = None
            slice_5: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807);  slice_4 = None
            slice_6: "i64[s0, s1 + s5]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807);  attention_mask = None
            unsqueeze_3: "i64[s0, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(slice_6, 1);  slice_6 = None
            unsqueeze_4: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2);  unsqueeze_3 = None
            slice_7: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 3, 0, 9223372036854775807);  unsqueeze_4 = None
            to: "i64[s0, 1, 1, s1 + s5]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'));  slice_7 = None
            add_2: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.add.Tensor(slice_5, to);  slice_5 = to = None
            eq_7: "b8[s0, 1, s1, s1 + s5]" = torch.ops.aten.eq.Scalar(add_2, 0);  add_2 = None
            slice_8: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_9: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807);  slice_8 = None
            slice_10: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807);  slice_9 = None
            masked_fill: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38);  slice_10 = eq_7 = None
            slice_11: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
            slice_12: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807);  slice_11 = None
            slice_13: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807);  slice_12 = None
            copy_: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.copy_.default(slice_13, masked_fill);  slice_13 = masked_fill = copy_ = None

            # No stacktrace found for following nodes
            submod_3 = self.submod_1
            wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, unsqueeze);  submod_3 = b_model_rotary_emb_inv_freq = unsqueeze = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
            to_6: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[0]
            to_7: "f32[1, s1, 96]" = wrap_with_set_grad_enabled[1];  wrap_with_set_grad_enabled = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_8: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32);  embedding = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_1: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
            mean: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_3: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05);  mean = None
            rsqrt: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_3);  add_3 = None
            mul_2: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt);  rsqrt = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_9: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32);  mul_2 = None
            mul_3: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9);  p_model_layers_0_input_layernorm_weight = to_9 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight);  p_model_layers_0_self_attn_q_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view: "f32[s0, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_20, sym_size_int_21, -1, 96]);  linear = None
            transpose_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2);  view = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight);  p_model_layers_0_self_attn_k_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_1: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_20, sym_size_int_21, -1, 96]);  linear_1 = None
            transpose_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2);  view_1 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_2: "f32[s0, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight);  mul_3 = p_model_layers_0_self_attn_v_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
            view_2: "f32[s0, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_20, sym_size_int_21, -1, 96]);  linear_2 = None
            transpose_3: "f32[s0, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2);  view_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
            unsqueeze_8: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1);  to_6 = None
            unsqueeze_9: "f32[1, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1);  to_7 = None
            mul_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_8)
            slice_17: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
            slice_18: "f32[s0, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807);  transpose_1 = None
            neg: "f32[s0, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18);  slice_18 = None
            cat_1: "f32[s0, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1);  neg = slice_17 = None
            mul_5: "f32[s0, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_9);  cat_1 = None
            add_4: "f32[s0, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5);  mul_4 = mul_5 = None
            mul_6: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_8);  unsqueeze_8 = None
            slice_19: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
            slice_20: "f32[s0, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807);  transpose_2 = None
            neg_1: "f32[s0, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20);  slice_20 = None
            cat_2: "f32[s0, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1);  neg_1 = slice_19 = None
            mul_7: "f32[s0, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_9);  cat_2 = unsqueeze_9 = None
            add_5: "f32[s0, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7);  mul_6 = mul_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
            cat_3: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2);  past_key_values_key_cache_0 = add_5 = None
            cat_4: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2);  past_key_values_value_cache_0 = transpose_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
            slice_21: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
            slice_22: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807);  slice_21 = None
            unsqueeze_10: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2);  slice_22 = None
            slice_23: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807);  unsqueeze_10 = None
            slice_24: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807);  slice_23 = None
            expand_2: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_24, [sym_size_int_20, 1, 2, add, 96]);  slice_24 = None
            reshape_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_20, 2, add, 96]);  expand_2 = None
            slice_25: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
            slice_26: "f32[s0, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807);  slice_25 = None
            unsqueeze_11: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2);  slice_26 = None
            slice_27: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807);  unsqueeze_11 = None
            slice_28: "f32[s0, 1, 1, s1 + s5, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807);  slice_27 = None
            expand_3: "f32[s0, 1, 2, s1 + s5, 96]" = torch.ops.aten.expand.default(slice_28, [sym_size_int_20, 1, 2, add, 96]);  slice_28 = None
            reshape_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_20, 2, add, 96]);  expand_3 = add = None
            slice_29: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807);  clone = None
            slice_30: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807);  slice_29 = None
            slice_31: "f32[s0, 1, s1, s1 + s5]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807);  slice_30 = None
            contiguous: "f32[s0, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4);  add_4 = None
            contiguous_1: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_1);  reshape_1 = None
            contiguous_2: "f32[s0, 2, s1 + s5, 96]" = torch.ops.aten.contiguous.default(reshape_2);  reshape_2 = None
            scaled_dot_product_attention: "f32[s0, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575);  contiguous = contiguous_1 = contiguous_2 = slice_31 = None
            transpose_4: "f32[s0, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2);  scaled_dot_product_attention = None
            contiguous_3: "f32[s0, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4);  transpose_4 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            reshape_3: "f32[s0, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_20, sym_size_int_21, -1]);  contiguous_3 = sym_size_int_20 = sym_size_int_21 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_3: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight);  reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
            add_7: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3);  to_8 = linear_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_10: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32);  add_7 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_2: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
            mean_1: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True);  pow_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_8: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05);  mean_1 = None
            rsqrt_1: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_8);  add_8 = None
            mul_8: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1);  rsqrt_1 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_11: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32);  mul_8 = None
            mul_9: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11);  p_model_layers_0_post_attention_layernorm_weight = to_11 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_4: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight);  p_model_layers_0_mlp_gate_proj_weight = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
            silu: "f32[s0, s1, 1024]" = torch.ops.aten.silu.default(linear_4);  linear_4 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_5: "f32[s0, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight);  mul_9 = p_model_layers_0_mlp_up_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
            mul_10: "f32[s0, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5);  silu = linear_5 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_6: "f32[s0, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight);  mul_10 = p_model_layers_0_mlp_down_proj_weight = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
            add_9: "f32[s0, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6);  to_10 = linear_6 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
            to_12: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32);  add_9 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
            pow_3: "f32[s0, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
            mean_2: "f32[s0, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True);  pow_3 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            add_10: "f32[s0, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05);  mean_2 = None
            rsqrt_2: "f32[s0, s1, 1]" = torch.ops.aten.rsqrt.default(add_10);  add_10 = None
            mul_11: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2);  to_12 = rsqrt_2 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
            to_13: "f32[s0, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32);  mul_11 = None
            mul_12: "f32[s0, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13);  p_model_norm_weight = to_13 = None

             # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
            slice_32: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807);  mul_12 = None
            slice_33: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807);  slice_32 = None
            slice_34: "f32[s0, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807);  slice_33 = None

             # File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_7: "f32[s0, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight);  slice_34 = p_lm_head_weight = None
            return (linear_7, cat_3, cat_4)

        class submod_1(torch.nn.Module):
            def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", unsqueeze: "i64[1, s1]"):
                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
                unsqueeze_5: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0);  b_model_rotary_emb_inv_freq = None
                slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_5, 1, 0, 9223372036854775807);  unsqueeze_5 = None
                unsqueeze_6: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2);  slice_14 = None
                to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_6, torch.float32);  unsqueeze_6 = None
                expand_1: "f32[1, 48, 1]" = torch.ops.aten.expand.default(to_1, [1, -1, 1]);  to_1 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
                slice_15: "i64[1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze, 0, 0, 9223372036854775807);  unsqueeze = None
                unsqueeze_7: "i64[1, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1);  slice_15 = None
                slice_16: "i64[1, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807);  unsqueeze_7 = None
                to_2: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32);  slice_16 = None

                # No stacktrace found for following nodes
                submod_3 = self.submod_1
                wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2);  submod_3 = expand_1 = to_2 = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                cos: "f32[1, s1, 96]" = wrap_with_autocast[0]

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                sin: "f32[1, s1, 96]" = wrap_with_autocast[1];  wrap_with_autocast = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
                mul: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0);  cos = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
                mul_1: "f32[1, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0);  sin = None

                 # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
                to_6: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32);  mul = None
                to_7: "f32[1, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32);  mul_1 = None
                return (to_6, to_7)

            class submod_1(torch.nn.Module):
                def forward(self, expand_1: "f32[1, 48, 1]", to_2: "f32[1, 1, s1]"):
                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
                    to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32);  expand_1 = None
                    to_4: "f32[1, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  to_3 = None
                    to_5: "f32[1, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32);  to_2 = None
                    matmul: "f32[1, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5);  to_4 = to_5 = None
                    transpose: "f32[1, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2);  matmul = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
                    cat: "f32[1, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1);  transpose = None

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
                    cos: "f32[1, s1, 96]" = torch.ops.aten.cos.default(cat)

                     # File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
                    sin: "f32[1, s1, 96]" = torch.ops.aten.sin.default(cat);  cat = None
                    return (cos, sin)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_embed_tokens_weight'), target='model.embed_tokens.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_q_proj_weight'), target='model.layers.0.self_attn.q_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_k_proj_weight'), target='model.layers.0.self_attn.k_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_v_proj_weight'), target='model.layers.0.self_attn.v_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_self_attn_o_proj_weight'), target='model.layers.0.self_attn.o_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_gate_proj_weight'), target='model.layers.0.mlp.gate_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_up_proj_weight'), target='model.layers.0.mlp.up_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_mlp_down_proj_weight'), target='model.layers.0.mlp.down_proj.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_input_layernorm_weight'), target='model.layers.0.input_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_layers_0_post_attention_layernorm_weight'), target='model.layers.0.post_attention_layernorm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_model_norm_weight'), target='model.norm.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_lm_head_weight'), target='lm_head.weight', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_model_rotary_emb_inv_freq'), target='model.rotary_emb.inv_freq', persistent=False), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input_ids'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='attention_mask'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='past_key_values_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='linear_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cat_4'), target=None)])
Range constraints: {s0: VR[1, 1024], s1: VR[1, 4096], s1 + s5: VR[4, 8192], s5: VR[1, 4096]}

Total running time of the script: (0 minutes 7.230 seconds)

Related examples

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Export with DynamicCache and dynamic shapes

Export with DynamicCache and dynamic shapes

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Gallery generated by Sphinx-Gallery