Note
Go to the end to download the full example code.
Steel method forward to guess the dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with:func:torch.export.export.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_models.llms import get_tiny_llm
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[830,830:A830.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.13926869702104508]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.008778290737794762]]),input_ids:T7s1x1[830,830:A830.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.13926869702104508]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.008778290737794762]]),input_ids:T7s1x1[635,635:A635.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.13467349399519055]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008463278700810406]]),input_ids:T7s1x1[635,635:A635.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.13467349399519055]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008463278700810406]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.8197526931762695,6.226877689361572:A-0.14064827534512556]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.5150525569915771:A0.00837615266990807]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.8197526931762695,6.226877689361572:A-0.14064827534512556]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.5150525569915771:A0.00837615266990807]]),input_ids:T7s1x1[2355,2355:A2355.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.8197526931762695,6.226877689361572:A-0.1338236682991515]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.575259804725647:A0.00714091036518817]]),input_ids:T7s1x1[2355,2355:A2355.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.8197526931762695,6.226877689361572:A-0.1338236682991515]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.575259804725647:A0.00714091036518817]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.8197526931762695,6.565155982971191:A-0.13599621835535985]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.575259804725647:A0.007871791164327591]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.8197526931762695,6.565155982971191:A-0.13599621835535985]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.575259804725647:A0.007871791164327591]]),input_ids:T7s1x1[2107,2107:A2107.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.8197526931762695,6.565155982971191:A-0.13278416968655837]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.575259804725647:A0.0076829137735568934]]),input_ids:T7s1x1[2107,2107:A2107.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.8197526931762695,6.565155982971191:A-0.13278416968655837]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.575259804725647:A0.0076829137735568934]]),input_ids:T7s1x1[5376,5376:A5376.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.8197526931762695,6.565155982971191:A-0.1292427302262998]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.575259804725647:A0.005242380537907189]]),input_ids:T7s1x1[5376,5376:A5376.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.8197526931762695,6.565155982971191:A-0.1292427302262998]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.575259804725647:A0.005242380537907189]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.8197526931762695,6.565155982971191:A-0.12225772721065671]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.575259804725647:A0.005362754381385078]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.8197526931762695,6.565155982971191:A-0.12225772721065671]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.575259804725647:A0.005362754381385078]]),input_ids:T7s1x1[7458,7458:A7458.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.8197526931762695,6.565155982971191:A-0.12251733124104053]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.575259804725647:A0.00661465932787406]]),input_ids:T7s1x1[7458,7458:A7458.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.8197526931762695,6.565155982971191:A-0.12251733124104053]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.575259804725647:A0.00661465932787406]]),input_ids:T7s1x1[856,856:A856.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.8197526931762695,6.565155982971191:A-0.1233257284622798]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.575259804725647:A0.0066723718805153315]]),input_ids:T7s1x1[856,856:A856.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.8197526931762695,6.565155982971191:A-0.1233257284622798]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.575259804725647:A0.0066723718805153315]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.8197526931762695,6.565155982971191:A-0.12438850456218413]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007809057273613705]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.8197526931762695,6.565155982971191:A-0.12438850456218413]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007809057273613705]]),input_ids:T7s1x1[29908,29908:A29908.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.8197526931762695,6.565155982971191:A-0.12379566539572504]], value_cache=#1[T1s1x1x20x96[-0.7138619422912598,0.7704185843467712:A0.006257893364750089]]),input_ids:T7s1x1[29908,29908:A29908.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.8197526931762695,6.565155982971191:A-0.12379566539572504]], value_cache=#1[T1s1x1x20x96[-0.7138619422912598,0.7704185843467712:A0.006257893364750089]]),input_ids:T7s1x1[29956,29956:A29956.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-5.8197526931762695,6.565155982971191:A-0.11858383356172006]], value_cache=#1[T1s1x1x21x96[-0.7138619422912598,0.7704185843467712:A0.006566355677001584]]),input_ids:T7s1x1[29956,29956:A29956.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-5.8197526931762695,6.565155982971191:A-0.11858383356172006]], value_cache=#1[T1s1x1x21x96[-0.7138619422912598,0.7704185843467712:A0.006566355677001584]]),input_ids:T7s1x1[1639,1639:A1639.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-5.8197526931762695,6.565155982971191:A-0.11352129460079374]], value_cache=#1[T1s1x1x22x96[-0.7138619422912598,0.7704185843467712:A0.004946331530976694]]),input_ids:T7s1x1[1639,1639:A1639.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-5.8197526931762695,6.565155982971191:A-0.11352129460079374]], value_cache=#1[T1s1x1x22x96[-0.7138619422912598,0.7704185843467712:A0.004946331530976694]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-5.987462043762207,7.052847862243652:A-0.11419484736873436]], value_cache=#1[T1s1x1x23x96[-0.7138619422912598,0.7704185843467712:A0.005234265772050212]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-5.987462043762207,7.052847862243652:A-0.11419484736873436]], value_cache=#1[T1s1x1x23x96[-0.7138619422912598,0.7704185843467712:A0.005234265772050212]]),input_ids:T7s1x1[510,510:A510.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-5.987462043762207,7.052847862243652:A-0.11391830670491901]], value_cache=#1[T1s1x1x24x96[-0.7138619422912598,0.7704185843467712:A0.00490483474958915]]),input_ids:T7s1x1[510,510:A510.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-5.987462043762207,7.052847862243652:A-0.11391830670491901]], value_cache=#1[T1s1x1x24x96[-0.7138619422912598,0.7704185843467712:A0.00490483474958915]]),input_ids:T7s1x1[448,448:A448.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-5.987462043762207,7.052847862243652:A-0.11251380358313327]], value_cache=#1[T1s1x1x25x96[-0.7138619422912598,0.7704185843467712:A0.004670132180078023]]),input_ids:T7s1x1[448,448:A448.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-5.987462043762207,7.052847862243652:A-0.11251380358313327]], value_cache=#1[T1s1x1x25x96[-0.7138619422912598,0.7704185843467712:A0.004670132180078023]]),input_ids:T7s1x1[1724,1724:A1724.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-5.987462043762207,7.052847862243652:A-0.11230627737766041]], value_cache=#1[T1s1x1x26x96[-0.7138619422912598,0.7704185843467712:A0.0046593989174174896]]),input_ids:T7s1x1[1724,1724:A1724.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-5.987462043762207,7.052847862243652:A-0.11230627737766041]], value_cache=#1[T1s1x1x26x96[-0.7138619422912598,0.7704185843467712:A0.0046593989174174896]]),input_ids:T7s1x1[2253,2253:A2253.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-5.987462043762207,7.052847862243652:A-0.10919084175403429]], value_cache=#1[T1s1x1x27x96[-0.7138619422912598,0.7704185843467712:A0.004745883260943909]]),input_ids:T7s1x1[2253,2253:A2253.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-5.987462043762207,7.052847862243652:A-0.10919084175403429]], value_cache=#1[T1s1x1x27x96[-0.7138619422912598,0.7704185843467712:A0.004745883260943909]]),input_ids:T7s1x1[982,982:A982.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-5.987462043762207,7.052847862243652:A-0.1025098547061134]], value_cache=#1[T1s1x1x28x96[-0.7138619422912598,0.7704185843467712:A0.005613603910806627]]),input_ids:T7s1x1[982,982:A982.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-5.987462043762207,7.052847862243652:A-0.1025098547061134]], value_cache=#1[T1s1x1x28x96[-0.7138619422912598,0.7704185843467712:A0.005613603910806627]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.72218656539917,7.052847862243652:A-0.09692871276567383]], value_cache=#1[T1s1x1x29x96[-0.7138619422912598,0.7704185843467712:A0.0060061092495356854]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.72218656539917,7.052847862243652:A-0.09692871276567383]], value_cache=#1[T1s1x1x29x96[-0.7138619422912598,0.7704185843467712:A0.0060061092495356854]]),input_ids:T7s1x1[15649,15649:A15649.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.72218656539917,7.052847862243652:A-0.09255540540483101]], value_cache=#1[T1s1x1x30x96[-0.7138619422912598,0.7704185843467712:A0.005399935928891056]]),input_ids:T7s1x1[15649,15649:A15649.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.72218656539917,7.052847862243652:A-0.09255540540483101]], value_cache=#1[T1s1x1x30x96[-0.7138619422912598,0.7704185843467712:A0.005399935928891056]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.72218656539917,7.052847862243652:A-0.08525467339006687]], value_cache=#1[T1s1x1x31x96[-0.7138619422912598,0.7704185843467712:A0.005672597131955803]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.72218656539917,7.052847862243652:A-0.08525467339006687]], value_cache=#1[T1s1x1x31x96[-0.7138619422912598,0.7704185843467712:A0.005672597131955803]]),input_ids:T7s1x1[367,367:A367.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.72218656539917,7.052847862243652:A-0.08051488317592732]], value_cache=#1[T1s1x1x32x96[-0.7138619422912598,0.7704185843467712:A0.005623818974484607]]),input_ids:T7s1x1[367,367:A367.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.72218656539917,7.052847862243652:A-0.08051488317592732]], value_cache=#1[T1s1x1x32x96[-0.7138619422912598,0.7704185843467712:A0.005623818974484607]]),input_ids:T7s1x1[596,596:A596.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.72218656539917,7.052847862243652:A-0.07817058410283709]], value_cache=#1[T1s1x1x33x96[-0.7138619422912598,0.7704185843467712:A0.0061109000237391025]]),input_ids:T7s1x1[596,596:A596.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.72218656539917,7.052847862243652:A-0.07817058410283709]], value_cache=#1[T1s1x1x33x96[-0.7138619422912598,0.7704185843467712:A0.0061109000237391025]]),input_ids:T7s1x1[1914,1914:A1914.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.72218656539917,7.052847862243652:A-0.07948427112209103]], value_cache=#1[T1s1x1x34x96[-0.7138619422912598,0.7704185843467712:A0.005882191799989467]]),input_ids:T7s1x1[1914,1914:A1914.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.72218656539917,7.052847862243652:A-0.07948427112209103]], value_cache=#1[T1s1x1x34x96[-0.7138619422912598,0.7704185843467712:A0.005882191799989467]]),input_ids:T7s1x1[3699,3699:A3699.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.72218656539917,7.052847862243652:A-0.07514602015667368]], value_cache=#1[T1s1x1x35x96[-0.7138619422912598,0.7704185843467712:A0.005620337307131454]]),input_ids:T7s1x1[3699,3699:A3699.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.72218656539917,7.052847862243652:A-0.07514602015667368]], value_cache=#1[T1s1x1x35x96[-0.7138619422912598,0.7704185843467712:A0.005620337307131454]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.72218656539917,7.052847862243652:A-0.07634669775328533]], value_cache=#1[T1s1x1x36x96[-0.7138619422912598,0.7704185843467712:A0.005670651629194018]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.72218656539917,7.052847862243652:A-0.07634669775328533]], value_cache=#1[T1s1x1x36x96[-0.7138619422912598,0.7704185843467712:A0.005670651629194018]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.72218656539917,7.052847862243652:A-0.07316075543306536]], value_cache=#1[T1s1x1x37x96[-0.7138619422912598,0.7704185843467712:A0.006281428459469434]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.72218656539917,7.052847862243652:A-0.07316075543306536]], value_cache=#1[T1s1x1x37x96[-0.7138619422912598,0.7704185843467712:A0.006281428459469434]]),input_ids:T7s1x1[6295,6295:A6295.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.72218656539917,7.052847862243652:A-0.06949802690050683]], value_cache=#1[T1s1x1x38x96[-0.7138619422912598,0.7704185843467712:A0.006212282752589441]]),input_ids:T7s1x1[6295,6295:A6295.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.72218656539917,7.052847862243652:A-0.06949802690050683]], value_cache=#1[T1s1x1x38x96[-0.7138619422912598,0.7704185843467712:A0.006212282752589441]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.72218656539917,7.052847862243652:A-0.07012538888073665]], value_cache=#1[T1s1x1x39x96[-0.7138619422912598,0.7704185843467712:A0.0064081840467255635]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.72218656539917,7.052847862243652:A-0.07012538888073665]], value_cache=#1[T1s1x1x39x96[-0.7138619422912598,0.7704185843467712:A0.0064081840467255635]]),input_ids:T7s1x1[2020,2020:A2020.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.72218656539917,7.052847862243652:A-0.06562370558746504]], value_cache=#1[T1s1x1x40x96[-0.7138619422912598,0.7704185843467712:A0.006578563628465872]]),input_ids:T7s1x1[2020,2020:A2020.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.72218656539917,7.052847862243652:A-0.06562370558746504]], value_cache=#1[T1s1x1x40x96[-0.7138619422912598,0.7704185843467712:A0.006578563628465872]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.72218656539917,7.052847862243652:A-0.0628541645154012]], value_cache=#1[T1s1x1x41x96[-0.7138619422912598,0.7704185843467712:A0.005612376343678323]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.72218656539917,7.052847862243652:A-0.0628541645154012]], value_cache=#1[T1s1x1x41x96[-0.7138619422912598,0.7704185843467712:A0.005612376343678323]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.72218656539917,7.052847862243652:A-0.06273026103340343]], value_cache=#1[T1s1x1x42x96[-0.7138619422912598,0.7704185843467712:A0.005540032730943985]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.72218656539917,7.052847862243652:A-0.06273026103340343]], value_cache=#1[T1s1x1x42x96[-0.7138619422912598,0.7704185843467712:A0.005540032730943985]]),input_ids:T7s1x1[1925,1925:A1925.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.72218656539917,7.052847862243652:A-0.061756996103297775]], value_cache=#1[T1s1x1x43x96[-0.7138619422912598,0.7704185843467712:A0.005460804806473387]]),input_ids:T7s1x1[1925,1925:A1925.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.72218656539917,7.052847862243652:A-0.061756996103297775]], value_cache=#1[T1s1x1x43x96[-0.7138619422912598,0.7704185843467712:A0.005460804806473387]]),input_ids:T7s1x1[1283,1283:A1283.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.72218656539917,7.052847862243652:A-0.060946521573455495]], value_cache=#1[T1s1x1x44x96[-0.7138619422912598,0.7704185843467712:A0.005347465156880775]]),input_ids:T7s1x1[1283,1283:A1283.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.72218656539917,7.052847862243652:A-0.060946521573455495]], value_cache=#1[T1s1x1x44x96[-0.7138619422912598,0.7704185843467712:A0.005347465156880775]]),input_ids:T7s1x1[777,777:A777.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.72218656539917,7.052847862243652:A-0.05882268034876217]], value_cache=#1[T1s1x1x45x96[-0.7138619422912598,0.7704185843467712:A0.00515750024143455]]),input_ids:T7s1x1[777,777:A777.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.72218656539917,7.052847862243652:A-0.05882268034876217]], value_cache=#1[T1s1x1x45x96[-0.7138619422912598,0.7704185843467712:A0.00515750024143455]]),input_ids:T7s1x1[901,901:A901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.72218656539917,7.052847862243652:A-0.05760709743605032]], value_cache=#1[T1s1x1x46x96[-0.7138619422912598,0.7704185843467712:A0.0045914474782276175]]),input_ids:T7s1x1[901,901:A901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.72218656539917,7.052847862243652:A-0.05760709743605032]], value_cache=#1[T1s1x1x46x96[-0.7138619422912598,0.7704185843467712:A0.0045914474782276175]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.72218656539917,7.052847862243652:A-0.05824511129190177]], value_cache=#1[T1s1x1x47x96[-0.7138619422912598,0.7704185843467712:A0.00465187738084796]]),input_ids:T7s1x1[29973,29973:A29973.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.72218656539917,7.052847862243652:A-0.05824511129190177]], value_cache=#1[T1s1x1x47x96[-0.7138619422912598,0.7704185843467712:A0.00465187738084796]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.72218656539917,7.052847862243652:A-0.058899736090880755]], value_cache=#1[T1s1x1x48x96[-0.7138619422912598,0.7704185843467712:A0.004711315192932059]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.72218656539917,7.052847862243652:A-0.058899736090880755]], value_cache=#1[T1s1x1x48x96[-0.7138619422912598,0.7704185843467712:A0.004711315192932059]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-6.72218656539917,7.052847862243652:A-0.05637273617952703]], value_cache=#1[T1s1x1x49x96[-1.1154754161834717,0.7704185843467712:A0.004064715622091023]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-- prompt Continue: it rains...
-- answer Continue: it rains... Really I got a great deal of trouble...
"Winter.com - What better way to buy, be your own house?
So, why have we put off some more? I've
Let’s restore the forward as it was.
model.forward = keep_model_forward
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
[_catch_produce_guards_and_solve_constraints] ERRORproduce_guards_and_solve_constraints failed, use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping
fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f3a2d8853a0>
dynamic_shapes={'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>}, 'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}], [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]]}
equalities_inputs=EqualityConstraint(warn_only=False, source_pairs=[(TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2))], derived_equalities=[], phantom_symbols=[], relaxed_sources={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1)}, _parents={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2)}, _defs={})
original_signature=(input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Union[transformers.cache_utils.Cache, List[torch.FloatTensor], NoneType] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[transformers.models.llama.modeling_llama.KwargsForCausalLM]) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]
_is_torch_jit_trace=False
exc=Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][1]['input_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['attention_mask'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['position_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
Suggested fixes:
batch = 2
L['args'][1]['position_ids'].size()[1] = seq_length
gm=<lambda>()
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1):
embedding = torch.ops.aten.embedding.default(arg0_1, arg13_1); arg0_1 = None
sym_size_int = torch.ops.aten.sym_size.int(arg16_1, 2)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg13_1, 1)
add = sym_size_int + sym_size_int_1
arange = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False); sym_size_int = add = None
sym_size_int_2 = torch.ops.aten.sym_size.int(arg14_1, 1)
full = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu = torch.ops.aten.triu.default(full, 1); full = None
arange_1 = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False); sym_size_int_2 = None
reshape = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_ = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1 = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
sym_size_int_5 = torch.ops.aten.sym_size.int(arg13_1, 0); arg13_1 = None
expand = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]); slice_2 = None
clone = torch.ops.aten.clone.default(expand); expand = None
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6 = torch.ops.aten.slice.Tensor(arg14_1, 0, 0, 9223372036854775807); arg14_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_7 = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
to = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2 = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7 = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12 = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13 = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_ = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg12_1, 0); arg12_1 = None
slice_14 = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5 = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1 = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
sym_size_int_13 = torch.ops.aten.sym_size.int(arg15_1, 0)
expand_1 = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]); to_1 = sym_size_int_13 = None
slice_15 = torch.ops.aten.slice.Tensor(arg15_1, 0, 0, 9223372036854775807); arg15_1 = None
unsqueeze_6 = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16 = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
to_2 = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
_enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', torch.bfloat16, False, False)
to_3 = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4 = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5 = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
cat = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
cos = torch.ops.aten.cos.default(cat)
sin = torch.ops.aten.sin.default(cat); cat = None
_exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None
mul = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
mul_1 = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
to_6 = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7 = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
to_8 = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
pow_1 = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
add_3 = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2 = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
to_9 = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3 = torch.ops.aten.mul.Tensor(arg8_1, to_9); arg8_1 = to_9 = None
linear = torch.ops.aten.linear.default(mul_3, arg1_1); arg1_1 = None
view = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_1, -1, 96]); linear = None
transpose_1 = torch.ops.aten.transpose.int(view, 1, 2); view = None
linear_1 = torch.ops.aten.linear.default(mul_3, arg2_1); arg2_1 = None
view_1 = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_1, -1, 96]); linear_1 = None
transpose_2 = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
linear_2 = torch.ops.aten.linear.default(mul_3, arg3_1); mul_3 = arg3_1 = None
view_2 = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_1, -1, 96]); linear_2 = sym_size_int_5 = None
transpose_3 = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
unsqueeze_7 = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_8 = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4 = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_17 = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18 = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1 = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4 = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6 = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_19 = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20 = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1 = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2 = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5 = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
cat_3 = torch.ops.aten.cat.default([arg16_1, add_5], -2); arg16_1 = add_5 = None
cat_4 = torch.ops.aten.cat.default([arg17_1, transpose_3], -2); arg17_1 = transpose_3 = None
slice_21 = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22 = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_9 = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
sym_size_int_16 = torch.ops.aten.sym_size.int(cat_3, 2)
slice_23 = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807); unsqueeze_9 = None
slice_24 = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2 = torch.ops.aten.expand.default(slice_24, [2, 1, 2, sym_size_int_16, 96]); slice_24 = None
reshape_1 = torch.ops.aten.reshape.default(expand_2, [2, 2, sym_size_int_16, 96]); expand_2 = sym_size_int_16 = None
slice_25 = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26 = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_10 = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
sym_size_int_17 = torch.ops.aten.sym_size.int(cat_4, 2)
slice_27 = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_28 = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3 = torch.ops.aten.expand.default(slice_28, [2, 1, 2, sym_size_int_17, 96]); slice_28 = None
reshape_2 = torch.ops.aten.reshape.default(expand_3, [2, 2, sym_size_int_17, 96]); expand_3 = sym_size_int_17 = None
slice_29 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30 = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31 = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1 = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2 = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4 = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3 = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
reshape_3 = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_1, -1]); contiguous_3 = sym_size_int_1 = None
linear_3 = torch.ops.aten.linear.default(reshape_3, arg4_1); reshape_3 = arg4_1 = None
add_7 = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
to_10 = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
pow_2 = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
add_8 = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1 = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8 = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
to_11 = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9 = torch.ops.aten.mul.Tensor(arg9_1, to_11); arg9_1 = to_11 = None
linear_4 = torch.ops.aten.linear.default(mul_9, arg5_1); arg5_1 = None
silu = torch.ops.aten.silu.default(linear_4); linear_4 = None
linear_5 = torch.ops.aten.linear.default(mul_9, arg6_1); mul_9 = arg6_1 = None
mul_10 = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
linear_6 = torch.ops.aten.linear.default(mul_10, arg7_1); mul_10 = arg7_1 = None
add_9 = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
to_12 = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
pow_3 = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2 = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
add_10 = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2 = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11 = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
to_13 = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12 = torch.ops.aten.mul.Tensor(arg10_1, to_13); arg10_1 = to_13 = None
slice_32 = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33 = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34 = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
linear_7 = torch.ops.aten.linear.default(slice_34, arg11_1); slice_34 = arg11_1 = None
return (linear_7, cat_3, cat_4)
# To see more debug info, please use `graph_module.print_readable()`
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[2, s1]", attention_mask: "i64[2, s1 + s7]", position_ids: "i64[2, s1]", past_key_values_key_cache_0: "f32[2, 1, s7, 96]", past_key_values_value_cache_0: "f32[2, 1, s7, 96]"):
#
sym_size_int_19: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_20: "Sym(s7)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[2, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s7)" = sym_size_int_20 + sym_size_int_19
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False); sym_size_int_20 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s7]" = torch.ops.aten.full.default([sym_size_int_19, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s7]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s7]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s7]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s7]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze: "f32[1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
sym_size_int_5: "Sym(2)" = torch.ops.aten.sym_size.int(input_ids, 0); input_ids = None
expand: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]); slice_2 = None
clone: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[2, s1 + s7]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_2: "i64[2, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_3: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_7: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
to: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7: "b8[2, 1, s1, s1 + s7]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[2, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_19, -1, 96]); linear = None
transpose_1: "f32[2, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_19, -1, 96]); linear_1 = None
transpose_2: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_19, -1, 96]); linear_2 = sym_size_int_5 = None
transpose_3: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_7: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_8: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_17: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[2, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[2, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4: "f32[2, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_19: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[2, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[2, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5: "f32[2, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_9: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
slice_23: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807); unsqueeze_9 = None
slice_24: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_24, [2, 1, 2, add, 96]); slice_24 = None
reshape_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_2, [2, 2, add, 96]); expand_2 = None
slice_25: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_10: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
slice_27: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_28: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_28, [2, 1, 2, add, 96]); slice_28 = None
reshape_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_3, [2, 2, add, 96]); expand_3 = add = None
slice_29: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous: "f32[2, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[2, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4: "f32[2, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[2, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[2, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_19, -1]); contiguous_3 = sym_size_int_19 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[2, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[2, s1, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[2, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_32: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[2, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight); slice_34 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", position_ids: "i64[2, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
sym_size_int_13: "Sym(2)" = torch.ops.aten.sym_size.int(position_ids, 0)
expand_1: "f32[2, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]); to_1 = sym_size_int_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[2, s1]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_6: "i64[2, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[2, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
to_2: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[2, s1, 96]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[2, s1, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[2, 48, 1]", to_2: "f32[2, 1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[2, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[2, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[2, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[2, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[2, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[2, s1, 96]" = torch.ops.aten.cos.default(cat)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[2, s1, 96]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s1: VR[2, 4096], s1 + s7: VR[4, 8192], s7: VR[1, 4096]}
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>},
'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}],
[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]],
'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}}
And Let’s finally export.
try:
ep = torch.export.export(
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
[_catch_produce_guards_and_solve_constraints] ERRORproduce_guards_and_solve_constraints failed, use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping
fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f3a2d3946e0>
dynamic_shapes={'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>}, 'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}], [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]]}
equalities_inputs=EqualityConstraint(warn_only=False, source_pairs=[(TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0)), (TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2))], derived_equalities=[], phantom_symbols=[], relaxed_sources={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=1)}, _parents={TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='attention_mask', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='position_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='input_ids', index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=0), TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='value_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2): TensorPropertySource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=GetItemSource(base=LocalSource(local_name='args', is_input=False, dynamism=None, is_derefed_cell_contents=False), index=1, index_is_slice=False), index='past_key_values', index_is_slice=False), index='key_cache', index_is_slice=False), index=0, index_is_slice=False), prop=<TensorProperty.SIZE: 0>, idx=2)}, _defs={})
original_signature=(input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Union[transformers.cache_utils.Cache, List[torch.FloatTensor], NoneType] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[transformers.models.llama.modeling_llama.KwargsForCausalLM]) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]
_is_torch_jit_trace=False
exc=Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['args'][1]['input_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['attention_mask'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['position_ids'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (2).
Suggested fixes:
batch = 2
L['args'][1]['position_ids'].size()[1] = seq_length
gm=<lambda>()
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1):
embedding = torch.ops.aten.embedding.default(arg0_1, arg13_1); arg0_1 = None
sym_size_int = torch.ops.aten.sym_size.int(arg16_1, 2)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg13_1, 1)
add = sym_size_int + sym_size_int_1
arange = torch.ops.aten.arange.start(sym_size_int, add, device = device(type='cpu'), pin_memory = False); sym_size_int = add = None
sym_size_int_2 = torch.ops.aten.sym_size.int(arg14_1, 1)
full = torch.ops.aten.full.default([sym_size_int_1, sym_size_int_2], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu = torch.ops.aten.triu.default(full, 1); full = None
arange_1 = torch.ops.aten.arange.default(sym_size_int_2, device = device(type='cpu'), pin_memory = False); sym_size_int_2 = None
reshape = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_ = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1 = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2 = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
sym_size_int_5 = torch.ops.aten.sym_size.int(arg13_1, 0); arg13_1 = None
expand = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]); slice_2 = None
clone = torch.ops.aten.clone.default(expand); expand = None
slice_3 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5 = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6 = torch.ops.aten.slice.Tensor(arg14_1, 0, 0, 9223372036854775807); arg14_1 = None
unsqueeze_2 = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_7 = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
to = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2 = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7 = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12 = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13 = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_ = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg12_1, 0); arg12_1 = None
slice_14 = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5 = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1 = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
sym_size_int_13 = torch.ops.aten.sym_size.int(arg15_1, 0)
expand_1 = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]); to_1 = sym_size_int_13 = None
slice_15 = torch.ops.aten.slice.Tensor(arg15_1, 0, 0, 9223372036854775807); arg15_1 = None
unsqueeze_6 = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16 = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
to_2 = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
_enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', torch.bfloat16, False, False)
to_3 = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4 = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5 = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
cat = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
cos = torch.ops.aten.cos.default(cat)
sin = torch.ops.aten.sin.default(cat); cat = None
_exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None
mul = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
mul_1 = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
to_6 = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7 = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
to_8 = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
pow_1 = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
add_3 = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2 = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
to_9 = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3 = torch.ops.aten.mul.Tensor(arg8_1, to_9); arg8_1 = to_9 = None
linear = torch.ops.aten.linear.default(mul_3, arg1_1); arg1_1 = None
view = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_1, -1, 96]); linear = None
transpose_1 = torch.ops.aten.transpose.int(view, 1, 2); view = None
linear_1 = torch.ops.aten.linear.default(mul_3, arg2_1); arg2_1 = None
view_1 = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_1, -1, 96]); linear_1 = None
transpose_2 = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
linear_2 = torch.ops.aten.linear.default(mul_3, arg3_1); mul_3 = arg3_1 = None
view_2 = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_1, -1, 96]); linear_2 = sym_size_int_5 = None
transpose_3 = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
unsqueeze_7 = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_8 = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4 = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_17 = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18 = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1 = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4 = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6 = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_19 = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20 = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1 = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2 = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5 = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
cat_3 = torch.ops.aten.cat.default([arg16_1, add_5], -2); arg16_1 = add_5 = None
cat_4 = torch.ops.aten.cat.default([arg17_1, transpose_3], -2); arg17_1 = transpose_3 = None
slice_21 = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22 = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_9 = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
sym_size_int_16 = torch.ops.aten.sym_size.int(cat_3, 2)
slice_23 = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807); unsqueeze_9 = None
slice_24 = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2 = torch.ops.aten.expand.default(slice_24, [2, 1, 2, sym_size_int_16, 96]); slice_24 = None
reshape_1 = torch.ops.aten.reshape.default(expand_2, [2, 2, sym_size_int_16, 96]); expand_2 = sym_size_int_16 = None
slice_25 = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26 = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_10 = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
sym_size_int_17 = torch.ops.aten.sym_size.int(cat_4, 2)
slice_27 = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_28 = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3 = torch.ops.aten.expand.default(slice_28, [2, 1, 2, sym_size_int_17, 96]); slice_28 = None
reshape_2 = torch.ops.aten.reshape.default(expand_3, [2, 2, sym_size_int_17, 96]); expand_3 = sym_size_int_17 = None
slice_29 = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30 = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31 = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1 = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2 = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4 = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3 = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
reshape_3 = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_1, -1]); contiguous_3 = sym_size_int_1 = None
linear_3 = torch.ops.aten.linear.default(reshape_3, arg4_1); reshape_3 = arg4_1 = None
add_7 = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
to_10 = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
pow_2 = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
add_8 = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1 = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8 = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
to_11 = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9 = torch.ops.aten.mul.Tensor(arg9_1, to_11); arg9_1 = to_11 = None
linear_4 = torch.ops.aten.linear.default(mul_9, arg5_1); arg5_1 = None
silu = torch.ops.aten.silu.default(linear_4); linear_4 = None
linear_5 = torch.ops.aten.linear.default(mul_9, arg6_1); mul_9 = arg6_1 = None
mul_10 = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
linear_6 = torch.ops.aten.linear.default(mul_10, arg7_1); mul_10 = arg7_1 = None
add_9 = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
to_12 = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
pow_3 = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2 = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
add_10 = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2 = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11 = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
to_13 = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12 = torch.ops.aten.mul.Tensor(arg10_1, to_13); arg10_1 = to_13 = None
slice_32 = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33 = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34 = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
linear_7 = torch.ops.aten.linear.default(slice_34, arg11_1); slice_34 = arg11_1 = None
return (linear_7, cat_3, cat_4)
# To see more debug info, please use `graph_module.print_readable()`
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[2, s1]", attention_mask: "i64[2, s1 + s7]", position_ids: "i64[2, s1]", past_key_values_key_cache_0: "f32[2, 1, s7, 96]", past_key_values_value_cache_0: "f32[2, 1, s7, 96]"):
#
sym_size_int_19: "Sym(s1)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_20: "Sym(s7)" = torch.ops.aten.sym_size.int(past_key_values_key_cache_0, 2)
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:190 in forward, code: return F.embedding(
embedding: "f32[2, s1, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:565 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s1 + s7)" = sym_size_int_20 + sym_size_int_19
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:564 in forward, code: cache_position = torch.arange(
arange: "i64[s1]" = torch.ops.aten.arange.start(sym_size_int_20, add, device = device(type='cpu'), pin_memory = False); sym_size_int_20 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:571 in forward, code: causal_mask = self._update_causal_mask(
full: "f32[s1, s1 + s7]" = torch.ops.aten.full.default([sym_size_int_19, add], -3.4028234663852886e+38, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
triu: "f32[s1, s1 + s7]" = torch.ops.aten.triu.default(full, 1); full = None
arange_1: "i64[s1 + s7]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
reshape: "i64[s1, 1]" = torch.ops.aten.reshape.default(arange, [-1, 1]); arange = None
gt: "b8[s1, s1 + s7]" = torch.ops.aten.gt.Tensor(arange_1, reshape); arange_1 = reshape = None
mul_: "f32[s1, s1 + s7]" = torch.ops.aten.mul_.Tensor(triu, gt); triu = gt = None
unsqueeze: "f32[1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(mul_, 0); mul_ = None
unsqueeze_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze, 1); unsqueeze = None
slice_1: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_1, 2, 0, 9223372036854775807); unsqueeze_1 = None
slice_2: "f32[1, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_1, 3, 0, 9223372036854775807); slice_1 = None
sym_size_int_5: "Sym(2)" = torch.ops.aten.sym_size.int(input_ids, 0); input_ids = None
expand: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.expand.default(slice_2, [sym_size_int_5, 1, -1, -1]); slice_2 = None
clone: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.clone.default(expand); expand = None
slice_3: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_4: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_3, 1, 0, 9223372036854775807); slice_3 = None
slice_5: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_4, 2, 0, 9223372036854775807); slice_4 = None
slice_6: "i64[2, s1 + s7]" = torch.ops.aten.slice.Tensor(attention_mask, 0, 0, 9223372036854775807); attention_mask = None
unsqueeze_2: "i64[2, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(slice_6, 1); slice_6 = None
unsqueeze_3: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2); unsqueeze_2 = None
slice_7: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.slice.Tensor(unsqueeze_3, 3, 0, 9223372036854775807); unsqueeze_3 = None
to: "i64[2, 1, 1, s1 + s7]" = torch.ops.aten.to.dtype_layout(slice_7, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')); slice_7 = None
add_2: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.add.Tensor(slice_5, to); slice_5 = to = None
eq_7: "b8[2, 1, s1, s1 + s7]" = torch.ops.aten.eq.Scalar(add_2, 0); add_2 = None
slice_8: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_9: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_8, 1, 0, 9223372036854775807); slice_8 = None
slice_10: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_9, 2, 0, 9223372036854775807); slice_9 = None
masked_fill: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.masked_fill.Scalar(slice_10, eq_7, -3.4028234663852886e+38); slice_10 = eq_7 = None
slice_11: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807)
slice_12: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807); slice_11 = None
slice_13: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_12, 2, 0, 9223372036854775807); slice_12 = None
copy_: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.copy_.default(slice_13, masked_fill); slice_13 = masked_fill = copy_ = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[0]
to_7: "f32[2, s1, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_8: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_8, 2)
mean: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_3: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_3); add_3 = None
mul_2: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_8, rsqrt); rsqrt = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_9: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_9); p_model_layers_0_input_layernorm_weight = to_9 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:277 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[2, s1, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_5, sym_size_int_19, -1, 96]); linear = None
transpose_1: "f32[2, 2, s1, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:278 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_5, sym_size_int_19, -1, 96]); linear_1 = None
transpose_2: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[2, s1, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:279 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[2, s1, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_5, sym_size_int_19, -1, 96]); linear_2 = sym_size_int_5 = None
transpose_3: "f32[2, 1, s1, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:282 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_7: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_6, 1); to_6 = None
unsqueeze_8: "f32[2, 1, s1, 96]" = torch.ops.aten.unsqueeze.default(to_7, 1); to_7 = None
mul_4: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_7)
slice_17: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_18: "f32[2, 2, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[2, 2, s1, 48]" = torch.ops.aten.neg.default(slice_18); slice_18 = None
cat_1: "f32[2, 2, s1, 96]" = torch.ops.aten.cat.default([neg, slice_17], -1); neg = slice_17 = None
mul_5: "f32[2, 2, s1, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_8); cat_1 = None
add_4: "f32[2, 2, s1, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_7); unsqueeze_7 = None
slice_19: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_20: "f32[2, 1, s1, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[2, 1, s1, 48]" = torch.ops.aten.neg.default(slice_20); slice_20 = None
cat_2: "f32[2, 1, s1, 96]" = torch.ops.aten.cat.default([neg_1, slice_19], -1); neg_1 = slice_19 = None
mul_7: "f32[2, 1, s1, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_8); cat_2 = unsqueeze_8 = None
add_5: "f32[2, 1, s1, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:287 in forward, code: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_key_cache_0, add_5], -2); past_key_values_key_cache_0 = add_5 = None
cat_4: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.cat.default([past_key_values_value_cache_0, transpose_3], -2); past_key_values_value_cache_0 = transpose_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:299 in forward, code: attn_output, attn_weights = attention_interface(
slice_21: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
slice_22: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_21, 1, 0, 9223372036854775807); slice_21 = None
unsqueeze_9: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_22, 2); slice_22 = None
slice_23: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_9, 3, 0, 9223372036854775807); unsqueeze_9 = None
slice_24: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_23, 4, 0, 9223372036854775807); slice_23 = None
expand_2: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_24, [2, 1, 2, add, 96]); slice_24 = None
reshape_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_2, [2, 2, add, 96]); expand_2 = None
slice_25: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
slice_26: "f32[2, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_25, 1, 0, 9223372036854775807); slice_25 = None
unsqueeze_10: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.unsqueeze.default(slice_26, 2); slice_26 = None
slice_27: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_10, 3, 0, 9223372036854775807); unsqueeze_10 = None
slice_28: "f32[2, 1, 1, s1 + s7, 96]" = torch.ops.aten.slice.Tensor(slice_27, 4, 0, 9223372036854775807); slice_27 = None
expand_3: "f32[2, 1, 2, s1 + s7, 96]" = torch.ops.aten.expand.default(slice_28, [2, 1, 2, add, 96]); slice_28 = None
reshape_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.reshape.default(expand_3, [2, 2, add, 96]); expand_3 = add = None
slice_29: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(clone, 0, 0, 9223372036854775807); clone = None
slice_30: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_29, 1, 0, 9223372036854775807); slice_29 = None
slice_31: "f32[2, 1, s1, s1 + s7]" = torch.ops.aten.slice.Tensor(slice_30, 2, 0, 9223372036854775807); slice_30 = None
contiguous: "f32[2, 2, s1, 96]" = torch.ops.aten.contiguous.default(add_4); add_4 = None
contiguous_1: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_1); reshape_1 = None
contiguous_2: "f32[2, 2, s1 + s7, 96]" = torch.ops.aten.contiguous.default(reshape_2); reshape_2 = None
scaled_dot_product_attention: "f32[2, 2, s1, 96]" = torch.ops.aten.scaled_dot_product_attention.default(contiguous, contiguous_1, contiguous_2, slice_31, scale = 0.10206207261596575); contiguous = contiguous_1 = contiguous_2 = slice_31 = None
transpose_4: "f32[2, s1, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
contiguous_3: "f32[2, s1, 2, 96]" = torch.ops.aten.contiguous.default(transpose_4); transpose_4 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:310 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_3: "f32[2, s1, 192]" = torch.ops.aten.reshape.default(contiguous_3, [2, sym_size_int_19, -1]); contiguous_3 = sym_size_int_19 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[2, s1, 192]" = torch.ops.aten.linear.default(reshape_3, p_model_layers_0_self_attn_o_proj_weight); reshape_3 = p_model_layers_0_self_attn_o_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:354 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_8, linear_3); to_8 = linear_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_10: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean_1: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_8: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt_1); rsqrt_1 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_11: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_8, torch.float32); mul_8 = None
mul_9: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_11); p_model_layers_0_post_attention_layernorm_weight = to_11 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/activation.py:432 in forward, code: return F.silu(input, inplace=self.inplace)
silu: "f32[2, s1, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[2, s1, 1024]" = torch.ops.aten.linear.default(mul_9, p_model_layers_0_mlp_up_proj_weight); mul_9 = p_model_layers_0_mlp_up_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:197 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_10: "f32[2, s1, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[2, s1, 192]" = torch.ops.aten.linear.default(mul_10, p_model_layers_0_mlp_down_proj_weight); mul_10 = p_model_layers_0_mlp_down_proj_weight = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:360 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[2, s1, 192]" = torch.ops.aten.add.Tensor(to_10, linear_6); to_10 = linear_6 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:78 in forward, code: hidden_states = hidden_states.to(torch.float32)
to_12: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:79 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[2, s1, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_2: "f32[2, s1, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:80 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[2, s1, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[2, s1, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_11: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_2); to_12 = rsqrt_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:81 in forward, code: return self.weight * hidden_states.to(input_dtype)
to_13: "f32[2, s1, 192]" = torch.ops.aten.to.dtype(mul_11, torch.float32); mul_11 = None
mul_12: "f32[2, s1, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_13); p_model_norm_weight = to_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:870 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_32: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(mul_12, 0, 0, 9223372036854775807); mul_12 = None
slice_33: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_32, 1, 0, 9223372036854775807); slice_32 = None
slice_34: "f32[2, s1, 192]" = torch.ops.aten.slice.Tensor(slice_33, 2, 0, 9223372036854775807); slice_33 = None
# File: /home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[2, s1, 32000]" = torch.ops.aten.linear.default(slice_34, p_lm_head_weight); slice_34 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", position_ids: "i64[2, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
unsqueeze_4: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
slice_14: "f32[1, 48]" = torch.ops.aten.slice.Tensor(unsqueeze_4, 1, 0, 9223372036854775807); unsqueeze_4 = None
unsqueeze_5: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(slice_14, 2); slice_14 = None
to_1: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_5, torch.float32); unsqueeze_5 = None
sym_size_int_13: "Sym(2)" = torch.ops.aten.sym_size.int(position_ids, 0)
expand_1: "f32[2, 48, 1]" = torch.ops.aten.expand.default(to_1, [sym_size_int_13, -1, 1]); to_1 = sym_size_int_13 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:134 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_15: "i64[2, s1]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_6: "i64[2, 1, s1]" = torch.ops.aten.unsqueeze.default(slice_15, 1); slice_15 = None
slice_16: "i64[2, 1, s1]" = torch.ops.aten.slice.Tensor(unsqueeze_6, 2, 0, 9223372036854775807); unsqueeze_6 = None
to_2: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(slice_16, torch.float32); slice_16 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, expand_1, to_2); submod_3 = expand_1 = to_2 = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[2, s1, 96]" = wrap_with_autocast[0]
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[2, s1, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:145 in forward, code: cos = cos * self.attention_scaling
mul: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:146 in forward, code: sin = sin * self.attention_scaling
mul_1: "f32[2, s1, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:148 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_6: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
to_7: "f32[2, s1, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_6, to_7)
class submod_1(torch.nn.Module):
def forward(self, expand_1: "f32[2, 48, 1]", to_2: "f32[2, 1, s1]"):
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:139 in forward, code: freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2)
to_3: "f32[2, 48, 1]" = torch.ops.aten.to.dtype(expand_1, torch.float32); expand_1 = None
to_4: "f32[2, 48, 1]" = torch.ops.aten.to.dtype_layout(to_3, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); to_3 = None
to_5: "f32[2, 1, s1]" = torch.ops.aten.to.dtype(to_2, torch.float32); to_2 = None
matmul: "f32[2, 48, s1]" = torch.ops.aten.matmul.default(to_4, to_5); to_4 = to_5 = None
transpose: "f32[2, s1, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:140 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[2, s1, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:141 in forward, code: cos = emb.cos()
cos: "f32[2, s1, 96]" = torch.ops.aten.cos.default(cat)
# File: /home/xadupre/github/transformers/src/transformers/models/llama/modeling_llama.py:142 in forward, code: sin = emb.sin()
sin: "f32[2, s1, 96]" = torch.ops.aten.sin.default(cat); cat = None
return (cos, sin)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_cache_0: USER_INPUT
past_key_values_value_cache_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s1: VR[2, 4096], s1 + s7: VR[4, 8192], s7: VR[1, 4096]}
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM fails", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 2.703 seconds)
Related examples

Find and fix an export issue due to dynamic shapes