Note
Go to the end to download the full example code.
Steel method forward to guess the dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with:func:torch.export.export.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_test_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache[serialized](#2[#0[],#0[]]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]],#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]],#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.432564735412598,8.368535995483398:A-4.234468644971028],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]],#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]],#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.071060180664062,2.7617390155792236:A-9.465396250322462],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x10x96[-5.509540557861328,6.348220348358154:A-0.11434226747017723]],#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.00897657713295909]]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x10x96[-5.509540557861328,6.348220348358154:A-0.11434226747017723]],#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.00897657713295909]]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.54892349243164,12.476265907287598:A-4.136563749908004],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x11x96[-5.509540557861328,6.348220348358154:A-0.10710627211742438]],#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.005327088963716078]]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x11x96[-5.509540557861328,6.348220348358154:A-0.10710627211742438]],#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.005327088963716078]]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.19338607788086,3.1651957035064697:A-11.273599289992823],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x12x96[-5.56224250793457,7.80775785446167:A-0.09510994365973602]],#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.004679121221657725]]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x12x96[-5.56224250793457,7.80775785446167:A-0.09510994365973602]],#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.004679121221657725]]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.81157112121582,2.490788698196411:A-11.573189300978557],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x13x96[-5.562875270843506,7.80775785446167:A-0.0842344671944328]],#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.004130840824531426]]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x13x96[-5.562875270843506,7.80775785446167:A-0.0842344671944328]],#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.004130840824531426]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.792251586914062,9.415217399597168:A-6.523793722998817],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x14x96[-5.562875270843506,7.80775785446167:A-0.07965348597705695]],#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.004098236537856792]]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x14x96[-5.562875270843506,7.80775785446167:A-0.07965348597705695]],#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.004098236537856792]]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-23.976877212524414,2.077348232269287:A-12.43196074456675],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x15x96[-6.845856666564941,7.80775785446167:A-0.08863463782561465]],#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.003449787268608715]]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x15x96[-6.845856666564941,7.80775785446167:A-0.08863463782561465]],#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.003449787268608715]]]),input_ids:T7s1x1[373,373:A373.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.62879753112793,2.709171772003174:A-11.294309411809314],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x16x96[-6.845856666564941,7.80775785446167:A-0.08419364550559294]],#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0020143345384345443]]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x16x96[-6.845856666564941,7.80775785446167:A-0.08419364550559294]],#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0020143345384345443]]]),input_ids:T7s1x1[3786,3786:A3786.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.211437225341797,11.88216781616211:A-11.201580898320302],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x17x96[-6.845856666564941,7.80775785446167:A-0.08415293724094369]],#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.0025193796739978436]]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x17x96[-6.845856666564941,7.80775785446167:A-0.08415293724094369]],#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.0025193796739978436]]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.378000259399414,14.070089340209961:A-7.2346467959224245],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x18x96[-6.845856666564941,7.80775785446167:A-0.09244497385946117]],#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.0006478700960694065]]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x18x96[-6.845856666564941,7.80775785446167:A-0.09244497385946117]],#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.0006478700960694065]]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.43880844116211,11.423129081726074:A-7.807880653257482],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x19x96[-6.845856666564941,7.80775785446167:A-0.08118930570095759]],#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.0003175346962307869]]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x19x96[-6.845856666564941,7.80775785446167:A-0.08118930570095759]],#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.0003175346962307869]]]),input_ids:T7s1x1[29945,29945:A29945.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.945932388305664,9.307157516479492:A-12.187897676732856],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x20x96[-6.845856666564941,7.80775785446167:A-0.0710719657554364]],#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.0003132348386088779]]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x20x96[-6.845856666564941,7.80775785446167:A-0.0710719657554364]],#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.0003132348386088779]]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.38627815246582,11.079444885253906:A-9.634620206483406],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x21x96[-6.845856666564941,7.80775785446167:A-0.07194269531828468]],#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.0003613207834152692]]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x21x96[-6.845856666564941,7.80775785446167:A-0.07194269531828468]],#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.0003613207834152692]]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.021613121032715,17.791135787963867:A-4.578210202363553],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x22x96[-6.845856666564941,7.80775785446167:A-0.07023053201572277]],#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.001071820739863335]]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x22x96[-6.845856666564941,7.80775785446167:A-0.07023053201572277]],#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.001071820739863335]]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.704498291015625,13.999704360961914:A-9.975061703482643],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x23x96[-6.993858814239502,7.80775785446167:A-0.07228078447640629]],#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.0012699373381677276]]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x23x96[-6.993858814239502,7.80775785446167:A-0.07228078447640629]],#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.0012699373381677276]]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.562313079833984,12.042723655700684:A-11.35273957562074],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x24x96[-6.993858814239502,7.80775785446167:A-0.0693251492926821]],#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.002489018587520301]]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x24x96[-6.993858814239502,7.80775785446167:A-0.0693251492926821]],#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.002489018587520301]]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.949621200561523,14.084500312805176:A-7.02805341891339],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x25x96[-6.993858814239502,7.80775785446167:A-0.06178976752926246]],#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.003610573336924669]]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x25x96[-6.993858814239502,7.80775785446167:A-0.06178976752926246]],#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.003610573336924669]]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.43863868713379,5.42636775970459:A-13.913722303581425],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x26x96[-6.993858814239502,7.80775785446167:A-0.05585792070105993]],#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.004207928655706639]]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x26x96[-6.993858814239502,7.80775785446167:A-0.05585792070105993]],#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.004207928655706639]]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-9.474013328552246,12.021588325500488:A-3.8923398731821215],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x27x96[-6.993858814239502,7.80775785446167:A-0.0562344402579909]],#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.003005064840703266]]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x27x96[-6.993858814239502,7.80775785446167:A-0.0562344402579909]],#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.003005064840703266]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.755416870117188,2.6581664085388184:A-10.450381227273494],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x28x96[-6.993858814239502,7.80775785446167:A-0.055390838444725976]],#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.0027665132102822013]]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x28x96[-6.993858814239502,7.80775785446167:A-0.055390838444725976]],#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.0027665132102822013]]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.12242889404297,9.252721786499023:A-7.408372220959515],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x29x96[-6.993858814239502,7.80775785446167:A-0.05673461517929762]],#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.0037458676764350225]]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x29x96[-6.993858814239502,7.80775785446167:A-0.05673461517929762]],#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.0037458676764350225]]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.936534881591797,4.424360275268555:A-11.144390676616691],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x30x96[-6.993858814239502,7.80775785446167:A-0.05863865203100431]],#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.0038086221705826676]]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x30x96[-6.993858814239502,7.80775785446167:A-0.05863865203100431]],#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.0038086221705826676]]]),input_ids:T7s1x1[29953,29953:A29953.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.961956024169922,2.967160701751709:A-11.384877558900044],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x31x96[-6.993858814239502,7.80775785446167:A-0.05885248658633309]],#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.004056779835204369]]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x31x96[-6.993858814239502,7.80775785446167:A-0.05885248658633309]],#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.004056779835204369]]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.61073112487793,6.068974494934082:A-9.463844107881654],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x32x96[-6.993858814239502,7.80775785446167:A-0.05171761695824747]],#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.003497116927107413]]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x32x96[-6.993858814239502,7.80775785446167:A-0.05171761695824747]],#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.003497116927107413]]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.140470504760742,9.447004318237305:A-9.53557308629062],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x33x96[-6.993858814239502,7.80775785446167:A-0.05123007030447803]],#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.003561704368736412]]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x33x96[-6.993858814239502,7.80775785446167:A-0.05123007030447803]],#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.003561704368736412]]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.547964096069336,10.584537506103516:A-9.146475921413861],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x34x96[-7.353572845458984,7.80775785446167:A-0.052414219703018274]],#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.004354827396792091]]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x34x96[-7.353572845458984,7.80775785446167:A-0.052414219703018274]],#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.004354827396792091]]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.020347595214844,5.742450714111328:A-11.164005009222775],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x35x96[-7.353572845458984,7.80775785446167:A-0.05484983952632146]],#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A-0.005102629108958873]]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x35x96[-7.353572845458984,7.80775785446167:A-0.05484983952632146]],#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A-0.005102629108958873]]]),input_ids:T7s1x1[29929,29929:A29929.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.91389274597168,2.289463520050049:A-13.676366100630723],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x36x96[-7.353572845458984,7.80775785446167:A-0.05621831190509746]],#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A-0.005637514253273231]]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x36x96[-7.353572845458984,7.80775785446167:A-0.05621831190509746]],#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A-0.005637514253273231]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.167343139648438,13.443876266479492:A-5.3544019131329845],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x37x96[-7.353572845458984,7.80775785446167:A-0.05481215719529401]],#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A-0.00538584141369594]]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x37x96[-7.353572845458984,7.80775785446167:A-0.05481215719529401]],#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A-0.00538584141369594]]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.998165130615234,8.360471725463867:A-10.564597876278684],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x38x96[-7.353572845458984,7.80775785446167:A-0.05163832419569652]],#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A-0.006047474200773139]]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x38x96[-7.353572845458984,7.80775785446167:A-0.05163832419569652]],#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A-0.006047474200773139]]]),input_ids:T7s1x1[29955,29955:A29955.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.97073745727539,4.956542015075684:A-12.609978926122189],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x39x96[-7.353572845458984,7.80775785446167:A-0.049571250253012476]],#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A-0.006333527961456468]]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x39x96[-7.353572845458984,7.80775785446167:A-0.049571250253012476]],#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A-0.006333527961456468]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.773216247558594,13.519667625427246:A-5.161368548027705],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x40x96[-7.353572845458984,7.80775785446167:A-0.0462463174137459]],#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A-0.006083330242142892]]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x40x96[-7.353572845458984,7.80775785446167:A-0.0462463174137459]],#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A-0.006083330242142892]]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.65859603881836,6.487763404846191:A-11.850197799076327],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x41x96[-7.353572845458984,7.80775785446167:A-0.04713789616934212]],#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A-0.005994676429919509]]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x41x96[-7.353572845458984,7.80775785446167:A-0.04713789616934212]],#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A-0.005994676429919509]]]),input_ids:T7s1x1[29945,29945:A29945.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.357160568237305,5.240086078643799:A-11.694788255400956],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x42x96[-7.353572845458984,7.80775785446167:A-0.047519140047688056]],#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A-0.006144752133982437]]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x42x96[-7.353572845458984,7.80775785446167:A-0.047519140047688056]],#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A-0.006144752133982437]]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.400876998901367,11.760887145996094:A-4.6944094990096055],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x43x96[-7.353572845458984,7.80775785446167:A-0.0455281768692601]],#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A-0.0053444231459971615]]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x43x96[-7.353572845458984,7.80775785446167:A-0.0455281768692601]],#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A-0.0053444231459971615]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.441680908203125,2.1894445419311523:A-11.039535363047849],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x44x96[-7.353572845458984,7.80775785446167:A-0.04239550796694058]],#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A-0.005139450328790714]]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x44x96[-7.353572845458984,7.80775785446167:A-0.04239550796694058]],#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A-0.005139450328790714]]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.220340728759766,9.617319107055664:A-8.075642883985303],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x45x96[-7.353572845458984,7.80775785446167:A-0.038776138660156265]],#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A-0.005717857937677898]]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x45x96[-7.353572845458984,7.80775785446167:A-0.038776138660156265]],#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A-0.005717857937677898]]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.410938262939453,3.8497819900512695:A-11.438987951982766],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x46x96[-7.353572845458984,7.80775785446167:A-0.03963859379362712]],#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A-0.005715915428181952]]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x46x96[-7.353572845458984,7.80775785446167:A-0.03963859379362712]],#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A-0.005715915428181952]]]),input_ids:T7s1x1[29953,29953:A29953.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.984619140625,2.4716057777404785:A-12.045449071270413],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x47x96[-7.353572845458984,7.80775785446167:A-0.03929206428306606]],#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A-0.0058390131802175576]]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x47x96[-7.353572845458984,7.80775785446167:A-0.03929206428306606]],#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A-0.0058390131802175576]]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.83071517944336,3.7170398235321045:A-11.912842350695282],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x48x96[-7.353572845458984,7.80775785446167:A-0.035736772801581336]],#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A-0.006353364724851139]]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x48x96[-7.353572845458984,7.80775785446167:A-0.035736772801581336]],#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A-0.006353364724851139]]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.11894989013672,10.8707275390625:A-6.2240660907747225],past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x49x96[-7.353572845458984,7.80775785446167:A-0.03108322479156599]],#1[T1s1x1x49x96[-0.6787744760513306,0.7704185843467712:A-0.006148716856770779]]]))
-- prompt Continue: it rains...
-- answer Continue: it rains...
- 11-2 on April 25, 2008
- 26,2009-07-15
- 260-4
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_test_helper.steal_forward()
.
with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
---- stolen forward for class LlamaForCausalLM -- iteration 0
<- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache[serialized](#2[#0[],#0[]]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x8x96],#1[T1s1x1x8x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 1
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x8x96],#1[T1s1x1x8x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x9x96],#1[T1s1x1x9x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 2
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x9x96],#1[T1s1x1x9x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x10x96],#1[T1s1x1x10x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 3
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x10x96],#1[T1s1x1x10x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x11x96],#1[T1s1x1x11x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 4
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x11x96],#1[T1s1x1x11x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x12x96],#1[T1s1x1x12x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 5
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x12x96],#1[T1s1x1x12x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x13x96],#1[T1s1x1x13x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 6
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x13x96],#1[T1s1x1x13x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x14x96],#1[T1s1x1x14x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 7
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x14x96],#1[T1s1x1x14x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x15x96],#1[T1s1x1x15x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 8
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x15x96],#1[T1s1x1x15x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x16x96],#1[T1s1x1x16x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 9
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x16x96],#1[T1s1x1x16x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x17x96],#1[T1s1x1x17x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 10
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x17x96],#1[T1s1x1x17x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x18x96],#1[T1s1x1x18x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 11
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x18x96],#1[T1s1x1x18x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x19x96],#1[T1s1x1x19x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 12
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x19x96],#1[T1s1x1x19x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x20x96],#1[T1s1x1x20x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 13
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x20x96],#1[T1s1x1x20x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x21x96],#1[T1s1x1x21x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 14
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x21x96],#1[T1s1x1x21x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x22x96],#1[T1s1x1x22x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 15
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x22x96],#1[T1s1x1x22x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x23x96],#1[T1s1x1x23x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 16
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x23x96],#1[T1s1x1x23x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x24x96],#1[T1s1x1x24x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 17
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x24x96],#1[T1s1x1x24x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x25x96],#1[T1s1x1x25x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 18
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x25x96],#1[T1s1x1x25x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x26x96],#1[T1s1x1x26x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 19
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x26x96],#1[T1s1x1x26x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x27x96],#1[T1s1x1x27x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 20
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x27x96],#1[T1s1x1x27x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x28x96],#1[T1s1x1x28x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 21
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x28x96],#1[T1s1x1x28x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x29x96],#1[T1s1x1x29x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 22
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x29x96],#1[T1s1x1x29x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x30x96],#1[T1s1x1x30x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 23
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x30x96],#1[T1s1x1x30x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x31x96],#1[T1s1x1x31x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 24
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x31x96],#1[T1s1x1x31x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x32x96],#1[T1s1x1x32x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 25
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x32x96],#1[T1s1x1x32x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x33x96],#1[T1s1x1x33x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 26
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x33x96],#1[T1s1x1x33x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x34x96],#1[T1s1x1x34x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 27
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x34x96],#1[T1s1x1x34x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x35x96],#1[T1s1x1x35x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 28
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x35x96],#1[T1s1x1x35x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x36x96],#1[T1s1x1x36x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 29
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x36x96],#1[T1s1x1x36x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x37x96],#1[T1s1x1x37x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 30
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x37x96],#1[T1s1x1x37x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x38x96],#1[T1s1x1x38x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 31
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x38x96],#1[T1s1x1x38x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x39x96],#1[T1s1x1x39x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 32
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x39x96],#1[T1s1x1x39x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x40x96],#1[T1s1x1x40x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 33
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x40x96],#1[T1s1x1x40x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x41x96],#1[T1s1x1x41x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 34
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x41x96],#1[T1s1x1x41x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x42x96],#1[T1s1x1x42x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 35
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x42x96],#1[T1s1x1x42x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x43x96],#1[T1s1x1x43x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 36
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x43x96],#1[T1s1x1x43x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x44x96],#1[T1s1x1x44x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 37
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x44x96],#1[T1s1x1x44x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x45x96],#1[T1s1x1x45x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 38
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x45x96],#1[T1s1x1x45x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x46x96],#1[T1s1x1x46x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 39
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x46x96],#1[T1s1x1x46x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x47x96],#1[T1s1x1x47x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 40
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x47x96],#1[T1s1x1x47x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x48x96],#1[T1s1x1x48x96]]))
.
---- stolen forward for class LlamaForCausalLM -- iteration 41
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x48x96],#1[T1s1x1x48x96]]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s1x1x49x96],#1[T1s1x1x49x96]]))
.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x30x96],#1[T1s2x1x30x96]]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x33x96],#1[T1s2x1x33x96]]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x33x96],#1[T1s2x1x33x96]]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s41, s2]", attention_mask: "i64[s41, s2 + s67]", position_ids: "i64[s41, s2]", past_key_values_key_cache_0: "f32[s41, 1, s67, 96]", past_key_values_value_cache_0: "f32[s41, 1, s67, 96]"):
#
sym_size_int_22: "Sym(s41)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_23: "Sym(s2)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_24: "Sym(s67)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s41, s2, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:543 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s2 + s67)" = sym_size_int_24 + sym_size_int_23
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:542 in forward, code: cache_position = torch.arange(
arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int_24, add, device = device(type='cpu'), pin_memory = False); sym_size_int_24 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:549 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s2, s2 + s67]" = torch.ops.aten.full.default([sym_size_int_23, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s2, s2 + s67]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s2 + s67]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s2, s2 + s67]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s2, s2 + s67]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze: "f32[1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_22, 1, -1, -1]); slice_2 = None
clone: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
slice_4: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_3, 1); slice_3 = None
slice_5: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_4, 2); slice_4 = None
slice_6: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, add); slice_5 = None
slice_7: "i64[s41, s2 + s67]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_2: "i64[s41, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(slice_7, 1); slice_7 = None
unsqueeze_3: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_8: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(slice_8, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_8 = None
add_2: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.add.Tensor(slice_6, to); slice_6 = to = None
eq_4: "b8[s41, 1, s2, s2 + s67]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_9: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
slice_10: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_9, 1); slice_9 = None
slice_11: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_10, 2); slice_10 = None
slice_12: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, add); slice_11 = None
masked_fill: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_4, -3.4028234663852886e+38); slice_12 = eq_4 = None
slice_13: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_14: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807); slice_13 = None
slice_15: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807); slice_14 = None
copy_: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.copy_.default(slice_15, masked_fill); slice_15 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_22, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:84 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:85 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:86 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:255 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s41, s2, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_22, sym_size_int_23, -1, 96]); linear = None
transpose_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:256 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_22, sym_size_int_23, -1, 96]); linear_1 = None
transpose_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:257 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_22, sym_size_int_23, -1, 96]); linear_2 = None
transpose_3: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:260 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_8: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_19: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_20: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s41, 2, s2, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.cat.default([neg, slice_19], -1); neg = slice_19 = None
mul_5: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_21: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_22: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s41, 1, s2, 48]" = torch.ops.aten.neg.default(slice_22); slice_22 = None
cat_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.cat.default([neg_1, slice_21], -1); neg_1 = slice_21 = None
mul_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5: "f32[s41, 1, s2, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:265 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: attn_output, attn_weights = attention_interface(
slice_23: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_24: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_23, 1, 0, 9223372036854775807); slice_23 = None
unsqueeze_9: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_24, 2); slice_24 = None
slice_25: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807); unsqueeze_9 = None
slice_26: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_25, 4, 0, 9223372036854775807); slice_25 = None
expand_2: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_26, [sym_size_int_22, 1, 2, add, 96]); slice_26 = None
reshape_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_22, 2, add, 96]); expand_2 = None
slice_27: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_28: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_27, 1, 0, 9223372036854775807); slice_27 = None
unsqueeze_10: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_28, 2); slice_28 = None
slice_29: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_30: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_29, 4, 0, 9223372036854775807); slice_29 = None
expand_3: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_30, [sym_size_int_22, 1, 2, add, 96]); slice_30 = None
reshape_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_22, 2, add, 96]); expand_3 = None
slice_31: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone); clone = None
slice_32: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_31, 1); slice_31 = None
slice_33: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_32, 2); slice_32 = None
slice_34: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_33, 3, None, add); slice_33 = add = None
contiguous: "f32[s41, 2, s2, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[s41, 2, s2, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_34, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_34 = None
transpose_4: "f32[s41, s2, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s41, s2, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:289 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[s41, s2, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_22, sym_size_int_23, -1]); contiguous_3 = sym_size_int_22 = sym_size_int_23 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:332 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:84 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:85 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:86 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s41, s2, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:175 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s41, s2, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:338 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:84 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:85 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:86 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:844 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_35: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(mul_12); mul_12 = None
slice_36: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_35, 1, 0); slice_35 = None
slice_37: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_36, 2); slice_36 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s41, s2, 32000]" = torch.ops.aten.linear.default(slice_37, p_lm_head_weight); slice_37 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_22: "Sym(s41)", position_ids: "i64[s41, s2]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:116 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_16: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_16, 2); slice_16 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
expand_1: "f32[s41, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_22, -1, 1]); to_1 = sym_size_int_22 = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "f32[s41, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:117 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_17: "i64[s41, s2]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_6: "i64[s41, 1, s2]" = torch.ops.aten.unsqueeze.default(slice_17, 1); slice_17 = None
slice_18: "i64[s41, 1, s2]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_18, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(slice_18, torch.float32); slice_18 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_2, to_3); submod_3 = to_2 = to_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:123 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s41, s2, 96]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:124 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s41, s2, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, to_2: "f32[s41, 48, 1]", to_3: "f32[s41, 1, s2]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:121 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s41, 48, 1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(to_3, torch.float32); to_3 = None
matmul: "f32[s41, 48, s2]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[s41, s2, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:122 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s41, s2, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:123 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s41, s2, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:124 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s41, s2, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s41: VR[1, 1024], s2: VR[2, 4096], s2 + s67: VR[4, 8192], s67: VR[1, 4096]}
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache[serialized](#2[#1[T1s2x1x30x96],#1[T1s2x1x30x96]]))
{'attention_mask': {0: Dim('batch', min=1, max=1024),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
'input_ids': {0: Dim('batch', min=1, max=1024),
1: Dim('seq_length', min=1, max=4096)},
'past_key_values': [[{0: Dim('batch', min=1, max=1024),
2: Dim('cache_length', min=1, max=4096)}],
[{0: Dim('batch', min=1, max=1024),
2: Dim('cache_length', min=1, max=4096)}]],
'position_ids': {0: Dim('batch', min=1, max=1024),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}}
And Let’s finally export.
try:
ep = torch.export.export(
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s41, s2]", attention_mask: "i64[s41, s2 + s67]", position_ids: "i64[s41, s2]", past_key_values_key_cache_0: "f32[s41, 1, s67, 96]", past_key_values_value_cache_0: "f32[s41, 1, s67, 96]"):
#
sym_size_int_22: "Sym(s41)" = torch.ops.aten.sym_size.int(input_ids, 0)
sym_size_int_23: "Sym(s2)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_24: "Sym(s67)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[s41, s2, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:543 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s2 + s67)" = sym_size_int_24 + sym_size_int_23
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:542 in forward, code: cache_position = torch.arange(
arange: "i64[s2]" = torch.ops.aten.arange.start(sym_size_int_24, add, device = device(type='cpu'), pin_memory = False); sym_size_int_24 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:549 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s2, s2 + s67]" = torch.ops.aten.full.default([sym_size_int_23, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s2, s2 + s67]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s2 + s67]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s2, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s2, s2 + s67]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s2, s2 + s67]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze: "f32[1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2: "f32[1, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
expand: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_22, 1, -1, -1]); slice_2 = None
clone: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
slice_4: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_3, 1); slice_3 = None
slice_5: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_4, 2); slice_4 = None
slice_6: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_5, 3, None, add); slice_5 = None
slice_7: "i64[s41, s2 + s67]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_2: "i64[s41, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(slice_7, 1); slice_7 = None
unsqueeze_3: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_8: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(slice_8, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "i64[s41, 1, 1, s2 + s67]" = torch.ops.aten.to.dtype_layout(slice_8, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_8 = None
add_2: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.add.Tensor(slice_6, to); slice_6 = to = None
eq_4: "b8[s41, 1, s2, s2 + s67]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_9: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone)
slice_10: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_9, 1); slice_9 = None
slice_11: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_10, 2); slice_10 = None
slice_12: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_11, 3, None, add); slice_11 = None
masked_fill: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.masked_fill.Scalar(slice_12, eq_4, -3.4028234663852886e+38); slice_12 = eq_4 = None
slice_13: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_14: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_13, 1, 0, 9223372036854775807); slice_13 = None
slice_15: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_14, 2, 0, 9223372036854775807); slice_14 = None
copy_: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.copy_.default(slice_15, masked_fill); slice_15 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_22, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[s41, s2, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:84 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:85 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:86 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:255 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s41, s2, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_22, sym_size_int_23, -1, 96]); linear = None
transpose_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:256 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_22, sym_size_int_23, -1, 96]); linear_1 = None
transpose_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s41, s2, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:257 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s41, s2, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_22, sym_size_int_23, -1, 96]); linear_2 = None
transpose_3: "f32[s41, 1, s2, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:260 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_8: "f32[s41, 1, s2, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_19: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_20: "f32[s41, 2, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s41, 2, s2, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_1: "f32[s41, 2, s2, 96]" = torch.ops.aten.cat.default([neg, slice_19], -1); neg = slice_19 = None
mul_5: "f32[s41, 2, s2, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4: "f32[s41, 2, s2, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_21: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_22: "f32[s41, 1, s2, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s41, 1, s2, 48]" = torch.ops.aten.neg.default(slice_22); slice_22 = None
cat_2: "f32[s41, 1, s2, 96]" = torch.ops.aten.cat.default([neg_1, slice_21], -1); neg_1 = slice_21 = None
mul_7: "f32[s41, 1, s2, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5: "f32[s41, 1, s2, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:265 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: attn_output, attn_weights = attention_interface(
slice_23: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_24: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_23, 1, 0, 9223372036854775807); slice_23 = None
unsqueeze_9: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_24, 2); slice_24 = None
slice_25: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807); unsqueeze_9 = None
slice_26: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_25, 4, 0, 9223372036854775807); slice_25 = None
expand_2: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_26, [sym_size_int_22, 1, 2, add, 96]); slice_26 = None
reshape_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_22, 2, add, 96]); expand_2 = None
slice_27: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_28: "f32[s41, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_27, 1, 0, 9223372036854775807); slice_27 = None
unsqueeze_10: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.unsqueeze.default(slice_28, 2); slice_28 = None
slice_29: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_30: "f32[s41, 1, 1, s2 + s67, 96]" = torch.ops.aten.slice.Tensor(slice_29, 4, 0, 9223372036854775807); slice_29 = None
expand_3: "f32[s41, 1, 2, s2 + s67, 96]" = torch.ops.aten.expand.default(slice_30, [sym_size_int_22, 1, 2, add, 96]); slice_30 = None
reshape_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_22, 2, add, 96]); expand_3 = None
slice_31: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(clone); clone = None
slice_32: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_31, 1); slice_31 = None
slice_33: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_32, 2); slice_32 = None
slice_34: "f32[s41, 1, s2, s2 + s67]" = torch.ops.aten.slice.Tensor(slice_33, 3, None, add); slice_33 = add = None
contiguous: "f32[s41, 2, s2, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[s41, 2, s2 + s67, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[s41, 2, s2, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_34, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_34 = None
transpose_4: "f32[s41, s2, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[s41, s2, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:289 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[s41, s2, 192]" = torch.ops.aten.reshape.default(contiguous_3, [sym_size_int_22, sym_size_int_23, -1]); contiguous_3 = sym_size_int_22 = sym_size_int_23 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:332 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:84 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:85 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:86 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_8, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:434 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[s41, s2, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s41, s2, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:175 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[s41, s2, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s41, s2, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:338 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s41, s2, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:83 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:84 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s41, s2, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[s41, s2, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:85 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s41, s2, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s41, s2, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:86 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_11, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s41, s2, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[s41, s2, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:844 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_35: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(mul_12); mul_12 = None
slice_36: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_35, 1, 0); slice_35 = None
slice_37: "f32[s41, s2, 192]" = torch.ops.aten.slice.Tensor(slice_36, 2); slice_36 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s41, s2, 32000]" = torch.ops.aten.linear.default(slice_37, p_lm_head_weight); slice_37 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_22: "Sym(s41)", position_ids: "i64[s41, s2]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:116 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_16: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_16, 2); slice_16 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
expand_1: "f32[s41, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_22, -1, 1]); to_1 = sym_size_int_22 = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "f32[s41, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:117 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_17: "i64[s41, s2]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_6: "i64[s41, 1, s2]" = torch.ops.aten.unsqueeze.default(slice_17, 1); slice_17 = None
slice_18: "i64[s41, 1, s2]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(slice_18, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(slice_18, torch.float32); slice_18 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_2, to_3); submod_3 = to_2 = to_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:123 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s41, s2, 96]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:124 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s41, s2, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s41, s2, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, to_2: "f32[s41, 48, 1]", to_3: "f32[s41, 1, s2]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:121 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(to_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s41, 48, 1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(to_3, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s41, 1, s2]" = torch.ops.aten.to.dtype(to_3, torch.float32); to_3 = None
matmul: "f32[s41, 48, s2]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[s41, s2, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:122 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s41, s2, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:123 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s41, s2, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:124 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s41, s2, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s41, s2, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s41: VR[1, 1024], s2: VR[2, 4096], s2 + s67: VR[4, 8192], s67: VR[1, 4096]}
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 10.329 seconds)
Related examples