Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Inputs are always dynamic with LLMs that is why dynamic shapes needs to be specified when a LLM is exported with torch.export.export(). Most of the examples on HuggingFace use method transformers.GenerationMixin.generate() but we only want to export the model and its method forward.

That example shows to guess the inputs of this method even though the model is executed through meth generate.

We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.

Steel the forward method

The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

We rewrite the forward method to print the cache dimension.

def _forward_(*args, _f=None, **kwargs):
    assert _f is not None
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        # torch.compiler.is_exporting requires torch>=2.7
        print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
    res = _f(*args, **kwargs)
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        print("->", string_type(res, with_shape=True, with_min_max=True))
    return res


keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
    *args, _f=_f, **kwargs
)

Let’s run the model.

prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.432564735412598,8.368535995483398:A-4.234468644971028],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.509540557861328,6.348220348358154:A-0.12195695057461206]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.7704185843467712:A0.009565710057611594]]),input_ids:T7s1x1[29941,29941:A29941.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.135845184326172,4.338468074798584:A-9.799261990881526],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.811427116394043,6.348220348358154:A-0.1250392806283344]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.008353379246409531]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.811427116394043,6.348220348358154:A-0.1250392806283344]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.7704185843467712:A0.008353379246409531]]),input_ids:T7s1x1[2610,2610:A2610.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.630133628845215,11.018857955932617:A-7.585419899705564],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.811427116394043,6.348220348358154:A-0.12102800812234901]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.009305229967787565]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.811427116394043,6.348220348358154:A-0.12102800812234901]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.7704185843467712:A0.009305229967787565]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.729768753051758,15.795294761657715:A-6.686570847850293],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.811427116394043,6.348220348358154:A-0.11321421906611956]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.0059324780764124325]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.811427116394043,6.348220348358154:A-0.11321421906611956]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.7704185843467712:A0.0059324780764124325]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.151559829711914,12.831623077392578:A-8.03768296736572],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.811427116394043,6.348220348358154:A-0.09604190427237952]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.005043171878160371]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.811427116394043,6.348220348358154:A-0.09604190427237952]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.7704185843467712:A0.005043171878160371]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.921062469482422,10.401754379272461:A-10.987481685367879],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.811427116394043,6.6982550621032715:A-0.08939146100766222]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.0033156730564877805]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.811427116394043,6.6982550621032715:A-0.08939146100766222]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.7704185843467712:A0.0033156730564877805]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.109092712402344,11.451603889465332:A-8.615241368453717],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.811427116394043,6.6982550621032715:A-0.08828816101069809]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.004018123734315143]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.811427116394043,6.6982550621032715:A-0.08828816101069809]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.7704185843467712:A0.004018123734315143]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.535235404968262,18.55143928527832:A-4.359501328858081],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.811427116394043,6.6982550621032715:A-0.10365996128249814]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0018190039553758197]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.811427116394043,6.6982550621032715:A-0.10365996128249814]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.7704185843467712:A0.0018190039553758197]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.208690643310547,12.71769905090332:A-10.155632720829919],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.856588363647461,6.6982550621032715:A-0.10853694268128795]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.0013809153403028676]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.856588363647461,6.6982550621032715:A-0.10853694268128795]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.7704185843467712:A0.0013809153403028676]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.913084030151367,11.392576217651367:A-11.69873876139149],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.856588363647461,6.6982550621032715:A-0.1023646623671084]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A-0.00039179591874893014]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.856588363647461,6.6982550621032715:A-0.1023646623671084]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A-0.00039179591874893014]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.21382713317871,13.484505653381348:A-7.445791564540472],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.856588363647461,6.6982550621032715:A-0.08908409716247677]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A-0.0019779059926373806]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.856588363647461,6.6982550621032715:A-0.08908409716247677]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A-0.0019779059926373806]]),input_ids:T7s1x1[29955,29955:A29955.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.9020938873291,5.044755458831787:A-13.776848735461943],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.856588363647461,6.840986728668213:A-0.07805764829069327]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.00273918923637666]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.856588363647461,6.840986728668213:A-0.07805764829069327]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A-0.00273918923637666]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.70694351196289,7.221324920654297:A-8.892709849011153],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.856588363647461,6.840986728668213:A-0.07859572630424361]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A-0.001949111976363571]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.856588363647461,6.840986728668213:A-0.07859572630424361]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A-0.001949111976363571]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.547388076782227,15.494963645935059:A-5.068759219610598],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.856588363647461,6.840986728668213:A-0.07658115250231992]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.003277233828743137]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.856588363647461,6.840986728668213:A-0.07658115250231992]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A-0.003277233828743137]]),input_ids:T7s1x1[29946,29946:A29946.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.133716583251953,6.478907585144043:A-11.908872234937736],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.882454872131348,6.840986728668213:A-0.07968059192137224]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.003521361922739216]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.882454872131348,6.840986728668213:A-0.07968059192137224]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A-0.003521361922739216]]),input_ids:T7s1x1[29901,29901:A29901.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.94466209411621,7.7586669921875:A-12.086376501435414],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.882454872131348,6.840986728668213:A-0.07993684639399766]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.0027230712880800135]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.882454872131348,6.840986728668213:A-0.07993684639399766]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A-0.0027230712880800135]]),input_ids:T7s1x1[29945,29945:A29945.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.87329864501953,9.426261901855469:A-10.317357872662134],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.882454872131348,6.840986728668213:A-0.07471688298367857]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.003106062676579313]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.882454872131348,6.840986728668213:A-0.07471688298367857]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A-0.003106062676579313]]),input_ids:T7s1x1[29953,29953:A29953.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.021434783935547,5.666491508483887:A-11.43309489883529],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.882454872131348,6.840986728668213:A-0.06830911466549412]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.003428964488013009]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.882454872131348,6.840986728668213:A-0.06830911466549412]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A-0.003428964488013009]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.502959251403809,8.647699356079102:A-3.666581474106759],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.882454872131348,6.840986728668213:A-0.06822447889040902]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.00225495119773903]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.882454872131348,6.840986728668213:A-0.06822447889040902]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A-0.00225495119773903]]),input_ids:T7s1x1[29928,29928:A29928.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.725600242614746,10.529221534729004:A-1.5172752494900488],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.882454872131348,6.840986728668213:A-0.06527995995752246]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.00138839243755315]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.882454872131348,6.840986728668213:A-0.06527995995752246]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A-0.00138839243755315]]),input_ids:T7s1x1[485,485:A485.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.5223307609558105,12.079378128051758:A-2.200265917599667],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.882454872131348,6.840986728668213:A-0.06148312655263333]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.001493190796611151]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.882454872131348,6.840986728668213:A-0.06148312655263333]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A-0.001493190796611151]]),input_ids:T7s1x1[262,262:A262.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.054697036743164,3.195920944213867:A-9.912151685696095],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.882454872131348,6.840986728668213:A-0.0526473199523025]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.0016696463759874151]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.882454872131348,6.840986728668213:A-0.0526473199523025]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A-0.0016696463759874151]]),input_ids:T7s1x1[435,435:A435.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.00013542175293,7.802364826202393:A-6.514995807819068],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.882454872131348,6.840986728668213:A-0.04733127825504628]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.0008047377704882507]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.882454872131348,6.840986728668213:A-0.04733127825504628]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A-0.0008047377704882507]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.103503227233887,2.7568225860595703:A-9.391665586687624],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.882454872131348,6.840986728668213:A-0.04055644638731337]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.0003467011769136737]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.882454872131348,6.840986728668213:A-0.04055644638731337]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A-0.0003467011769136737]]),input_ids:T7s1x1[322,322:A322.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.71291732788086,2.768527030944824:A-9.84994037065329],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.882454872131348,6.840986728668213:A-0.040666309959186665]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.0006287981841110609]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.882454872131348,6.840986728668213:A-0.040666309959186665]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A-0.0006287981841110609]]),input_ids:T7s1x1[390,390:A390.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.72878646850586,8.927671432495117:A-5.238189476897475],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.882454872131348,6.840986728668213:A-0.04073903500626758]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.00030211803927355217]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.882454872131348,6.840986728668213:A-0.04073903500626758]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A-0.00030211803927355217]]),input_ids:T7s1x1[29968,29968:A29968.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.1890926361084,4.8266496658325195:A-9.480396195146255],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.882454872131348,6.840986728668213:A-0.036860077806375645]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.00034705521119186806]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.882454872131348,6.840986728668213:A-0.036860077806375645]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.00034705521119186806]]),input_ids:T7s1x1[29928,29928:A29928.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.828540802001953,5.075605869293213:A-10.123816122239456],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.882454872131348,6.840986728668213:A-0.03573528471056138]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.0009487674021994719]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.882454872131348,6.840986728668213:A-0.03573528471056138]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.0009487674021994719]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.334418296813965,3.0961825847625732:A-9.918788560789078],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.882454872131348,6.840986728668213:A-0.03620726755868852]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0012975151271102485]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.882454872131348,6.840986728668213:A-0.03620726755868852]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0012975151271102485]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.548851013183594,10.435674667358398:A-3.7397975996453314],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.882454872131348,6.840986728668213:A-0.036007717760679826]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0013311053766293505]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.882454872131348,6.840986728668213:A-0.036007717760679826]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.0013311053766293505]]),input_ids:T7s1x1[1073,1073:A1073.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.840539932250977,6.937159061431885:A-10.19300847663451],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.882454872131348,6.840986728668213:A-0.03415937627835222]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.00108061570339347]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.882454872131348,6.840986728668213:A-0.03415937627835222]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.00108061570339347]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.949212074279785,6.899430274963379:A-8.692059776168783],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.882454872131348,7.078313827514648:A-0.03448053422941181]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0013428207877003236]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.882454872131348,7.078313827514648:A-0.03448053422941181]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0013428207877003236]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.425172805786133,9.370534896850586:A-6.1916607732344415],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.882454872131348,7.078313827514648:A-0.03414592427012631]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.0014931152163145766]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.882454872131348,7.078313827514648:A-0.03414592427012631]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.0014931152163145766]]),input_ids:T7s1x1[29915,29915:A29915.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.616035461425781,15.954569816589355:A-4.68313079584483],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.882454872131348,7.078313827514648:A-0.03241555252936941]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.0008153728593004503]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.882454872131348,7.078313827514648:A-0.03241555252936941]], value_cache=#1[T1s1x1x42x96[-1.1154754161834717,0.7704185843467712:A0.0008153728593004503]]),input_ids:T7s1x1[29885,29885:A29885.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.34793472290039,5.864483833312988:A-8.361768380252295],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.882454872131348,7.078313827514648:A-0.03249438322728925]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.0002990090804553964]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.882454872131348,7.078313827514648:A-0.03249438322728925]], value_cache=#1[T1s1x1x43x96[-1.1154754161834717,0.7704185843467712:A0.0002990090804553964]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-9.470327377319336,15.861299514770508:A-2.115428224620875],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.882454872131348,7.078313827514648:A-0.028590843711478527]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A-0.0004161455061166359]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.882454872131348,7.078313827514648:A-0.028590843711478527]], value_cache=#1[T1s1x1x44x96[-1.1154754161834717,0.7704185843467712:A-0.0004161455061166359]]),input_ids:T7s1x1[29947,29947:A29947.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.115264892578125,7.146152496337891:A-8.258798646918498],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.882454872131348,7.078313827514648:A-0.028340438807127714]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A-0.000832271419875286]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.882454872131348,7.078313827514648:A-0.028340438807127714]], value_cache=#1[T1s1x1x45x96[-1.1154754161834717,0.7704185843467712:A-0.000832271419875286]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.093294143676758,6.9178547859191895:A-8.125556804327294],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.882454872131348,7.078313827514648:A-0.028565603406525992]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A-0.0014778282873264645]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.882454872131348,7.078313827514648:A-0.028565603406525992]], value_cache=#1[T1s1x1x46x96[-1.1154754161834717,0.7704185843467712:A-0.0014778282873264645]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.948810577392578,3.4617862701416016:A-10.526158623103983],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.198397636413574,7.078313827514648:A-0.028059341673402186]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A-0.0020959146497797204]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-7.198397636413574,7.078313827514648:A-0.028059341673402186]], value_cache=#1[T1s1x1x47x96[-1.1154754161834717,0.7704185843467712:A-0.0020959146497797204]]),input_ids:T7s1x1[29900,29900:A29900.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-21.243667602539062,1.9742556810379028:A-12.392495226606727],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.198397636413574,7.078313827514648:A-0.02473806524628546]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A-0.002688247413797424]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-7.198397636413574,7.078313827514648:A-0.02473806524628546]], value_cache=#1[T1s1x1x48x96[-1.1154754161834717,0.7704185843467712:A-0.002688247413797424]]),input_ids:T7s1x1[29941,29941:A29941.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.82381820678711,3.515730857849121:A-12.101337580995635],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-7.198397636413574,7.078313827514648:A-0.01913464772082896]], value_cache=#1[T1s1x1x49x96[-1.1154754161834717,0.7704185843467712:A-0.002685581100340518]]))
-- prompt Continue: it rains...
-- answer Continue: it rains...
3 May 28, 2007, 4:56
Davin J, and RKD, we know. I'm 800036

Let’s restore the forward as it was.

model.forward = keep_model_forward

Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward().

with steal_forward(model):
    model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
  <- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
  <- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.

Untrained model

This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.

Let’s create an untrained model using the config file provided config.json to create an untrained model: onnx_diagnostic.torch_models.llms.get_tiny_llm(). Then let’s use it.

experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.

print("input type before", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)

print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

The outputs

print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

It works.

ExportedProgram

try:
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It failed: 8*s72 (127382405373088)is not tracked with proxy for <torch.fx.experimental.proxy_tensor._ModuleStackTracer object at 0x73da86e0b650>

Back to the original model

Let’s use the same dummy inputs but we use the downloaded model. Dummy inputs and dynamic shapes are created by function onnx_diagnostic.torch_models.llms.get_tiny_llm().

data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

Let’s print the inputs.

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: 'batch', 1: 'cache+seq'},
 'input_ids': {0: 'batch', 1: 'seq_length'},
 'past_key_values': [[{0: 'batch', 2: 'cache_length'}],
                     [{0: 'batch', 2: 'cache_length'}]],
 'position_ids': {0: 'batch', 1: 'cache+seq'}}

And Let’s finally export.

try:
    ep = torch.export.export(
        model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It failed: 8*s72 (127382406994672)is not tracked with proxy for <torch.fx.experimental.proxy_tensor._ModuleStackTracer object at 0x73da87081fa0>

If you have any error, then look at example Export Tiny-LLM with patches.

doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")
plot export tiny llm

Total running time of the script: (0 minutes 2.731 seconds)

Related examples

Export microsoft/phi-2

Export microsoft/phi-2

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export with DynamicCache and guessed dynamic shapes

Export with DynamicCache and guessed dynamic shapes

Gallery generated by Sphinx-Gallery