Note
Go to the end to download the full example code.
Steel method forward to guess the dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with:func:torch.export.export.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_test_helper import steel_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[7523,7523:A7523.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-17.376155853271484,8.790515899658203:A-7.921017523506889],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.13514820696434104]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007433787821338666]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.13514820696434104]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.007433787821338666]]),input_ids:T7s1x1[5853,5853:A5853.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-15.67812728881836,8.352869033813477:A-8.630772858360782],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1472933334015882]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.003732145182690753]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1472933334015882]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.003732145182690753]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-9.482733726501465,9.324060440063477:A-2.9642273675696926],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.510405540466309,6.323276519775391:A-0.1520401766288022]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.005962804197844724]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.510405540466309,6.323276519775391:A-0.1520401766288022]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.005962804197844724]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-15.92909049987793,4.663028717041016:A-8.153554003157653],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.510405540466309,6.323276519775391:A-0.1326881590736472]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.005772102248948209]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.510405540466309,6.323276519775391:A-0.1326881590736472]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.005772102248948209]]),input_ids:T7s1x1[5129,5129:A5129.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-11.310409545898438,6.781505107879639:A-5.744279077815357],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.510405540466309,6.323276519775391:A-0.13639661611439527]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.006949159104678051]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.510405540466309,6.323276519775391:A-0.13639661611439527]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.006949159104678051]]),input_ids:T7s1x1[1576,1576:A1576.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-16.06218719482422,2.2015881538391113:A-8.275736948913895],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.510405540466309,6.323276519775391:A-0.13408630436978933]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.006618083676314546]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.510405540466309,6.323276519775391:A-0.13408630436978933]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.006618083676314546]]),input_ids:T7s1x1[853,853:A853.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-10.354561805725098,11.275320053100586:A-0.6690227779871785],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.510405540466309,6.323276519775391:A-0.14283968643891665]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.006669190143040598]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.510405540466309,6.323276519775391:A-0.14283968643891665]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.006669190143040598]]),input_ids:T7s1x1[2164,2164:A2164.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-14.804096221923828,5.608332633972168:A-7.6928343164441175],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.510405540466309,6.323276519775391:A-0.13123303175772585]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.006160999315918995]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.510405540466309,6.323276519775391:A-0.13123303175772585]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.006160999315918995]]),input_ids:T7s1x1[2088,2088:A2088.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-8.98497200012207,14.32550048828125:A-1.384192180806771],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.510405540466309,6.323276519775391:A-0.1380472471685732]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.007224553581886093]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-5.510405540466309,6.323276519775391:A-0.1380472471685732]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.007224553581886093]]),input_ids:T7s1x1[342,342:A342.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-14.901308059692383,5.264919757843018:A-7.053204167680815],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.510405540466309,6.323276519775391:A-0.13216429356590972]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.00691319705859112]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-5.510405540466309,6.323276519775391:A-0.13216429356590972]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.00691319705859112]]),input_ids:T7s1x1[379,379:A379.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-7.071249008178711,12.912490844726562:A0.07057144980179146],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.66123628616333,6.323276519775391:A-0.1363222968698609]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.008607103123171998]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-5.66123628616333,6.323276519775391:A-0.1363222968698609]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.008607103123171998]]),input_ids:T7s1x1[3774,3774:A3774.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-13.730072021484375,13.2047758102417:A-2.6407147965063342],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.66123628616333,7.132113456726074:A-0.12779356389067592]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.008140969080879282]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-5.66123628616333,7.132113456726074:A-0.12779356389067592]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.008140969080879282]]),input_ids:T7s1x1[1025,1025:A1025.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-13.743938446044922,7.054136276245117:A-6.374360442082398],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-5.66123628616333,7.132113456726074:A-0.12017948665806981]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.007508044630419987]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-5.66123628616333,7.132113456726074:A-0.12017948665806981]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.007508044630419987]]),input_ids:T7s1x1[29873,29873:A29873.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-12.208542823791504,5.429507732391357:A-6.246416576387826],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-5.66123628616333,7.132113456726074:A-0.1166820352323019]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.008285623732945169]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-5.66123628616333,7.132113456726074:A-0.1166820352323019]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.008285623732945169]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-12.99802303314209,7.814536094665527:A-5.9160133217014375],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-5.66123628616333,7.132113456726074:A-0.11402256011084926]], value_cache=#1[T1s1x1x23x96[-1.1154754161834717,0.7704185843467712:A0.006752680797674565]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-5.66123628616333,7.132113456726074:A-0.11402256011084926]], value_cache=#1[T1s1x1x23x96[-1.1154754161834717,0.7704185843467712:A0.006752680797674565]]),input_ids:T7s1x1[29879,29879:A29879.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-16.782283782958984,2.6439943313598633:A-8.63411575705884],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-5.66123628616333,7.419308185577393:A-0.10728398225774072]], value_cache=#1[T1s1x1x24x96[-1.1154754161834717,0.7704185843467712:A0.006775320575033877]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-5.66123628616333,7.419308185577393:A-0.10728398225774072]], value_cache=#1[T1s1x1x24x96[-1.1154754161834717,0.7704185843467712:A0.006775320575033877]]),input_ids:T7s1x1[317,317:A317.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-11.515056610107422,9.5897216796875:A-2.475880642474862],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-5.738618850708008,7.419308185577393:A-0.10380303513903831]], value_cache=#1[T1s1x1x25x96[-1.1154754161834717,0.7704185843467712:A0.0077070750772873]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-5.738618850708008,7.419308185577393:A-0.10380303513903831]], value_cache=#1[T1s1x1x25x96[-1.1154754161834717,0.7704185843467712:A0.0077070750772873]]),input_ids:T7s1x1[3466,3466:A3466.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-16.978328704833984,4.344635009765625:A-8.933943966033404],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-5.738618850708008,7.419308185577393:A-0.09595241791663657]], value_cache=#1[T1s1x1x26x96[-1.1154754161834717,0.7704185843467712:A0.007098517292917723]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-5.738618850708008,7.419308185577393:A-0.09595241791663657]], value_cache=#1[T1s1x1x26x96[-1.1154754161834717,0.7704185843467712:A0.007098517292917723]]),input_ids:T7s1x1[313,313:A313.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-5.452112674713135,11.58664321899414:A-0.39912351543735713],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-5.738618850708008,7.419308185577393:A-0.09383037467885089]], value_cache=#1[T1s1x1x27x96[-1.1154754161834717,0.7704185843467712:A0.007187827831201525]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-5.738618850708008,7.419308185577393:A-0.09383037467885089]], value_cache=#1[T1s1x1x27x96[-1.1154754161834717,0.7704185843467712:A0.007187827831201525]]),input_ids:T7s1x1[1123,1123:A1123.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-13.879581451416016,10.219121932983398:A-3.6833848738148807],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-5.738618850708008,7.419308185577393:A-0.09727102882429574]], value_cache=#1[T1s1x1x28x96[-1.1154754161834717,0.7704185843467712:A0.00744349041005015]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-5.738618850708008,7.419308185577393:A-0.09727102882429574]], value_cache=#1[T1s1x1x28x96[-1.1154754161834717,0.7704185843467712:A0.00744349041005015]]),input_ids:T7s1x1[1125,1125:A1125.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-15.045766830444336,4.373247146606445:A-9.400006349338218],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-5.738618850708008,7.419308185577393:A-0.09172060809511492]], value_cache=#1[T1s1x1x29x96[-1.1154754161834717,0.7704185843467712:A0.007852016473422955]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-5.738618850708008,7.419308185577393:A-0.09172060809511492]], value_cache=#1[T1s1x1x29x96[-1.1154754161834717,0.7704185843467712:A0.007852016473422955]]),input_ids:T7s1x1[1619,1619:A1619.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-18.499126434326172,2.3278067111968994:A-9.242684685871005],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-5.738618850708008,7.419308185577393:A-0.08659399402737715]], value_cache=#1[T1s1x1x30x96[-1.1154754161834717,0.7704185843467712:A0.008025167971972122]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-5.738618850708008,7.419308185577393:A-0.08659399402737715]], value_cache=#1[T1s1x1x30x96[-1.1154754161834717,0.7704185843467712:A0.008025167971972122]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-12.639043807983398,12.287067413330078:A-3.877907753824722],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-5.738618850708008,7.419308185577393:A-0.08519415041147474]], value_cache=#1[T1s1x1x31x96[-1.1154754161834717,0.7704185843467712:A0.006760878917111279]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-5.738618850708008,7.419308185577393:A-0.08519415041147474]], value_cache=#1[T1s1x1x31x96[-1.1154754161834717,0.7704185843467712:A0.006760878917111279]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-19.31742286682129,6.317631721496582:A-9.84044504672382],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-5.738618850708008,7.419308185577393:A-0.08400380717887401]], value_cache=#1[T1s1x1x32x96[-1.1154754161834717,0.7704185843467712:A0.006473085077795797]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-5.738618850708008,7.419308185577393:A-0.08400380717887401]], value_cache=#1[T1s1x1x32x96[-1.1154754161834717,0.7704185843467712:A0.006473085077795797]]),input_ids:T7s1x1[29946,29946:A29946.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-19.243236541748047,5.739378929138184:A-9.226775219509378],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-5.738618850708008,7.419308185577393:A-0.08370567948172766]], value_cache=#1[T1s1x1x33x96[-1.1154754161834717,0.7704185843467712:A0.006007470984812501]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-5.738618850708008,7.419308185577393:A-0.08370567948172766]], value_cache=#1[T1s1x1x33x96[-1.1154754161834717,0.7704185843467712:A0.006007470984812501]]),input_ids:T7s1x1[386,386:A386.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-23.499183654785156,2.2744171619415283:A-11.42114406795567],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-5.738618850708008,7.419308185577393:A-0.07931487505176811]], value_cache=#1[T1s1x1x34x96[-1.1154754161834717,0.7704185843467712:A0.0060399259108073556]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-5.738618850708008,7.419308185577393:A-0.07931487505176811]], value_cache=#1[T1s1x1x34x96[-1.1154754161834717,0.7704185843467712:A0.0060399259108073556]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-8.283088684082031,11.367097854614258:A-3.30123053604085],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-5.893681526184082,7.419308185577393:A-0.07609968704799983]], value_cache=#1[T1s1x1x35x96[-1.1154754161834717,0.7704185843467712:A0.006675053580480986]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-5.893681526184082,7.419308185577393:A-0.07609968704799983]], value_cache=#1[T1s1x1x35x96[-1.1154754161834717,0.7704185843467712:A0.006675053580480986]]),input_ids:T7s1x1[29933,29933:A29933.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-11.9619140625,9.606267929077148:A-2.6959211870259607],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-5.893681526184082,7.419308185577393:A-0.07247279114350628]], value_cache=#1[T1s1x1x36x96[-1.1154754161834717,0.7704185843467712:A0.005796969355923533]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-5.893681526184082,7.419308185577393:A-0.07247279114350628]], value_cache=#1[T1s1x1x36x96[-1.1154754161834717,0.7704185843467712:A0.005796969355923533]]),input_ids:T7s1x1[1463,1463:A1463.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-17.82023048400879,7.7012481689453125:A-9.883137480513193],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-5.893681526184082,7.419308185577393:A-0.06942578241658928]], value_cache=#1[T1s1x1x37x96[-1.1154754161834717,0.7704185843467712:A0.005500409093193403]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-5.893681526184082,7.419308185577393:A-0.06942578241658928]], value_cache=#1[T1s1x1x37x96[-1.1154754161834717,0.7704185843467712:A0.005500409093193403]]),input_ids:T7s1x1[373,373:A373.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-17.569393157958984,4.371037006378174:A-10.224331094395835],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-5.893681526184082,7.419308185577393:A-0.06490005454246497]], value_cache=#1[T1s1x1x38x96[-1.1154754161834717,0.7704185843467712:A0.004842044211420472]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-5.893681526184082,7.419308185577393:A-0.06490005454246497]], value_cache=#1[T1s1x1x38x96[-1.1154754161834717,0.7704185843467712:A0.004842044211420472]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-18.02338218688965,4.0899658203125:A-9.586811772121116],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-5.893681526184082,7.419308185577393:A-0.06226507910055533]], value_cache=#1[T1s1x1x39x96[-1.1154754161834717,0.7704185843467712:A0.005144616430460988]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-5.893681526184082,7.419308185577393:A-0.06226507910055533]], value_cache=#1[T1s1x1x39x96[-1.1154754161834717,0.7704185843467712:A0.005144616430460988]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-12.909448623657227,16.28973388671875:A-2.808785396105144],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-5.893681526184082,7.419308185577393:A-0.05778663022620094]], value_cache=#1[T1s1x1x40x96[-1.1154754161834717,0.7704185843467712:A0.004236806201481614]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-5.893681526184082,7.419308185577393:A-0.05778663022620094]], value_cache=#1[T1s1x1x40x96[-1.1154754161834717,0.7704185843467712:A0.004236806201481614]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-20.17303466796875,7.190042018890381:A-9.835913220126182],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-5.893681526184082,7.419308185577393:A-0.05772399016850835]], value_cache=#1[T1s1x1x41x96[-1.1154754161834717,0.7704185843467712:A0.0036665960105855765]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-5.893681526184082,7.419308185577393:A-0.05772399016850835]], value_cache=#1[T1s1x1x41x96[-1.1154754161834717,0.7704185843467712:A0.0036665960105855765]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-11.823934555053711,10.777144432067871:A-4.499593742562458],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-5.893681526184082,7.419308185577393:A-0.054385498989971]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.003666781362978457]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-5.893681526184082,7.419308185577393:A-0.054385498989971]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.003666781362978457]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-21.825679779052734,4.7603607177734375:A-11.866506573867984],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-5.893681526184082,7.419308185577393:A-0.05366887297102159]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.0035245649605607217]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-5.893681526184082,7.419308185577393:A-0.05366887297102159]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.0035245649605607217]]),input_ids:T7s1x1[3275,3275:A3275.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-19.200580596923828,4.240942478179932:A-10.24893242730666],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-5.893681526184082,7.419308185577393:A-0.05163268719697402]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A0.002931999661717038]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-5.893681526184082,7.419308185577393:A-0.05163268719697402]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A0.002931999661717038]]),input_ids:T7s1x1[411,411:A411.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-17.112947463989258,4.2529144287109375:A-9.974786050224676],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-5.893681526184082,7.72377872467041:A-0.04788672395923101]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A0.0027012119957598095]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-5.893681526184082,7.72377872467041:A-0.04788672395923101]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A0.0027012119957598095]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-17.24622344970703,5.872923851013184:A-8.405456542466302],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-5.893681526184082,7.72377872467041:A-0.04511458361227689]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A0.0030042804470259148]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-5.893681526184082,7.72377872467041:A-0.04511458361227689]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A0.0030042804470259148]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-10.931031227111816,16.540855407714844:A-3.121000719427131],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-5.893681526184082,7.72377872467041:A-0.040932147357222115]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A0.0022772150603080437]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-5.893681526184082,7.72377872467041:A-0.040932147357222115]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A0.0022772150603080437]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-17.932403564453125,7.150956153869629:A-9.867993058512454],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.511573791503906,7.72377872467041:A-0.0389036618837003]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A0.0021787621644477895]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.511573791503906,7.72377872467041:A-0.0389036618837003]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A0.0021787621644477895]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> dict(logits:T1s1x1x32000[-10.504631996154785,12.478915214538574:A-3.499839548948221],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-7.511573791503906,7.72377872467041:A-0.03418548348425396]], value_cache=#1[T1s1x1x49x96[-1.1154754161834717,0.7704185843467712:A0.0022092849939710293]]))
-- prompt Continue: it rains...
-- answer Continue: it rains... Read More
- ‘The Unified Guest Humboldt's Slip (Re): My 14th
Based on a 8-1 lead with a 1-0
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_test_helper.steel_forward()
.
with steel_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
.
---- stolen forward for class LlamaForCausalLM
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
--
-> dict(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type dict(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Cannot associate shape [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}], [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
'input_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
1: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.seq_length'>},
'past_key_values': [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}],
[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]],
'position_ids': {0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>,
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}}
And Let’s finally export.
try:
ep = torch.export.export(
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Cannot associate shape [[{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}], [{0: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.batch'>, 2: <class 'onnx_diagnostic.torch_models.untrained.llm_tiny_llm.cache_length'>}]] specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.DynamicCache'> at `inputs['past_key_values']` (expected None)
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 2.886 seconds)
Related examples