Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Inputs are always dynamic with LLMs that is why dynamic shapes needs to be specified when a LLM is exported with torch.export.export(). Most of the examples on HuggingFace use method transformers.GenerationMixin.generate() but we only want to export the model and its method forward.

That example shows to guess the inputs of this method even though the model is executed through meth generate.

We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.

Steel the forward method

The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

We rewrite the forward method to print the cache dimension.

def _forward_(*args, _f=None, **kwargs):
    assert _f is not None
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        # torch.compiler.is_exporting requires torch>=2.7
        print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
    res = _f(*args, **kwargs)
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        print("->", string_type(res, with_shape=True, with_min_max=True))
    return res


keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
    *args, _f=_f, **kwargs
)

Let’s run the model.

prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.432564735412598,8.368535995483398:A-4.234468644971028],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]),input_ids:T7s1x1[29903,29903:A29903.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.083460807800293,11.222838401794434:A-0.12068867132649758],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.509540557861328,6.540352821350098:A-0.11170021048298319]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.010558997104196048]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.509540557861328,6.540352821350098:A-0.11170021048298319]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.010558997104196048]]),input_ids:T7s1x1[870,870:A870.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.207452774047852,15.588762283325195:A-2.6809886231475977],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.509540557861328,6.540352821350098:A-0.11049965830732783]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.008882343518480135]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.509540557861328,6.540352821350098:A-0.11049965830732783]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.008882343518480135]]),input_ids:T7s1x1[388,388:A388.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.050418853759766,8.69715404510498:A-5.870951075625606],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.509540557861328,6.540352821350098:A-0.09253683880554793]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.00932457198429246]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.509540557861328,6.540352821350098:A-0.09253683880554793]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.00932457198429246]]),input_ids:T7s1x1[363,363:A363.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.969917297363281,4.7796430587768555:A-8.691098548230249],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.509540557861328,7.7615065574646:A-0.09071257812198989]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.008131573297491857]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.509540557861328,7.7615065574646:A-0.09071257812198989]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.008131573297491857]]),input_ids:T7s1x1[395,395:A395.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.560033798217773,10.845871925354004:A-5.587876864117105],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.509540557861328,7.7615065574646:A-0.10584936178955624]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.005287295096412068]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.509540557861328,7.7615065574646:A-0.10584936178955624]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.005287295096412068]]),input_ids:T7s1x1[29946,29946:A29946.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.021800994873047,7.943614959716797:A-9.891893816536758],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-6.6544508934021,7.7615065574646:A-0.11246078275103678]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.0043419967572744]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-6.6544508934021,7.7615065574646:A-0.11246078275103678]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.0043419967572744]]),input_ids:T7s1x1[29929,29929:A29929.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.839542388916016,7.427285671234131:A-9.898015989836772],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-6.6544508934021,7.7615065574646:A-0.11938203940091323]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0025482160659275146]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-6.6544508934021,7.7615065574646:A-0.11938203940091323]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0025482160659275146]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.888417720794678,9.502809524536133:A-2.749738055441761],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.6544508934021,7.7615065574646:A-0.12290454222031758]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.004061226553189686]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.6544508934021,7.7615065574646:A-0.12290454222031758]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.004061226553189686]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.51053237915039,4.5319952964782715:A-8.75580452865525],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.6544508934021,7.7615065574646:A-0.12374490506415357]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.004039735123072845]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.6544508934021,7.7615065574646:A-0.12374490506415357]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.004039735123072845]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.291577339172363,13.843061447143555:A-3.493030960181262],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.6544508934021,7.7615065574646:A-0.12551499162656368]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.002186707341399852]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.6544508934021,7.7615065574646:A-0.12551499162656368]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.002186707341399852]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.667932510375977,3.051494598388672:A-11.655316781926901],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.6544508934021,7.988457679748535:A-0.11675897925761092]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.0019549457772806513]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.6544508934021,7.988457679748535:A-0.11675897925761092]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.0019549457772806513]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.824583053588867,2.734921455383301:A-11.789864917821252],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.6544508934021,7.988457679748535:A-0.11479982524565818]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.0017452567430775651]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.6544508934021,7.988457679748535:A-0.11479982524565818]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.0017452567430775651]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.48427963256836,9.069987297058105:A-6.623168629536405],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.6544508934021,7.988457679748535:A-0.11198793186700971]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.001832944200714337]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.6544508934021,7.988457679748535:A-0.11198793186700971]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.001832944200714337]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.128677368164062,2.9294705390930176:A-12.22451045340672],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.993858814239502,7.988457679748535:A-0.11222264520372424]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.0015085334745587413]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.993858814239502,7.988457679748535:A-0.11222264520372424]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.0015085334745587413]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.06477165222168,12.539310455322266:A-4.034430594373495],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.993858814239502,7.988457679748535:A-0.10622886746947795]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.00014701988275570935]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.993858814239502,7.988457679748535:A-0.10622886746947795]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.00014701988275570935]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.963985443115234,4.890559673309326:A-11.444591832424514],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.993858814239502,7.988457679748535:A-0.10285045109044101]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A4.3198129806114597e-05]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.993858814239502,7.988457679748535:A-0.10285045109044101]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A4.3198129806114597e-05]]),input_ids:T7s1x1[29941,29941:A29941.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.188583374023438,6.214415550231934:A-11.473405036952347],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.993858814239502,7.988457679748535:A-0.09620339705478005]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-5.683249266369855e-05]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.993858814239502,7.988457679748535:A-0.09620339705478005]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-5.683249266369855e-05]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.311017990112305,9.818561553955078:A-5.837241976384539],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.993858814239502,7.988457679748535:A-0.09236260112654382]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A8.136059266038456e-05]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.993858814239502,7.988457679748535:A-0.09236260112654382]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A8.136059266038456e-05]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.469409942626953,6.609431266784668:A-12.143391815945506],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-7.999135494232178,7.988457679748535:A-0.09617399696413659]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-8.992426398349163e-06]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-7.999135494232178,7.988457679748535:A-0.09617399696413659]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-8.992426398349163e-06]]),input_ids:T7s1x1[29955,29955:A29955.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.79451560974121,4.2854905128479:A-12.53562844231911],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-7.999135494232178,7.988457679748535:A-0.09886911942337702]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.0006019089243647154]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-7.999135494232178,7.988457679748535:A-0.09886911942337702]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.0006019089243647154]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-9.357734680175781,11.605958938598633:A-3.7251004789504223],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-7.999135494232178,7.988457679748535:A-0.09535541591770501]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.0003604678514269229]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-7.999135494232178,7.988457679748535:A-0.09535541591770501]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.0003604678514269229]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.97235870361328,2.1917269229888916:A-11.143657298058738],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-7.999135494232178,7.988457679748535:A-0.0921674876929413]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.0004673682694804591]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-7.999135494232178,7.988457679748535:A-0.0921674876929413]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.0004673682694804591]]),input_ids:T7s1x1[399,399:A399.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.27956771850586,11.503822326660156:A-2.8761391152418216],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-7.999135494232178,7.988457679748535:A-0.08664193680135668]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.001360551845166924]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-7.999135494232178,7.988457679748535:A-0.08664193680135668]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.001360551845166924]]),input_ids:T7s1x1[598,598:A598.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.697214126586914,6.44054651260376:A-8.798265672787092],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-7.999135494232178,7.988457679748535:A-0.08183755590454728]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.0013312865389248098]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-7.999135494232178,7.988457679748535:A-0.08183755590454728]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.0013312865389248098]]),input_ids:T7s1x1[3307,3307:A3307.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.3937931060791,6.993257522583008:A-9.127292814762331],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-7.999135494232178,7.988457679748535:A-0.07982303033410669]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.001324651840413222]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-7.999135494232178,7.988457679748535:A-0.07982303033410669]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.001324651840413222]]),input_ids:T7s1x1[304,304:A304.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.881256103515625,6.579377174377441:A-8.133967524012784],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-7.999135494232178,7.988457679748535:A-0.07773771530541819]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.0017724120373713958]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-7.999135494232178,7.988457679748535:A-0.07773771530541819]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.0017724120373713958]]),input_ids:T7s1x1[437,437:A437.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.573562622070312,6.619939804077148:A-9.605280767610296],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-7.999135494232178,7.988457679748535:A-0.07457322255157595]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.0012776492868119425]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-7.999135494232178,7.988457679748535:A-0.07457322255157595]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.0012776492868119425]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.93636131286621,4.505215644836426:A-8.935527173072565],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-7.999135494232178,7.988457679748535:A-0.07322780954023442]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0016123774709641922]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-7.999135494232178,7.988457679748535:A-0.07322780954023442]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0016123774709641922]]),input_ids:T7s1x1[1021,1021:A1021.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.266992568969727,7.826637268066406:A-8.672308790531476],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-7.999135494232178,7.988457679748535:A-0.06982763404139086]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0018445075649547984]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-7.999135494232178,7.988457679748535:A-0.06982763404139086]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0018445075649547984]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.838875770568848,12.080994606018066:A-3.247482618824113],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-7.999135494232178,7.988457679748535:A-0.06708570983155523]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.0025220687645555583]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-7.999135494232178,7.988457679748535:A-0.06708570983155523]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.0025220687645555583]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.574447631835938,3.1893486976623535:A-10.43625673301518],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-7.999135494232178,7.988457679748535:A-0.06332291550282511]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.002550876565718833]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-7.999135494232178,7.988457679748535:A-0.06332291550282511]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.002550876565718833]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.307563781738281,10.014108657836914:A-7.09132053546235],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-7.999135494232178,7.988457679748535:A-0.05868335100905558]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.001728469997561934]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-7.999135494232178,7.988457679748535:A-0.05868335100905558]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.001728469997561934]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.55433464050293,3.5906128883361816:A-11.013929116971791],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-7.999135494232178,7.988457679748535:A-0.05899994574916953]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.0016290177609298842]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-7.999135494232178,7.988457679748535:A-0.05899994574916953]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.0016290177609298842]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.947246551513672,3.281421422958374:A-10.84253434949182],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-7.999135494232178,7.988457679748535:A-0.05781987646205803]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.0011459752170650986]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-7.999135494232178,7.988457679748535:A-0.05781987646205803]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.0011459752170650986]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.097780227661133,9.926095962524414:A-6.567275531813968],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-7.999135494232178,7.988457679748535:A-0.05440785075081129]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.001203438980565586]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-7.999135494232178,7.988457679748535:A-0.05440785075081129]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.001203438980565586]]),input_ids:T7s1x1[29896,29896:A29896.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-23.342710494995117,5.49420166015625:A-12.477824332008138],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-7.999135494232178,7.988457679748535:A-0.05398036079738328]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.0011222842489755917]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-7.999135494232178,7.988457679748535:A-0.05398036079738328]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.0011222842489755917]]),input_ids:T7s1x1[29929,29929:A29929.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.037118911743164,3.233747959136963:A-11.089064961513504],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-7.999135494232178,7.988457679748535:A-0.05372195484847327]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.0005683542804266493]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-7.999135494232178,7.988457679748535:A-0.05372195484847327]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.0005683542804266493]]),input_ids:T7s1x1[29945,29945:A29945.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.34065055847168,7.2644877433776855:A-10.663603161831386],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.999135494232178,7.988457679748535:A-0.05325095350373679]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.00029460512551198636]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.999135494232178,7.988457679748535:A-0.05325095350373679]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.00029460512551198636]]),input_ids:T7s1x1[29945,29945:A29945.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.227590560913086,3.387385129928589:A-13.036369040665218],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.999135494232178,7.988457679748535:A-0.05014284603933245]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A3.226218538543435e-05]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.999135494232178,7.988457679748535:A-0.05014284603933245]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A3.226218538543435e-05]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.065973281860352,12.935248374938965:A-5.813110903360881],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-7.999135494232178,7.988457679748535:A-0.045195296534669124]], value_cache=#1[T1s1x1x49x96[-0.6787744760513306,0.7704185843467712:A0.00010659113693035498]]))
-- prompt Continue: it rains...
-- answer Continue: it rains...
Sunday for $49
- 11-2 13-17
- Ware enough to do the same
- 18-1955-1

Let’s restore the forward as it was.

model.forward = keep_model_forward

Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward().

with steal_forward(model):
    model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
  <- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.

Untrained model

This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.

Let’s create an untrained model using the config file provided config.json to create an untrained model: onnx_diagnostic.torch_models.llms.get_tiny_llm(). Then let’s use it.

experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.

print("input type before", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)

print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

The outputs

print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

It works.

ExportedProgram

try:
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x71150fd238c0> not registered

Back to the original model

Let’s use the same dummy inputs but we use the downloaded model. Dummy inputs and dynamic shapes are created by function onnx_diagnostic.torch_models.llms.get_tiny_llm().

data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

Let’s print the inputs.

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
 'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}

And Let’s finally export.

try:
    ep = torch.export.export(
        model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x71150c473920> not registered

If you have any error, then look at example Export Tiny-LLM with patches.

doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")
plot export tiny llm

Total running time of the script: (0 minutes 2.731 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Test the export on untrained models

Test the export on untrained models

Gallery generated by Sphinx-Gallery