Note
Go to the end to download the full example code.
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with torch.export.export()
.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.200005531311035,13.318134307861328:A-3.0123733444297685],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.11562127664324685]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.002961578160045098]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.11562127664324685]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.002961578160045098]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.586483001708984,4.537554740905762:A-10.816070999450982],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-6.134088039398193,6.226877689361572:A-0.11618857977773586]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.002102570093954152]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-6.134088039398193,6.226877689361572:A-0.11618857977773586]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.002102570093954152]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.339534759521484,3.94206166267395:A-11.31890943210572],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-6.383307456970215,6.226877689361572:A-0.10467741398833548]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.511397659778595:A0.00017126266493505682]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-6.383307456970215,6.226877689361572:A-0.10467741398833548]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.511397659778595:A0.00017126266493505682]]),input_ids:T7s1x1[2739,2739:A2739.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.786590576171875,12.760019302368164:A-8.096025096898899],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-6.383307456970215,6.226877689361572:A-0.09766931737037036]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.511397659778595:A0.0017530560903019755]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-6.383307456970215,6.226877689361572:A-0.09766931737037036]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.511397659778595:A0.0017530560903019755]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.460124969482422,18.50583267211914:A-2.61176041782368],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-6.383307456970215,6.226877689361572:A-0.09679381578266308]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.511397659778595:A-0.0007793168957761783]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-6.383307456970215,6.226877689361572:A-0.09679381578266308]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.511397659778595:A-0.0007793168957761783]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.56883430480957,17.521331787109375:A-5.0011909737486855],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-6.383307456970215,6.580713272094727:A-0.09281789121961906]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.511397659778595:A-0.0011256872961396204]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-6.383307456970215,6.580713272094727:A-0.09281789121961906]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.511397659778595:A-0.0011256872961396204]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.019512176513672,11.016201972961426:A-9.780955639850347],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-7.0830888748168945,6.580713272094727:A-0.09829636609607001]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.511397659778595:A-0.0030858339645722785]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-7.0830888748168945,6.580713272094727:A-0.09829636609607001]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.511397659778595:A-0.0030858339645722785]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.470309257507324,15.61859130859375:A-6.633849481329322],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-7.0830888748168945,6.580713272094727:A-0.11563793529868842]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.511397659778595:A-0.003046002088098021]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-7.0830888748168945,6.580713272094727:A-0.11563793529868842]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.511397659778595:A-0.003046002088098021]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.500776290893555,6.285541534423828:A-10.962694770040923],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-7.0830888748168945,6.580713272094727:A-0.1198103299906435]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.511397659778595:A-0.0031979138770842764]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-7.0830888748168945,6.580713272094727:A-0.1198103299906435]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.511397659778595:A-0.0031979138770842764]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.235013961791992,9.970125198364258:A-3.036244198039174],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-7.0830888748168945,6.580713272094727:A-0.12205251992573149]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A-0.001449730086724904]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-7.0830888748168945,6.580713272094727:A-0.12205251992573149]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A-0.001449730086724904]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.281158447265625,2.9456660747528076:A-9.769279797134455],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-7.0830888748168945,6.580713272094727:A-0.1189443595937056]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A-0.0011800400394716697]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-7.0830888748168945,6.580713272094727:A-0.1189443595937056]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A-0.0011800400394716697]]),input_ids:T7s1x1[1085,1085:A1085.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.374403953552246,11.528715133666992:A-5.849875963516533],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-7.0830888748168945,6.580713272094727:A-0.11131518831213422]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.0006671041037482913]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-7.0830888748168945,6.580713272094727:A-0.11131518831213422]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.0006671041037482913]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.48352336883545,14.460716247558594:A-6.279674670537003],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-7.0830888748168945,6.580713272094727:A-0.10925453701692897]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A-0.0021195178477942784]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-7.0830888748168945,6.580713272094727:A-0.10925453701692897]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A-0.0021195178477942784]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.75263786315918,11.090300559997559:A-8.510559297380038],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-7.0830888748168945,6.580713272094727:A-0.10923761380526588]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.0022790171502065555]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-7.0830888748168945,6.580713272094727:A-0.10923761380526588]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.0022790171502065555]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.390655517578125,8.281238555908203:A-10.236852982913144],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-7.34796667098999,6.580713272094727:A-0.10873454754329177]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.0035072288968336393]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-7.34796667098999,6.580713272094727:A-0.10873454754329177]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.0035072288968336393]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.17488670349121,15.653326988220215:A-8.630881659234873],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-7.34796667098999,6.580713272094727:A-0.10522823639272853]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.0035956152970015864]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-7.34796667098999,6.580713272094727:A-0.10522823639272853]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.0035956152970015864]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.672569274902344,3.797863721847534:A-10.746239509989508],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-7.34796667098999,7.215419769287109:A-0.09625673114530703]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.004672906178026703]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-7.34796667098999,7.215419769287109:A-0.09625673114530703]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.004672906178026703]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.594993591308594,8.698394775390625:A-6.974384758137166],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-7.34796667098999,7.215419769287109:A-0.0961442707299657]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.003960393124106514]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-7.34796667098999,7.215419769287109:A-0.0961442707299657]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.003960393124106514]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.360509872436523,16.70062255859375:A-4.671268068693113],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-7.34796667098999,7.215419769287109:A-0.09505437406973394]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.004968074590943829]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-7.34796667098999,7.215419769287109:A-0.09505437406973394]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.004968074590943829]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.814577102661133,8.049570083618164:A-11.26390791450441],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-7.999135494232178,7.215419769287109:A-0.09876963515935563]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.004878090639159555]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-7.999135494232178,7.215419769287109:A-0.09876963515935563]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.004878090639159555]]),input_ids:T7s1x1[29929,29929:A29929.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.522127151489258,9.503350257873535:A-9.24007676622225],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-7.999135494232178,7.215419769287109:A-0.10050203466175583]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.005549828696577355]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-7.999135494232178,7.215419769287109:A-0.10050203466175583]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.005549828696577355]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.51691436767578,10.691946983337402:A-9.585153352168389],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-7.999135494232178,7.215419769287109:A-0.10208049081957142]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.006002894794199973]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-7.999135494232178,7.215419769287109:A-0.10208049081957142]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.006002894794199973]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.16791534423828,4.051937103271484:A-13.73162283246778],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-7.999135494232178,7.215419769287109:A-0.10111829210040679]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.006426730821008229]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-7.999135494232178,7.215419769287109:A-0.10111829210040679]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.006426730821008229]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.925310134887695,6.715033531188965:A-9.641330544466152],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-7.999135494232178,7.215419769287109:A-0.09266261605000636]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.005793006944604902]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-7.999135494232178,7.215419769287109:A-0.09266261605000636]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.005793006944604902]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.01685619354248,16.82293701171875:A-5.22482418841403],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-7.999135494232178,7.215419769287109:A-0.0924762174399435]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.006561939847153663]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-7.999135494232178,7.215419769287109:A-0.0924762174399435]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.006561939847153663]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.561767578125,8.897604942321777:A-11.16644848463498],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-7.999135494232178,7.215419769287109:A-0.09316551498537307]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.006534485807556919]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-7.999135494232178,7.215419769287109:A-0.09316551498537307]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.006534485807556919]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.213354110717773,9.359787940979004:A-12.702782304516061],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-7.999135494232178,7.215419769287109:A-0.09443681208632326]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A-0.007220011565130422]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-7.999135494232178,7.215419769287109:A-0.09443681208632326]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A-0.007220011565130422]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.211776733398438,9.062095642089844:A-11.117917354448698],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-7.999135494232178,7.215419769287109:A-0.09605957950390331]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A-0.00708747024223747]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-7.999135494232178,7.215419769287109:A-0.09605957950390331]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A-0.00708747024223747]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.347116470336914,3.921221971511841:A-13.276751260987483],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-7.999135494232178,7.215419769287109:A-0.09427604100677821]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A-0.007048038681659818]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-7.999135494232178,7.215419769287109:A-0.09427604100677821]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A-0.007048038681659818]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.48412322998047,6.587777137756348:A-8.29412822028529],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-7.999135494232178,7.215419769287109:A-0.09317054887507861]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A-0.006558121372256465]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-7.999135494232178,7.215419769287109:A-0.09317054887507861]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A-0.006558121372256465]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.575267791748047,9.075438499450684:A-5.38603269260982],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-7.999135494232178,7.215419769287109:A-0.08983008838745613]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A-0.005665108148624647]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-7.999135494232178,7.215419769287109:A-0.08983008838745613]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A-0.005665108148624647]]),input_ids:T7s1x1[29902,29902:A29902.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.226329803466797,10.05836296081543:A-4.170653190152254],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-7.999135494232178,7.215419769287109:A-0.08659607141510908]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A-0.005644659130518903]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-7.999135494232178,7.215419769287109:A-0.08659607141510908]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A-0.005644659130518903]]),input_ids:T7s1x1[1016,1016:A1016.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.434805870056152,13.4017333984375:A-6.080810847351095],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-7.999135494232178,7.215419769287109:A-0.08494088391931817]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A-0.006044374190671404]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-7.999135494232178,7.215419769287109:A-0.08494088391931817]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A-0.006044374190671404]]),input_ids:T7s1x1[30010,30010:A30010.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.16743278503418,17.66503143310547:A-5.702449502617819],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-7.999135494232178,7.215419769287109:A-0.08180633239874349]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A-0.005923025541054784]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-7.999135494232178,7.215419769287109:A-0.08180633239874349]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A-0.005923025541054784]]),input_ids:T7s1x1[29873,29873:A29873.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.457927703857422,11.149272918701172:A-6.717263021916617],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-7.999135494232178,7.215419769287109:A-0.07929274646028703]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A-0.005212843903216906]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-7.999135494232178,7.215419769287109:A-0.07929274646028703]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A-0.005212843903216906]]),input_ids:T7s1x1[505,505:A505.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.53079605102539,6.203318119049072:A-9.729566184808967],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-7.999135494232178,7.215419769287109:A-0.07651930623502046]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A-0.005845168247412513]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-7.999135494232178,7.215419769287109:A-0.07651930623502046]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A-0.005845168247412513]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.876670837402344,5.129977226257324:A-7.769785232942551],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-7.999135494232178,7.215419769287109:A-0.07260213183523134]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A-0.0053454453807144425]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-7.999135494232178,7.215419769287109:A-0.07260213183523134]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A-0.0053454453807144425]]),input_ids:T7s1x1[1472,1472:A1472.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.18552017211914,6.626327037811279:A-8.271938080507331],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-7.999135494232178,7.215419769287109:A-0.07038920751377493]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A-0.00537434951540187]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-7.999135494232178,7.215419769287109:A-0.07038920751377493]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A-0.00537434951540187]]),input_ids:T7s1x1[4955,4955:A4955.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.18909454345703,9.84488296508789:A-7.494754958472215],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.999135494232178,7.215419769287109:A-0.06860058786312316]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A-0.005502881686412198]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.999135494232178,7.215419769287109:A-0.06860058786312316]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A-0.005502881686412198]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.823881149291992,5.0159502029418945:A-7.559749979640357],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.999135494232178,7.215419769287109:A-0.06748739368622915]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A-0.005238897442246248]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.999135494232178,7.215419769287109:A-0.06748739368622915]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A-0.005238897442246248]]),input_ids:T7s1x1[4955,4955:A4955.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.734375,11.287750244140625:A-6.233110616709106],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-7.999135494232178,7.215419769287109:A-0.06416769155869273]], value_cache=#1[T1s1x1x49x96[-0.6787744760513306,0.7704185843467712:A-0.005364947730218513]]))
-- prompt Continue: it rains...
-- answer Continue: it rains... 28 Jul 2012
- Mar 2020, 1988, 2012.
I don’t have a long history of history because
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward()
.
with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
<- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x7a4efb121790> not registered
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: Dim('batch', min=1, max=1024),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
'input_ids': {0: Dim('batch', min=1, max=1024),
1: Dim('seq_length', min=1, max=4096)},
'past_key_values': [[{0: Dim('batch', min=1, max=1024),
2: Dim('cache_length', min=1, max=4096)}],
[{0: Dim('batch', min=1, max=1024),
2: Dim('cache_length', min=1, max=4096)}]],
'position_ids': {0: Dim('batch', min=1, max=1024),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}}
And Let’s finally export.
try:
ep = torch.export.export(
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x7a4efb534560> not registered
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 3.026 seconds)
Related examples