Note
Go to the end to download the full example code.
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with torch.export.export()
.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.432564735412598,8.368535995483398:A-4.234468644971028],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]),input_ids:T7s1x1[3644,3644:A3644.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.676687240600586,7.27504825592041:A-8.850758515120019],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.509540557861328,6.348220348358154:A-0.12691959182981616]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.009876168944038]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.509540557861328,6.348220348358154:A-0.12691959182981616]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.009876168944038]]),input_ids:T7s1x1[366,366:A366.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.188640594482422,10.638813018798828:A-6.258877615917474],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-6.004913806915283,6.348220348358154:A-0.13589673094346919]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.01118109134810738]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-6.004913806915283,6.348220348358154:A-0.13589673094346919]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.01118109134810738]]),input_ids:T7s1x1[526,526:A526.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.33739471435547,8.60517692565918:A-7.591759837625548],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-6.004913806915283,6.348220348358154:A-0.14156292788084102]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.010105416223862247]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-6.004913806915283,6.348220348358154:A-0.14156292788084102]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.010105416223862247]]),input_ids:T7s1x1[2534,2534:A2534.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.11432647705078,5.904876708984375:A-9.151046005945652],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-6.004913806915283,6.423257827758789:A-0.14045577719372934]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.008647813524301174]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-6.004913806915283,6.423257827758789:A-0.14045577719372934]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.008647813524301174]]),input_ids:T7s1x1[7458,7458:A7458.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.7952823638916,7.502084732055664:A-6.758856386695057],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-6.004913806915283,6.423257827758789:A-0.1394986053910543]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.009933336734829504]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-6.004913806915283,6.423257827758789:A-0.1394986053910543]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.009933336734829504]]),input_ids:T7s1x1[411,411:A411.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.175168991088867,3.348450183868408:A-10.076951823784038],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-6.004913806915283,7.722805023193359:A-0.13086280752217538]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.008774217932083654]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-6.004913806915283,7.722805023193359:A-0.13086280752217538]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.008774217932083654]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.300609588623047,1.5349633693695068:A-9.108510399873369],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-6.004913806915283,7.722805023193359:A-0.1289421314440157]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.00907974131760625]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-6.004913806915283,7.722805023193359:A-0.1289421314440157]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.00907974131760625]]),input_ids:T7s1x1[1492,1492:A1492.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.805042266845703,2.2429909706115723:A-10.377611739800777],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.004913806915283,7.722805023193359:A-0.13040056353394688]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.008177679533049492]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.004913806915283,7.722805023193359:A-0.13040056353394688]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.008177679533049492]]),input_ids:T7s1x1[2305,2305:A2305.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.009395599365234,9.020483016967773:A-8.41079969012877],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.004913806915283,7.722805023193359:A-0.1238072641973068]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.008250927643050292]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.004913806915283,7.722805023193359:A-0.1238072641973068]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.008250927643050292]]),input_ids:T7s1x1[472,472:A472.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-23.251914978027344,5.268628120422363:A-11.289321211283095],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.004913806915283,7.722805023193359:A-0.11749377748549848]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007056356020600282]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.004913806915283,7.722805023193359:A-0.11749377748549848]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.007056356020600282]]),input_ids:T7s1x1[937,937:A937.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.645509719848633,8.125471115112305:A-8.086463120871224],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.004913806915283,7.722805023193359:A-0.1168980686679788]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.007018144187100006]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.004913806915283,7.722805023193359:A-0.1168980686679788]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.007018144187100006]]),input_ids:T7s1x1[470,470:A470.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.912551879882812,10.241974830627441:A-3.8326050875000655],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.004913806915283,7.722805023193359:A-0.11400177848906698]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.006563755542174476]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.004913806915283,7.722805023193359:A-0.11400177848906698]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.006563755542174476]]),input_ids:T7s1x1[4654,4654:A4654.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.213184356689453,7.596614837646484:A-8.155910748695023],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.004913806915283,7.722805023193359:A-0.11244513507760227]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.005980289497026179]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.004913806915283,7.722805023193359:A-0.11244513507760227]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.005980289497026179]]),input_ids:T7s1x1[6263,6263:A6263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.154296875,6.49399471282959:A-9.270631306155584],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.004913806915283,7.722805023193359:A-0.10651536178208873]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.0055122166952415955]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.004913806915283,7.722805023193359:A-0.10651536178208873]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.0055122166952415955]]),input_ids:T7s1x1[373,373:A373.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.9627628326416,4.663084030151367:A-10.78550848642271],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.004913806915283,7.722805023193359:A-0.10453992799067338]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.0044693136490157786]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.004913806915283,7.722805023193359:A-0.10453992799067338]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.0044693136490157786]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.589021682739258,3.069247245788574:A-9.503916243720335],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.004913806915283,7.722805023193359:A-0.10742598550059483]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A0.004837044787072955]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.004913806915283,7.722805023193359:A-0.10742598550059483]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A0.004837044787072955]]),input_ids:T7s1x1[1510,1510:A1510.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.98994255065918,10.435498237609863:A-6.825805324685061],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.004913806915283,7.722805023193359:A-0.10189088881111857]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A0.0038105725595286434]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.004913806915283,7.722805023193359:A-0.10189088881111857]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A0.0038105725595286434]]),input_ids:T7s1x1[322,322:A322.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.791205406188965,7.87973690032959:A-6.549245668119518],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.004913806915283,7.722805023193359:A-0.10006822699968909]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A0.003311814227159899]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.004913806915283,7.722805023193359:A-0.10006822699968909]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A0.003311814227159899]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.047138214111328,13.057881355285645:A-3.6654508894905447],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.004913806915283,7.722805023193359:A-0.0977025575129162]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A0.003285461740791193]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.004913806915283,7.722805023193359:A-0.0977025575129162]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A0.003285461740791193]]),input_ids:T7s1x1[674,674:A674.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.111164093017578,8.688833236694336:A-8.262362815548666],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.004913806915283,7.722805023193359:A-0.09106498731957748]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A0.0038887540596580696]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.004913806915283,7.722805023193359:A-0.09106498731957748]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A0.0038887540596580696]]),input_ids:T7s1x1[367,367:A367.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.561283111572266,6.930385589599609:A-6.774146137303672],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.004913806915283,7.722805023193359:A-0.08655843766767551]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.0038961854607653854]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.004913806915283,7.722805023193359:A-0.08655843766767551]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.0038961854607653854]]),input_ids:T7s1x1[2805,2805:A2805.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.426450729370117,6.377567291259766:A-8.199113059743308],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.004913806915283,7.722805023193359:A-0.08060185193604137]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.004169290036579989]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.004913806915283,7.722805023193359:A-0.08060185193604137]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.004169290036579989]]),input_ids:T7s1x1[697,697:A697.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.203289031982422,5.623708724975586:A-8.658177069942466],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.004913806915283,7.722805023193359:A-0.07838606702424993]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.004470653610605761]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.004913806915283,7.722805023193359:A-0.07838606702424993]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.004470653610605761]]),input_ids:T7s1x1[1833,1833:A1833.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.355812072753906,8.754438400268555:A-6.3964791208186655],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.004913806915283,7.722805023193359:A-0.07597313534580565]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.004249567550452797]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.004913806915283,7.722805023193359:A-0.07597313534580565]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.004249567550452797]]),input_ids:T7s1x1[11015,11015:A11015.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.791301727294922,7.957080364227295:A-8.04135798576777],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.004913806915283,7.722805023193359:A-0.06986021991189559]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.004329011404581929]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.004913806915283,7.722805023193359:A-0.06986021991189559]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.004329011404581929]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.727907180786133,7.9141998291015625:A-7.306731516623636],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.004913806915283,7.722805023193359:A-0.0648249151083638]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.004601109170848109]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.004913806915283,7.722805023193359:A-0.0648249151083638]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.004601109170848109]]),input_ids:T7s1x1[697,697:A697.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.604541778564453,7.640566349029541:A-7.185191930130357],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.004913806915283,7.722805023193359:A-0.06394224635877671]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.004856992927363569]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.004913806915283,7.722805023193359:A-0.06394224635877671]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.004856992927363569]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.543540954589844,10.047779083251953:A-7.751862312250072],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.004913806915283,7.722805023193359:A-0.06277707978638769]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.004919462362665997]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.004913806915283,7.722805023193359:A-0.06277707978638769]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.004919462362665997]]),input_ids:T7s1x1[1749,1749:A1749.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.439579010009766,5.922304630279541:A-6.655277061746456],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.106344699859619,7.722805023193359:A-0.061771074560486965]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.004828778836378676]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.106344699859619,7.722805023193359:A-0.061771074560486965]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.004828778836378676]]),input_ids:T7s1x1[1757,1757:A1757.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.706258773803711,10.057670593261719:A-4.6446524666668845],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.106344699859619,7.722805023193359:A-0.05803686417335158]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.005118629730442726]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.106344699859619,7.722805023193359:A-0.05803686417335158]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.005118629730442726]]),input_ids:T7s1x1[526,526:A526.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.891840934753418,6.157626628875732:A-6.936900479043368],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.106344699859619,7.722805023193359:A-0.056969892494907984]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.004947488733610801]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.106344699859619,7.722805023193359:A-0.056969892494907984]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.004947488733610801]]),input_ids:T7s1x1[1985,1985:A1985.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.089534759521484,9.370197296142578:A-7.689837488615885],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.106344699859619,7.722805023193359:A-0.0554071209260404]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.004887221839124514]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.106344699859619,7.722805023193359:A-0.0554071209260404]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.004887221839124514]]),input_ids:T7s1x1[408,408:A408.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.932441711425781,8.202560424804688:A-6.631210417136084],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.106344699859619,7.722805023193359:A-0.05302953876684663]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.005100728132699967]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.106344699859619,7.722805023193359:A-0.05302953876684663]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.005100728132699967]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.183387756347656,5.482445240020752:A-6.923980790627189],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.106344699859619,7.722805023193359:A-0.05024713480040452]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.005369138193660448]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.106344699859619,7.722805023193359:A-0.05024713480040452]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.005369138193660448]]),input_ids:T7s1x1[8455,8455:A8455.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.479686737060547,10.227513313293457:A-8.05372774467012],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.106344699859619,7.722805023193359:A-0.049373298938862165]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.005049302342764818]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.106344699859619,7.722805023193359:A-0.049373298938862165]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.005049302342764818]]),input_ids:T7s1x1[363,363:A363.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.16100788116455,6.885959148406982:A-8.176626010833308],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.106344699859619,7.75748348236084:A-0.04626136697401374]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.004799664269723036]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.106344699859619,7.75748348236084:A-0.04626136697401374]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.004799664269723036]]),input_ids:T7s1x1[963,963:A963.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.823324203491211,12.436174392700195:A-5.595570003254107],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.106344699859619,7.75748348236084:A-0.0461938070601046]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.0047170781391794635]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.106344699859619,7.75748348236084:A-0.0461938070601046]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.0047170781391794635]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.377718925476074,9.742720603942871:A-6.410888299258891],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.659117221832275,7.75748348236084:A-0.045671093396396426]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.004978337300770312]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.659117221832275,7.75748348236084:A-0.045671093396396426]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.004978337300770312]]),input_ids:T7s1x1[437,437:A437.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.776954650878906,7.525132179260254:A-8.930040388823254],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.659117221832275,7.75748348236084:A-0.04461738154984434]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A0.004540475128196577]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.659117221832275,7.75748348236084:A-0.04461738154984434]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A0.004540475128196577]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.089251518249512,9.531167984008789:A-5.4318459506535435],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-6.659117221832275,7.75748348236084:A-0.044132965958876606]], value_cache=#1[T1s1x1x49x96[-0.6787744760513306,0.7704185843467712:A0.004730515089333346]]))
-- prompt Continue: it rains...
-- answer Continue: it rains...
If you are having trouble with the right people at first or third party on the show and we will be getting one last minute, one of our men are working as a manager for them to do, and
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward()
.
with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
<- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x79629eb3b560> not registered
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
And Let’s finally export.
try:
ep = torch.export.export(
model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x79629fd2bbf0> not registered
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 2.553 seconds)
Related examples