Note
Go to the end to download the full example code.
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with torch.export.export().
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate() but we only want to
export the model and its method forward.
That example shows to guess the inputs of this method even though the model
is executed through meth generate.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[2866,2866:A2866.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.076018333435059,16.944217681884766:A-2.592398094831733],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.1302779765439347]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007744434695858352]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.1302779765439347]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007744434695858352]]),input_ids:T7s1x1[14150,14150:A14150.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.20236587524414,6.324185371398926:A-8.229752516841982],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1353976684111937]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008736979494627425]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1353976684111937]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.008736979494627425]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.76398468017578,3.467536449432373:A-9.690286429880652],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.490959167480469,6.226877689361572:A-0.15163987638218465]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.009184762458792675]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.490959167480469,6.226877689361572:A-0.15163987638218465]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.009184762458792675]]),input_ids:T7s1x1[2446,2446:A2446.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.323726654052734,5.4670844078063965:A-7.865214796903078],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.490959167480469,6.226877689361572:A-0.14727237609086943]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.49568021297454834:A0.009050533552087674]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.490959167480469,6.226877689361572:A-0.14727237609086943]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.49568021297454834:A0.009050533552087674]]),input_ids:T7s1x1[29991,29991:A29991.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.184151649475098,7.759481430053711:A-9.368468943121377],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.490959167480469,6.226877689361572:A-0.14272359549319774]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.49568021297454834:A0.009175964193731581]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.490959167480469,6.226877689361572:A-0.14272359549319774]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.49568021297454834:A0.009175964193731581]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.932031631469727,9.351970672607422:A-3.3105489786481486],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.511512279510498,6.282632827758789:A-0.1452894162609612]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.010539780633421071]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.511512279510498,6.282632827758789:A-0.1452894162609612]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.010539780633421071]]),input_ids:T7s1x1[29940,29940:A29940.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-9.800661087036133,11.163089752197266:A-1.0843636879529803],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.511512279510498,6.282632827758789:A-0.14685925084462118]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.009878809622579057]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.511512279510498,6.282632827758789:A-0.14685925084462118]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.009878809622579057]]),input_ids:T7s1x1[711,711:A711.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.302396774291992,12.005083084106445:A-2.538844227918424],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.511512279510498,6.282632827758789:A-0.14132002103807886]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.007490537232939687]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.511512279510498,6.282632827758789:A-0.14132002103807886]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.007490537232939687]]),input_ids:T7s1x1[1486,1486:A1486.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.352119445800781,4.841743469238281:A-7.302925526224077],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.511512279510498,6.282632827758789:A-0.12528488661786574]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.006690134637219493]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.511512279510498,6.282632827758789:A-0.12528488661786574]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.006690134637219493]]),input_ids:T7s1x1[29901,29901:A29901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.473407745361328,3.7950451374053955:A-9.055886989319232],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.511512279510498,6.282632827758789:A-0.12704408231224484]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.007187216785656279]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.511512279510498,6.282632827758789:A-0.12704408231224484]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.007187216785656279]]),input_ids:T7s1x1[6439,6439:A6439.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.637372970581055,8.48501205444336:A-6.666262734174262],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.511512279510498,6.282632827758789:A-0.12709062617687505]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007587694405126393]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.511512279510498,6.282632827758789:A-0.12709062617687505]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007587694405126393]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.878786087036133,6.441829204559326:A-7.103221681401133],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.511512279510498,6.282632827758789:A-0.12636854723211097]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.007900931346064984]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.511512279510498,6.282632827758789:A-0.12636854723211097]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.007900931346064984]]),input_ids:T7s1x1[590,590:A590.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.573440551757812,4.970290184020996:A-8.706373463791795],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.10756778717041,6.282632827758789:A-0.12324468264491842]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.007973992656980766]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.10756778717041,6.282632827758789:A-0.12324468264491842]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.007973992656980766]]),input_ids:T7s1x1[7339,7339:A7339.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.548786163330078,8.159546852111816:A-6.970149930046405],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.10756778717041,6.282632827758789:A-0.11999419265422065]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.008538347612880953]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.10756778717041,6.282632827758789:A-0.11999419265422065]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.008538347612880953]]),input_ids:T7s1x1[16846,16846:A16846.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.448102951049805,13.524051666259766:A-3.142901279407088],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.10756778717041,6.282632827758789:A-0.12003202550297737]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.008957836799689412]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.10756778717041,6.282632827758789:A-0.12003202550297737]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.008957836799689412]]),input_ids:T7s1x1[29876,29876:A29876.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.23011302947998,5.398580551147461:A-7.507193362160586],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.10756778717041,6.282632827758789:A-0.11118035720407231]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.00852774141735482]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.10756778717041,6.282632827758789:A-0.11118035720407231]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.00852774141735482]]),input_ids:T7s1x1[29991,29991:A29991.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.851299285888672,5.789419651031494:A-9.332220061377622],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.10756778717041,6.282632827758789:A-0.11118051617323847]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A0.008613877036398966]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.10756778717041,6.282632827758789:A-0.11118051617323847]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A0.008613877036398966]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.931562423706055,11.070415496826172:A-3.9143916586777197],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.10756778717041,6.282632827758789:A-0.10740538675240831]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A0.00708381281466385]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.10756778717041,6.282632827758789:A-0.10740538675240831]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A0.00708381281466385]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.610342025756836,5.410274505615234:A-9.61340464850422],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.10756778717041,6.282632827758789:A-0.10291547219646789]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A0.006612986321496139]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.10756778717041,6.282632827758789:A-0.10291547219646789]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A0.006612986321496139]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.976022720336914,7.121726036071777:A-9.940053578583523],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-7.410092830657959,6.282632827758789:A-0.10118713294819567]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A0.005286526548491652]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-7.410092830657959,6.282632827758789:A-0.10118713294819567]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A0.005286526548491652]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.976371765136719,11.98576545715332:A-8.17162115960475],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-7.410092830657959,6.282632827758789:A-0.10303693728038242]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A0.004910146236444893]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-7.410092830657959,6.282632827758789:A-0.10303693728038242]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A0.004910146236444893]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.45410919189453,4.403109550476074:A-10.726719597123564],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-7.410092830657959,6.282632827758789:A-0.10299661416336474]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.003728878451142413]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-7.410092830657959,6.282632827758789:A-0.10299661416336474]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.003728878451142413]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.6409077644348145,10.016256332397461:A-2.7200095351033378],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-7.410092830657959,6.282632827758789:A-0.09867025076805817]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.004520507996246995]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-7.410092830657959,6.282632827758789:A-0.09867025076805817]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.004520507996246995]]),input_ids:T7s1x1[29924,29924:A29924.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.641586303710938,10.785520553588867:A-1.2355172414700502],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-7.410092830657959,6.282632827758789:A-0.09192560731022088]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.004210412411296716]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-7.410092830657959,6.282632827758789:A-0.09192560731022088]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.004210412411296716]]),input_ids:T7s1x1[858,858:A858.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.559972763061523,11.023469924926758:A-3.6022008529237937],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-7.410092830657959,6.282632827758789:A-0.0856626743322219]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.0039097890680750425]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-7.410092830657959,6.282632827758789:A-0.0856626743322219]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.0039097890680750425]]),input_ids:T7s1x1[279,279:A279.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.961400985717773,5.305764198303223:A-7.576864839469082],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-7.410092830657959,6.282632827758789:A-0.07616603146611763]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.00391728923921387]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-7.410092830657959,6.282632827758789:A-0.07616603146611763]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.00391728923921387]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.459526062011719,12.080596923828125:A-5.501866738987156],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-7.410092830657959,6.282632827758789:A-0.07680716320724709]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.0029148583258445018]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-7.410092830657959,6.282632827758789:A-0.07680716320724709]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.0029148583258445018]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.406972885131836,6.467724323272705:A-9.910090550930239],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-7.839562892913818,6.282632827758789:A-0.07891964309369036]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.002765875485099261]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-7.839562892913818,6.282632827758789:A-0.07891964309369036]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.002765875485099261]]),input_ids:T7s1x1[29946,29946:A29946.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.70451545715332,7.5244140625:A-10.066174281597137],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-7.839562892913818,6.67236852645874:A-0.07827593383941299]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0024507929046735256]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-7.839562892913818,6.67236852645874:A-0.07827593383941299]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0024507929046735256]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.15230369567871,3.797675848007202:A-11.954967962278053],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-7.839562892913818,6.67236852645874:A-0.07536061911451454]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0022381798676856866]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-7.839562892913818,6.67236852645874:A-0.07536061911451454]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0022381798676856866]]),input_ids:T7s1x1[448,448:A448.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.453864097595215,5.4443745613098145:A-9.043932732896879],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-7.839562892913818,6.67236852645874:A-0.07316683897119623]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.002156105268817104]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-7.839562892913818,6.67236852645874:A-0.07316683897119623]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.002156105268817104]]),input_ids:T7s1x1[450,450:A450.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.085827827453613,3.1654162406921387:A-7.622863385022618],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-7.839562892913818,6.67236852645874:A-0.07050568048407513]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0023849375489855143]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-7.839562892913818,6.67236852645874:A-0.07050568048407513]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0023849375489855143]]),input_ids:T7s1x1[399,399:A399.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.79360580444336,9.920923233032227:A-2.054879032921046],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-7.839562892913818,6.67236852645874:A-0.06917301994604433]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.0030352864548992176]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-7.839562892913818,6.67236852645874:A-0.06917301994604433]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.0030352864548992176]]),input_ids:T7s1x1[3634,3634:A3634.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.410441398620605,7.465281009674072:A-5.817310914522037],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-7.839562892913818,6.67236852645874:A-0.06578478722841528]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.0029793926585965]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-7.839562892913818,6.67236852645874:A-0.06578478722841528]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.0029793926585965]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.725366592407227,4.837038993835449:A-7.264510527204722],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-7.839562892913818,6.67236852645874:A-0.06424830222907661]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.0032322540670918884]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-7.839562892913818,6.67236852645874:A-0.06424830222907661]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.0032322540670918884]]),input_ids:T7s1x1[450,450:A450.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.627082824707031,2.770594596862793:A-8.073137799663236],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-7.839562892913818,6.67236852645874:A-0.0608415401476591]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.0034158254854660163]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-7.839562892913818,6.67236852645874:A-0.0608415401476591]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.0034158254854660163]]),input_ids:T7s1x1[4177,4177:A4177.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.842899322509766,7.7854743003845215:A-5.66247427111445],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-7.839562892913818,6.67236852645874:A-0.05924728778366073]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.003553519836656994]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-7.839562892913818,6.67236852645874:A-0.05924728778366073]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.003553519836656994]]),input_ids:T7s1x1[297,297:A297.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.639137268066406,6.7676544189453125:A-7.407553228139179],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-7.839562892913818,7.707134246826172:A-0.05626418006328965]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.0036710448876195755]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-7.839562892913818,7.707134246826172:A-0.05626418006328965]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.0036710448876195755]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.90138816833496,2.0827174186706543:A-8.824249757778832],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.839562892913818,7.707134246826172:A-0.054659987288777565]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.0038836309985307584]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.839562892913818,7.707134246826172:A-0.054659987288777565]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.0038836309985307584]]),input_ids:T7s1x1[6726,6726:A6726.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.083677291870117,10.706292152404785:A-6.400765408007894],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.839562892913818,7.707134246826172:A-0.05368268294884931]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A0.0036749325944198416]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.839562892913818,7.707134246826172:A-0.05368268294884931]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A0.0036749325944198416]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.0308837890625,5.623234272003174:A-9.287264132453128],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-7.839562892913818,7.707134246826172:A-0.05208185623714413]], value_cache=#1[T1s1x1x49x96[-0.6787744760513306,0.7704185843467712:A0.0037462270727695065]]))
-- prompt Continue: it rains...
-- answer Continue: it rains... Continue the next!
Nobody: Oh, my goddamn! 2020
Mystar 142 - The Woo, The God in the Book of St
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward().
with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
<- args=() --- kwargs=dict(cache_position:T7s8,input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm().
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s44, s70]", attention_mask: "i64[s43, s53]", position_ids: "i64[s44, s70]", past_key_values_key_0: "f32[s44, 1, s45, 96]", past_key_values_value_0: "f32[s44, 1, s21, 96]"):
# No stacktrace found for following nodes
sym_size_int_15: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_18: "Sym(s44)" = torch.ops.aten.sym_size.int(position_ids, 0)
sym_size_int_21: "Sym(s45)" = torch.ops.aten.sym_size.int(past_key_values_key_0, 2)
sym_size_int_23: "Sym(s21)" = torch.ops.aten.sym_size.int(past_key_values_value_0, 2)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s44, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:403 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s45 + s70)" = sym_size_int_21 + sym_size_int_15
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:402 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_21, add, device = device(type='cpu'), pin_memory = False); sym_size_int_21 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:409 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s43, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s44]" = torch.ops.aten.arange.default(sym_size_int_18, device = device(type='cpu'), pin_memory = False)
arange_2: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[s45 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_3: "i64[s45 + s70]" = torch.ops.aten.add.Tensor(arange_3, 0); arange_3 = None
slice_1: "i64[s44]" = torch.ops.aten.slice.Tensor(arange_1, 0, 0, 9223372036854775807); arange_1 = None
unsqueeze: "i64[s44, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None
unsqueeze_1: "i64[s44, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
unsqueeze_2: "i64[s44, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3); unsqueeze_1 = None
unsqueeze_3: "i64[1, 1]" = torch.ops.aten.unsqueeze.default(arange_2, 0); arange_2 = None
unsqueeze_4: "i64[1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
unsqueeze_5: "i64[1, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3); unsqueeze_4 = unsqueeze_5 = None
unsqueeze_6: "i64[1, s70]" = torch.ops.aten.unsqueeze.default(arange, 0); arange = None
unsqueeze_7: "i64[1, 1, s70]" = torch.ops.aten.unsqueeze.default(unsqueeze_6, 1); unsqueeze_6 = None
slice_2: "i64[1, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
unsqueeze_8: "i64[1, 1, s70, 1]" = torch.ops.aten.unsqueeze.default(slice_2, 3); slice_2 = None
unsqueeze_9: "i64[1, s45 + s70]" = torch.ops.aten.unsqueeze.default(add_3, 0); add_3 = None
unsqueeze_10: "i64[1, 1, s45 + s70]" = torch.ops.aten.unsqueeze.default(unsqueeze_9, 1); unsqueeze_9 = None
unsqueeze_11: "i64[1, 1, 1, s45 + s70]" = torch.ops.aten.unsqueeze.default(unsqueeze_10, 2); unsqueeze_10 = None
slice_3: "i64[1, 1, 1, s45 + s70]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807); unsqueeze_11 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(unsqueeze_8, [], dtype = torch.bool, pin_memory = False)
le_3: "b8[1, 1, s70, s45 + s70]" = torch.ops.aten.le.Tensor(slice_3, unsqueeze_8); unsqueeze_8 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le_3, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "b8[1, 1, s70, s45 + s70]" = torch.ops.aten.to.dtype_layout(le_3, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); le_3 = None
and_1: "b8[1, 1, s70, s45 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1); new_ones = to_1 = None
index: "b8[s44, 1, 1, s45 + s70]" = torch.ops.aten.index.Tensor(to, [unsqueeze_2, slice_3]); to = unsqueeze_2 = slice_3 = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "b8[s44, 1, 1, s45 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); index = None
and_2: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2); and_1 = to_2 = None
expand: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.expand.default(and_2, [sym_size_int_18, -1, sym_size_int_15, add]); and_2 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_18, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_8: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[0]
to_9: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_4: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None
mul_2: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11); p_model_layers_0_input_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:264 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s44, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_18, sym_size_int_15, -1, 96]); linear = None
transpose_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:265 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_18, sym_size_int_15, -1, 96]); linear_1 = None
transpose_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:266 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_18, sym_size_int_15, -1, 96]); linear_2 = None
transpose_3: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:269 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_15: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None
unsqueeze_16: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1); to_9 = None
mul_4: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_15)
slice_6: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_7: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s44, 2, s70, 48]" = torch.ops.aten.neg.default(slice_7); slice_7 = None
cat_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_6], -1); neg = slice_6 = None
mul_5: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_16); cat_1 = None
add_5: "f32[s44, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_15); unsqueeze_15 = None
slice_8: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_9: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s44, 1, s70, 48]" = torch.ops.aten.neg.default(slice_9); slice_9 = None
cat_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_8], -1); neg_1 = slice_8 = None
mul_7: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_16); cat_2 = unsqueeze_16 = None
add_6: "f32[s44, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:274 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_key_0, add_6], -2); past_key_values_key_0 = add_6 = None
cat_4: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_value_0, transpose_3], -2); past_key_values_value_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:280 in forward, code: attn_output, attn_weights = attention_interface(
slice_10: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
unsqueeze_17: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_10, 2); slice_10 = None
slice_11: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_17, 3, 0, 9223372036854775807); unsqueeze_17 = None
expand_2: "f32[s44, 1, 2, s45 + s70, 96]" = torch.ops.aten.expand.default(slice_11, [sym_size_int_18, 1, 2, add, 96]); slice_11 = None
reshape: "f32[s44, 2, s45 + s70, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_18, 2, add, 96]); expand_2 = None
slice_12: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
unsqueeze_18: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_12, 2); slice_12 = None
add_11: "Sym(s21 + s70)" = sym_size_int_23 + sym_size_int_15; sym_size_int_23 = None
slice_13: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_18, 3, 0, 9223372036854775807); unsqueeze_18 = None
expand_3: "f32[s44, 1, 2, s21 + s70, 96]" = torch.ops.aten.expand.default(slice_13, [sym_size_int_18, 1, 2, add_11, 96]); slice_13 = None
reshape_1: "f32[s44, 2, s21 + s70, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_18, 2, add_11, 96]); expand_3 = add_11 = None
slice_14: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.slice.Tensor(expand, 3, None, add); expand = add = None
scaled_dot_product_attention: "f32[s44, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_5, reshape, reshape_1, slice_14, scale = 0.10206207261596575); add_5 = reshape = reshape_1 = slice_14 = None
transpose_4: "f32[s44, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:291 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_2: "f32[s44, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_18, sym_size_int_15, -1]); transpose_4 = sym_size_int_18 = sym_size_int_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_0_self_attn_o_proj_weight); reshape_2 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:331 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3); to_10 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_1: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_16: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_16, torch.float32); mul_16 = None
mul_17: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13); p_model_layers_0_post_attention_layernorm_weight = to_13 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/github/transformers/src/transformers/activations.py:103 in forward, code: return nn.functional.silu(input)
silu: "f32[s44, s70, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight); mul_17 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:184 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_18: "f32[s44, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight); mul_18 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:337 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6); to_12 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_14: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_2: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_19: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2); to_14 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_15: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_19, torch.float32); mul_19 = None
mul_20: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15); p_model_norm_weight = to_15 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:500 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_15: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(mul_20, 0, 0, 9223372036854775807); mul_20 = None
slice_16: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(slice_15, 1, 0, 9223372036854775807); slice_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s44, s70, 32000]" = torch.ops.aten.linear.default(slice_16, p_lm_head_weight); slice_16 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_18: "Sym(s44)", position_ids: "i64[s44, s70]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:125 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
unsqueeze_12: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_13: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_12, 2); unsqueeze_12 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_13, torch.float32); unsqueeze_13 = None
expand_1: "f32[s44, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_18, -1, 1]); to_3 = sym_size_int_18 = None
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_4: "i64[s44, s70]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_14: "i64[s44, 1, s70]" = torch.ops.aten.unsqueeze.default(slice_4, 1); slice_4 = None
slice_5: "i64[s44, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807); unsqueeze_14 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_5, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(slice_5, torch.float32); slice_5 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5); submod_3 = to_4 = to_5 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s44, s70, 96]" = wrap_with_autocast[0]
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s44, s70, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_8, to_9)
class submod_1(torch.nn.Module):
def forward(self, to_4: "f32[s44, 48, 1]", to_5: "f32[s44, 1, s70]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:130 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32); to_5 = None
matmul: "f32[s44, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7); to_6 = to_7 = None
transpose: "f32[s44, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:131 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s44, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s44, s70, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s44, s70, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_0: USER_INPUT
past_key_values_value_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s44: VR[0, int_oo], s70: VR[2, int_oo], s43: VR[0, int_oo], s53: VR[0, int_oo], s45: VR[0, int_oo], s21: VR[0, int_oo]}
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm().
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: 'batch', 1: 'cache+seq'},
'input_ids': {0: 'batch', 1: 'seq_length'},
'past_key_values': [{0: 'batch', 2: 'cache_length'},
{0: 'batch', 2: 'cache_length'}],
'position_ids': {0: 'batch', 1: 'seq_length'}}
And Let’s finally export.
try:
ep = torch.export.export(
model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It worked:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_model_embed_tokens_weight: "f32[32000, 192]", p_model_layers_0_self_attn_q_proj_weight: "f32[192, 192]", p_model_layers_0_self_attn_k_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_v_proj_weight: "f32[96, 192]", p_model_layers_0_self_attn_o_proj_weight: "f32[192, 192]", p_model_layers_0_mlp_gate_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_up_proj_weight: "f32[1024, 192]", p_model_layers_0_mlp_down_proj_weight: "f32[192, 1024]", p_model_layers_0_input_layernorm_weight: "f32[192]", p_model_layers_0_post_attention_layernorm_weight: "f32[192]", p_model_norm_weight: "f32[192]", p_lm_head_weight: "f32[32000, 192]", b_model_rotary_emb_inv_freq: "f32[48]", input_ids: "i64[s44, s70]", attention_mask: "i64[s43, s53]", position_ids: "i64[s44, s70]", past_key_values_key_0: "f32[s44, 1, s45, 96]", past_key_values_value_0: "f32[s44, 1, s21, 96]"):
# No stacktrace found for following nodes
sym_size_int_15: "Sym(s70)" = torch.ops.aten.sym_size.int(input_ids, 1)
sym_size_int_18: "Sym(s44)" = torch.ops.aten.sym_size.int(position_ids, 0)
sym_size_int_21: "Sym(s45)" = torch.ops.aten.sym_size.int(past_key_values_key_0, 2)
sym_size_int_23: "Sym(s21)" = torch.ops.aten.sym_size.int(past_key_values_value_0, 2)
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/sparse.py:192 in forward, code: return F.embedding(
embedding: "f32[s44, s70, 192]" = torch.ops.aten.embedding.default(p_model_embed_tokens_weight, input_ids); p_model_embed_tokens_weight = input_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:403 in forward, code: past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
add: "Sym(s45 + s70)" = sym_size_int_21 + sym_size_int_15
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:402 in forward, code: cache_position: torch.Tensor = torch.arange(
arange: "i64[s70]" = torch.ops.aten.arange.start(sym_size_int_21, add, device = device(type='cpu'), pin_memory = False); sym_size_int_21 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:409 in forward, code: causal_mask = create_causal_mask(
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(attention_mask, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "b8[s43, s53]" = torch.ops.aten.to.device(attention_mask, device(type='cpu'), torch.bool); attention_mask = None
arange_1: "i64[s44]" = torch.ops.aten.arange.default(sym_size_int_18, device = device(type='cpu'), pin_memory = False)
arange_2: "i64[1]" = torch.ops.aten.arange.default(1, device = device(type='cpu'), pin_memory = False)
arange_3: "i64[s45 + s70]" = torch.ops.aten.arange.default(add, device = device(type='cpu'), pin_memory = False)
add_3: "i64[s45 + s70]" = torch.ops.aten.add.Tensor(arange_3, 0); arange_3 = None
slice_1: "i64[s44]" = torch.ops.aten.slice.Tensor(arange_1, 0, 0, 9223372036854775807); arange_1 = None
unsqueeze: "i64[s44, 1]" = torch.ops.aten.unsqueeze.default(slice_1, 1); slice_1 = None
unsqueeze_1: "i64[s44, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None
unsqueeze_2: "i64[s44, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3); unsqueeze_1 = None
unsqueeze_3: "i64[1, 1]" = torch.ops.aten.unsqueeze.default(arange_2, 0); arange_2 = None
unsqueeze_4: "i64[1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None
unsqueeze_5: "i64[1, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3); unsqueeze_4 = unsqueeze_5 = None
unsqueeze_6: "i64[1, s70]" = torch.ops.aten.unsqueeze.default(arange, 0); arange = None
unsqueeze_7: "i64[1, 1, s70]" = torch.ops.aten.unsqueeze.default(unsqueeze_6, 1); unsqueeze_6 = None
slice_2: "i64[1, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_7, 2, 0, 9223372036854775807); unsqueeze_7 = None
unsqueeze_8: "i64[1, 1, s70, 1]" = torch.ops.aten.unsqueeze.default(slice_2, 3); slice_2 = None
unsqueeze_9: "i64[1, s45 + s70]" = torch.ops.aten.unsqueeze.default(add_3, 0); add_3 = None
unsqueeze_10: "i64[1, 1, s45 + s70]" = torch.ops.aten.unsqueeze.default(unsqueeze_9, 1); unsqueeze_9 = None
unsqueeze_11: "i64[1, 1, 1, s45 + s70]" = torch.ops.aten.unsqueeze.default(unsqueeze_10, 2); unsqueeze_10 = None
slice_3: "i64[1, 1, 1, s45 + s70]" = torch.ops.aten.slice.Tensor(unsqueeze_11, 3, 0, 9223372036854775807); unsqueeze_11 = None
new_ones: "b8[]" = torch.ops.aten.new_ones.default(unsqueeze_8, [], dtype = torch.bool, pin_memory = False)
le_3: "b8[1, 1, s70, s45 + s70]" = torch.ops.aten.le.Tensor(slice_3, unsqueeze_8); unsqueeze_8 = None
_assert_tensor_metadata_default_1 = torch.ops.aten._assert_tensor_metadata.default(le_3, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_1 = None
to_1: "b8[1, 1, s70, s45 + s70]" = torch.ops.aten.to.dtype_layout(le_3, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); le_3 = None
and_1: "b8[1, 1, s70, s45 + s70]" = torch.ops.aten.__and__.Tensor(new_ones, to_1); new_ones = to_1 = None
index: "b8[s44, 1, 1, s45 + s70]" = torch.ops.aten.index.Tensor(to, [unsqueeze_2, slice_3]); to = unsqueeze_2 = slice_3 = None
_assert_tensor_metadata_default_2 = torch.ops.aten._assert_tensor_metadata.default(index, dtype = torch.bool, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_2 = None
to_2: "b8[s44, 1, 1, s45 + s70]" = torch.ops.aten.to.dtype_layout(index, dtype = torch.bool, layout = torch.strided, device = device(type='cpu')); index = None
and_2: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.__and__.Tensor(and_1, to_2); and_1 = to_2 = None
expand: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.expand.default(and_2, [sym_size_int_18, -1, sym_size_int_15, add]); and_2 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, b_model_rotary_emb_inv_freq, sym_size_int_18, position_ids); submod_3 = b_model_rotary_emb_inv_freq = position_ids = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
to_8: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[0]
to_9: "f32[s44, s70, 96]" = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_10 = torch.ops.aten._assert_tensor_metadata.default(embedding, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_10 = None
to_10: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(embedding, torch.float32); embedding = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_1: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_10, 2)
mean: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_4: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
rsqrt: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None
mul_2: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_10, rsqrt); rsqrt = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_11 = torch.ops.aten._assert_tensor_metadata.default(mul_2, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_11 = None
to_11: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_2, torch.float32); mul_2 = None
mul_3: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_input_layernorm_weight, to_11); p_model_layers_0_input_layernorm_weight = to_11 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_q_proj_weight); p_model_layers_0_self_attn_q_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:264 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view: "f32[s44, s70, 2, 96]" = torch.ops.aten.view.default(linear, [sym_size_int_18, sym_size_int_15, -1, 96]); linear = None
transpose_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.transpose.int(view, 1, 2); view = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_k_proj_weight); p_model_layers_0_self_attn_k_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:265 in forward, code: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_1: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_1, [sym_size_int_18, sym_size_int_15, -1, 96]); linear_1 = None
transpose_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_1, 1, 2); view_1 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_2: "f32[s44, s70, 96]" = torch.ops.aten.linear.default(mul_3, p_model_layers_0_self_attn_v_proj_weight); mul_3 = p_model_layers_0_self_attn_v_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:266 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
view_2: "f32[s44, s70, 1, 96]" = torch.ops.aten.view.default(linear_2, [sym_size_int_18, sym_size_int_15, -1, 96]); linear_2 = None
transpose_3: "f32[s44, 1, s70, 96]" = torch.ops.aten.transpose.int(view_2, 1, 2); view_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:269 in forward, code: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
unsqueeze_15: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_8, 1); to_8 = None
unsqueeze_16: "f32[s44, 1, s70, 96]" = torch.ops.aten.unsqueeze.default(to_9, 1); to_9 = None
mul_4: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_1, unsqueeze_15)
slice_6: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 0, 48)
slice_7: "f32[s44, 2, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_1, 3, 48, 9223372036854775807); transpose_1 = None
neg: "f32[s44, 2, s70, 48]" = torch.ops.aten.neg.default(slice_7); slice_7 = None
cat_1: "f32[s44, 2, s70, 96]" = torch.ops.aten.cat.default([neg, slice_6], -1); neg = slice_6 = None
mul_5: "f32[s44, 2, s70, 96]" = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_16); cat_1 = None
add_5: "f32[s44, 2, s70, 96]" = torch.ops.aten.add.Tensor(mul_4, mul_5); mul_4 = mul_5 = None
mul_6: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(transpose_2, unsqueeze_15); unsqueeze_15 = None
slice_8: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 0, 48)
slice_9: "f32[s44, 1, s70, 48]" = torch.ops.aten.slice.Tensor(transpose_2, 3, 48, 9223372036854775807); transpose_2 = None
neg_1: "f32[s44, 1, s70, 48]" = torch.ops.aten.neg.default(slice_9); slice_9 = None
cat_2: "f32[s44, 1, s70, 96]" = torch.ops.aten.cat.default([neg_1, slice_8], -1); neg_1 = slice_8 = None
mul_7: "f32[s44, 1, s70, 96]" = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_16); cat_2 = unsqueeze_16 = None
add_6: "f32[s44, 1, s70, 96]" = torch.ops.aten.add.Tensor(mul_6, mul_7); mul_6 = mul_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:274 in forward, code: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
cat_3: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_key_0, add_6], -2); past_key_values_key_0 = add_6 = None
cat_4: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.cat.default([past_key_values_value_0, transpose_3], -2); past_key_values_value_0 = transpose_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:280 in forward, code: attn_output, attn_weights = attention_interface(
slice_10: "f32[s44, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_3, 0, 0, 9223372036854775807)
unsqueeze_17: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_10, 2); slice_10 = None
slice_11: "f32[s44, 1, 1, s45 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_17, 3, 0, 9223372036854775807); unsqueeze_17 = None
expand_2: "f32[s44, 1, 2, s45 + s70, 96]" = torch.ops.aten.expand.default(slice_11, [sym_size_int_18, 1, 2, add, 96]); slice_11 = None
reshape: "f32[s44, 2, s45 + s70, 96]" = torch.ops.aten.reshape.default(expand_2, [sym_size_int_18, 2, add, 96]); expand_2 = None
slice_12: "f32[s44, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(cat_4, 0, 0, 9223372036854775807)
unsqueeze_18: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.unsqueeze.default(slice_12, 2); slice_12 = None
add_11: "Sym(s21 + s70)" = sym_size_int_23 + sym_size_int_15; sym_size_int_23 = None
slice_13: "f32[s44, 1, 1, s21 + s70, 96]" = torch.ops.aten.slice.Tensor(unsqueeze_18, 3, 0, 9223372036854775807); unsqueeze_18 = None
expand_3: "f32[s44, 1, 2, s21 + s70, 96]" = torch.ops.aten.expand.default(slice_13, [sym_size_int_18, 1, 2, add_11, 96]); slice_13 = None
reshape_1: "f32[s44, 2, s21 + s70, 96]" = torch.ops.aten.reshape.default(expand_3, [sym_size_int_18, 2, add_11, 96]); expand_3 = add_11 = None
slice_14: "b8[s44, 1, s70, s45 + s70]" = torch.ops.aten.slice.Tensor(expand, 3, None, add); expand = add = None
scaled_dot_product_attention: "f32[s44, 2, s70, 96]" = torch.ops.aten.scaled_dot_product_attention.default(add_5, reshape, reshape_1, slice_14, scale = 0.10206207261596575); add_5 = reshape = reshape_1 = slice_14 = None
transpose_4: "f32[s44, s70, 2, 96]" = torch.ops.aten.transpose.int(scaled_dot_product_attention, 1, 2); scaled_dot_product_attention = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:291 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_2: "f32[s44, s70, 192]" = torch.ops.aten.reshape.default(transpose_4, [sym_size_int_18, sym_size_int_15, -1]); transpose_4 = sym_size_int_18 = sym_size_int_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_3: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(reshape_2, p_model_layers_0_self_attn_o_proj_weight); reshape_2 = p_model_layers_0_self_attn_o_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:331 in forward, code: hidden_states = residual + hidden_states
add_7: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_10, linear_3); to_10 = linear_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_12 = torch.ops.aten._assert_tensor_metadata.default(add_7, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_12 = None
to_12: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_7, torch.float32); add_7 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_2: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_12, 2)
mean_1: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_8: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
rsqrt_1: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
mul_16: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_12, rsqrt_1); rsqrt_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_13 = torch.ops.aten._assert_tensor_metadata.default(mul_16, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_13 = None
to_13: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_16, torch.float32); mul_16 = None
mul_17: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_layers_0_post_attention_layernorm_weight, to_13); p_model_layers_0_post_attention_layernorm_weight = to_13 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_4: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_gate_proj_weight); p_model_layers_0_mlp_gate_proj_weight = None
# File: ~/github/transformers/src/transformers/activations.py:103 in forward, code: return nn.functional.silu(input)
silu: "f32[s44, s70, 1024]" = torch.ops.aten.silu.default(linear_4); linear_4 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_5: "f32[s44, s70, 1024]" = torch.ops.aten.linear.default(mul_17, p_model_layers_0_mlp_up_proj_weight); mul_17 = p_model_layers_0_mlp_up_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:184 in forward, code: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
mul_18: "f32[s44, s70, 1024]" = torch.ops.aten.mul.Tensor(silu, linear_5); silu = linear_5 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_6: "f32[s44, s70, 192]" = torch.ops.aten.linear.default(mul_18, p_model_layers_0_mlp_down_proj_weight); mul_18 = p_model_layers_0_mlp_down_proj_weight = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:337 in forward, code: hidden_states = residual + hidden_states
add_9: "f32[s44, s70, 192]" = torch.ops.aten.add.Tensor(to_12, linear_6); to_12 = linear_6 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:64 in forward, code: hidden_states = hidden_states.to(torch.float32)
_assert_tensor_metadata_default_14 = torch.ops.aten._assert_tensor_metadata.default(add_9, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_14 = None
to_14: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(add_9, torch.float32); add_9 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:65 in forward, code: variance = hidden_states.pow(2).mean(-1, keepdim=True)
pow_3: "f32[s44, s70, 192]" = torch.ops.aten.pow.Tensor_Scalar(to_14, 2)
mean_2: "f32[s44, s70, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:66 in forward, code: hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
add_10: "f32[s44, s70, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
rsqrt_2: "f32[s44, s70, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
mul_19: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(to_14, rsqrt_2); to_14 = rsqrt_2 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:67 in forward, code: return self.weight * hidden_states.to(input_dtype)
_assert_tensor_metadata_default_15 = torch.ops.aten._assert_tensor_metadata.default(mul_19, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_15 = None
to_15: "f32[s44, s70, 192]" = torch.ops.aten.to.dtype(mul_19, torch.float32); mul_19 = None
mul_20: "f32[s44, s70, 192]" = torch.ops.aten.mul.Tensor(p_model_norm_weight, to_15); p_model_norm_weight = to_15 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:500 in forward, code: logits = self.lm_head(hidden_states[:, slice_indices, :])
slice_15: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(mul_20, 0, 0, 9223372036854775807); mul_20 = None
slice_16: "f32[s44, s70, 192]" = torch.ops.aten.slice.Tensor(slice_15, 1, 0, 9223372036854775807); slice_15 = None
# File: ~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_7: "f32[s44, s70, 32000]" = torch.ops.aten.linear.default(slice_16, p_lm_head_weight); slice_16 = p_lm_head_weight = None
return (linear_7, cat_3, cat_4)
class submod_1(torch.nn.Module):
def forward(self, b_model_rotary_emb_inv_freq: "f32[48]", sym_size_int_18: "Sym(s44)", position_ids: "i64[s44, s70]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:125 in forward, code: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
unsqueeze_12: "f32[1, 48]" = torch.ops.aten.unsqueeze.default(b_model_rotary_emb_inv_freq, 0); b_model_rotary_emb_inv_freq = None
unsqueeze_13: "f32[1, 48, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_12, 2); unsqueeze_12 = None
_assert_tensor_metadata_default_3 = torch.ops.aten._assert_tensor_metadata.default(unsqueeze_13, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_3 = None
to_3: "f32[1, 48, 1]" = torch.ops.aten.to.dtype(unsqueeze_13, torch.float32); unsqueeze_13 = None
expand_1: "f32[s44, 48, 1]" = torch.ops.aten.expand.default(to_3, [sym_size_int_18, -1, 1]); to_3 = sym_size_int_18 = None
_assert_tensor_metadata_default_4 = torch.ops.aten._assert_tensor_metadata.default(expand_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_4 = None
to_4: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype_layout(expand_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); expand_1 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:126 in forward, code: position_ids_expanded = position_ids[:, None, :].float()
slice_4: "i64[s44, s70]" = torch.ops.aten.slice.Tensor(position_ids, 0, 0, 9223372036854775807); position_ids = None
unsqueeze_14: "i64[s44, 1, s70]" = torch.ops.aten.unsqueeze.default(slice_4, 1); slice_4 = None
slice_5: "i64[s44, 1, s70]" = torch.ops.aten.slice.Tensor(unsqueeze_14, 2, 0, 9223372036854775807); unsqueeze_14 = None
_assert_tensor_metadata_default_5 = torch.ops.aten._assert_tensor_metadata.default(slice_5, dtype = torch.int64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_5 = None
to_5: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(slice_5, torch.float32); slice_5 = None
# No stacktrace found for following nodes
submod_3 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', torch.bfloat16, False, False, submod_3, to_4, to_5); submod_3 = to_4 = to_5 = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
mul: "f32[s44, s70, 96]" = wrap_with_autocast[0]
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
mul_1: "f32[s44, s70, 96]" = wrap_with_autocast[1]; wrap_with_autocast = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:135 in forward, code: return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
_assert_tensor_metadata_default_8 = torch.ops.aten._assert_tensor_metadata.default(mul, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_8 = None
to_8: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul, torch.float32); mul = None
_assert_tensor_metadata_default_9 = torch.ops.aten._assert_tensor_metadata.default(mul_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_9 = None
to_9: "f32[s44, s70, 96]" = torch.ops.aten.to.dtype(mul_1, torch.float32); mul_1 = None
return (to_8, to_9)
class submod_1(torch.nn.Module):
def forward(self, to_4: "f32[s44, 48, 1]", to_5: "f32[s44, 1, s70]"):
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:130 in forward, code: freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
_assert_tensor_metadata_default_6 = torch.ops.aten._assert_tensor_metadata.default(to_4, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_6 = None
to_6: "f32[s44, 48, 1]" = torch.ops.aten.to.dtype(to_4, torch.float32); to_4 = None
_assert_tensor_metadata_default_7 = torch.ops.aten._assert_tensor_metadata.default(to_5, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default_7 = None
to_7: "f32[s44, 1, s70]" = torch.ops.aten.to.dtype(to_5, torch.float32); to_5 = None
matmul: "f32[s44, 48, s70]" = torch.ops.aten.matmul.default(to_6, to_7); to_6 = to_7 = None
transpose: "f32[s44, s70, 48]" = torch.ops.aten.transpose.int(matmul, 1, 2); matmul = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:131 in forward, code: emb = torch.cat((freqs, freqs), dim=-1)
cat: "f32[s44, s70, 96]" = torch.ops.aten.cat.default([transpose, transpose], -1); transpose = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:132 in forward, code: cos = emb.cos() * self.attention_scaling
cos: "f32[s44, s70, 96]" = torch.ops.aten.cos.default(cat)
mul: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(cos, 1.0); cos = None
# File: ~/github/transformers/src/transformers/models/llama/modeling_llama.py:133 in forward, code: sin = emb.sin() * self.attention_scaling
sin: "f32[s44, s70, 96]" = torch.ops.aten.sin.default(cat); cat = None
mul_1: "f32[s44, s70, 96]" = torch.ops.aten.mul.Tensor(sin, 1.0); sin = None
return (mul, mul_1)
Graph signature:
# inputs
p_model_embed_tokens_weight: PARAMETER target='model.embed_tokens.weight'
p_model_layers_0_self_attn_q_proj_weight: PARAMETER target='model.layers.0.self_attn.q_proj.weight'
p_model_layers_0_self_attn_k_proj_weight: PARAMETER target='model.layers.0.self_attn.k_proj.weight'
p_model_layers_0_self_attn_v_proj_weight: PARAMETER target='model.layers.0.self_attn.v_proj.weight'
p_model_layers_0_self_attn_o_proj_weight: PARAMETER target='model.layers.0.self_attn.o_proj.weight'
p_model_layers_0_mlp_gate_proj_weight: PARAMETER target='model.layers.0.mlp.gate_proj.weight'
p_model_layers_0_mlp_up_proj_weight: PARAMETER target='model.layers.0.mlp.up_proj.weight'
p_model_layers_0_mlp_down_proj_weight: PARAMETER target='model.layers.0.mlp.down_proj.weight'
p_model_layers_0_input_layernorm_weight: PARAMETER target='model.layers.0.input_layernorm.weight'
p_model_layers_0_post_attention_layernorm_weight: PARAMETER target='model.layers.0.post_attention_layernorm.weight'
p_model_norm_weight: PARAMETER target='model.norm.weight'
p_lm_head_weight: PARAMETER target='lm_head.weight'
b_model_rotary_emb_inv_freq: BUFFER target='model.rotary_emb.inv_freq' persistent=False
input_ids: USER_INPUT
attention_mask: USER_INPUT
position_ids: USER_INPUT
past_key_values_key_0: USER_INPUT
past_key_values_value_0: USER_INPUT
# outputs
linear_7: USER_OUTPUT
cat_3: USER_OUTPUT
cat_4: USER_OUTPUT
Range constraints: {s44: VR[0, int_oo], s70: VR[2, int_oo], s43: VR[0, int_oo], s53: VR[0, int_oo], s45: VR[0, int_oo], s21: VR[0, int_oo]}
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 3.988 seconds)
Related examples
Export with DynamicCache and guessed dynamic shapes