Note
Go to the end to download the full example code.
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with torch.export.export()
.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[910,910:A910.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.646947860717773,6.200076103210449:A-8.397947782306],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.14168291143505485]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.0072832345568645835]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.14168291143505485]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.0072832345568645835]]),input_ids:T7s1x1[338,338:A338.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.090131759643555,7.717521667480469:A-8.146808001592756],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.556069374084473,6.226877689361572:A-0.14703989434189377]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.006896714328574186]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.556069374084473,6.226877689361572:A-0.14703989434189377]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.006896714328574186]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.002761840820312,5.229127883911133:A-6.962000927583082],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.615224838256836,6.561506271362305:A-0.1616650741086756]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.007782682185431137]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.615224838256836,6.561506271362305:A-0.1616650741086756]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.007782682185431137]]),input_ids:T7s1x1[760,760:A760.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.227304458618164,10.62122917175293:A-7.661171555686276],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.615224838256836,6.561506271362305:A-0.15541725033542914]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.49568021297454834:A0.004062080485697252]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.615224838256836,6.561506271362305:A-0.15541725033542914]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.49568021297454834:A0.004062080485697252]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.49190902709961,5.520626068115234:A-9.510941162860021],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.615224838256836,6.561506271362305:A-0.14535933936741247]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.49568021297454834:A0.004301025220147727]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.615224838256836,6.561506271362305:A-0.14535933936741247]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.49568021297454834:A0.004301025220147727]]),input_ids:T7s1x1[1749,1749:A1749.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.034587860107422,3.760286331176758:A-7.952430804436561],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-6.0752644538879395,6.561506271362305:A-0.1403430000343741]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.49568021297454834:A0.004099058301833446]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-6.0752644538879395,6.561506271362305:A-0.1403430000343741]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.49568021297454834:A0.004099058301833446]]),input_ids:T7s1x1[3889,3889:A3889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.837169647216797,2.7178168296813965:A-9.845356908234768],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-6.0752644538879395,6.561506271362305:A-0.13389094234080404]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.49568021297454834:A0.0034263806791790963]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-6.0752644538879395,6.561506271362305:A-0.13389094234080404]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.49568021297454834:A0.0034263806791790963]]),input_ids:T7s1x1[664,664:A664.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.211637496948242,7.891732692718506:A-8.29793752584234],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-6.0752644538879395,6.561506271362305:A-0.12716216281258616]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.49568021297454834:A0.0034651014380623715]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-6.0752644538879395,6.561506271362305:A-0.12716216281258616]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.49568021297454834:A0.0034651014380623715]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.742939949035645,7.792896270751953:A-8.512116801181342],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.0752644538879395,7.040541648864746:A-0.13139735067492367]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.5450473427772522:A0.00394179071086268]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.0752644538879395,7.040541648864746:A-0.13139735067492367]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.5450473427772522:A0.00394179071086268]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.509566307067871,9.730469703674316:A-2.8481635262952185],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.0752644538879395,7.040541648864746:A-0.13299581723866274]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.005293324246336111]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.0752644538879395,7.040541648864746:A-0.13299581723866274]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.7704185843467712:A0.005293324246336111]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.276447296142578,3.863849639892578:A-9.108061160747893],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.0752644538879395,7.040541648864746:A-0.12931169389016678]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.005208116697112449]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.0752644538879395,7.040541648864746:A-0.12931169389016678]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.7704185843467712:A0.005208116697112449]]),input_ids:T7s1x1[1334,1334:A1334.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.339250564575195,10.697063446044922:A-3.8610484419988933],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.0752644538879395,7.040541648864746:A-0.12824437141504555]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.00502735465263413]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.0752644538879395,7.040541648864746:A-0.12824437141504555]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.7704185843467712:A0.00502735465263413]]),input_ids:T7s1x1[526,526:A526.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.613191604614258,6.067843437194824:A-8.398026848909911],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.0752644538879395,7.040541648864746:A-0.12427503910322908]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.004705718233802303]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.0752644538879395,7.040541648864746:A-0.12427503910322908]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.7704185843467712:A0.004705718233802303]]),input_ids:T7s1x1[1407,1407:A1407.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.604442596435547,9.673364639282227:A-6.13026818135893],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.0752644538879395,7.040541648864746:A-0.12393082087923672]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.004316238149723369]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.0752644538879395,7.040541648864746:A-0.12393082087923672]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.7704185843467712:A0.004316238149723369]]),input_ids:T7s1x1[1781,1781:A1781.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.008285522460938,5.844202995300293:A-10.460380604167934],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.0752644538879395,7.040541648864746:A-0.12341701357761031]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.005046067914977554]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.0752644538879395,7.040541648864746:A-0.12341701357761031]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.7704185843467712:A0.005046067914977554]]),input_ids:T7s1x1[472,472:A472.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.88067626953125,4.26597785949707:A-8.971358215967658],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.0752644538879395,7.040541648864746:A-0.11997644997422362]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.004233901202540993]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.0752644538879395,7.040541648864746:A-0.11997644997422362]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.7704185843467712:A0.004233901202540993]]),input_ids:T7s1x1[664,664:A664.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.084211349487305,8.734560012817383:A-7.706683804829372],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.0752644538879395,7.040541648864746:A-0.11596168465694064]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A0.004226381667291813]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.0752644538879395,7.040541648864746:A-0.11596168465694064]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.7704185843467712:A0.004226381667291813]]),input_ids:T7s1x1[322,322:A322.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.806533813476562,10.124608993530273:A-5.58848192453105],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.0752644538879395,7.040541648864746:A-0.11482213054370778]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A0.003692447664148765]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.0752644538879395,7.040541648864746:A-0.11482213054370778]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.7704185843467712:A0.003692447664148765]]),input_ids:T7s1x1[664,664:A664.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.121829986572266,9.74388599395752:A-6.678059865375748],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.0752644538879395,7.040541648864746:A-0.10892830754968555]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A0.003705538966265903]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.0752644538879395,7.040541648864746:A-0.10892830754968555]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.7704185843467712:A0.003705538966265903]]),input_ids:T7s1x1[373,373:A373.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-20.720945358276367,4.704853057861328:A-10.75672876830306],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.0752644538879395,7.040541648864746:A-0.10388165370676732]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A0.0028761462741071922]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.0752644538879395,7.040541648864746:A-0.10388165370676732]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.7704185843467712:A0.0028761462741071922]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.579059600830078,3.1370785236358643:A-8.706778356491588],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.0752644538879395,7.040541648864746:A-0.09978157794450097]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A0.0032480926819119503]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.0752644538879395,7.040541648864746:A-0.09978157794450097]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.7704185843467712:A0.0032480926819119503]]),input_ids:T7s1x1[4700,4700:A4700.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.605297088623047,9.580116271972656:A-8.856629673856544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.0752644538879395,7.040541648864746:A-0.09491640068150041]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.0031573915695187475]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.0752644538879395,7.040541648864746:A-0.09491640068150041]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.7704185843467712:A0.0031573915695187475]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.236988067626953,8.914348602294922:A-7.639233135399874],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.0752644538879395,7.066565036773682:A-0.0882099473115319]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.0034287279406848403]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.0752644538879395,7.066565036773682:A-0.0882099473115319]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.7704185843467712:A0.0034287279406848403]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-7.898897171020508,12.787775039672852:A-2.7137405212041923],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.0752644538879395,7.066565036773682:A-0.08407474773847905]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.004204998765956702]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.0752644538879395,7.066565036773682:A-0.08407474773847905]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.7704185843467712:A0.004204998765956702]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.565919876098633,4.15845251083374:A-9.402753400890157],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.0752644538879395,7.066565036773682:A-0.08286440813856763]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.00418891943399094]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.0752644538879395,7.066565036773682:A-0.08286440813856763]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.7704185843467712:A0.00418891943399094]]),input_ids:T7s1x1[910,910:A910.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.804019927978516,8.278521537780762:A-6.805950582148973],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.0752644538879395,7.066565036773682:A-0.08245940126578395]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.004292984004834187]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.0752644538879395,7.066565036773682:A-0.08245940126578395]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.7704185843467712:A0.004292984004834187]]),input_ids:T7s1x1[2794,2794:A2794.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.040019989013672,6.763426780700684:A-8.730520411090925],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.0752644538879395,7.066565036773682:A-0.0794041071187946]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.0044168401827889845]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.0752644538879395,7.066565036773682:A-0.0794041071187946]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.7704185843467712:A0.0044168401827889845]]),input_ids:T7s1x1[591,591:A591.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.592232704162598,12.337525367736816:A-4.852083723993972],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.0752644538879395,7.066565036773682:A-0.07899568123425356]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.004365648639068072]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.0752644538879395,7.066565036773682:A-0.07899568123425356]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.7704185843467712:A0.004365648639068072]]),input_ids:T7s1x1[508,508:A508.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.29562759399414,9.427640914916992:A-6.5143384498339145],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.0752644538879395,7.066565036773682:A-0.07725866622855097]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0042003935156965265]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.0752644538879395,7.066565036773682:A-0.07725866622855097]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.7704185843467712:A0.0042003935156965265]]),input_ids:T7s1x1[871,871:A871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.615734100341797,10.038257598876953:A-6.6629056220497],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.0752644538879395,7.066565036773682:A-0.07804625507136127]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.004511671301846339]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.0752644538879395,7.066565036773682:A-0.07804625507136127]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.7704185843467712:A0.004511671301846339]]),input_ids:T7s1x1[664,664:A664.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.111228942871094,9.62582015991211:A-7.555076593886129],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.0752644538879395,7.066565036773682:A-0.074877238089116]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.00449972877670442]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.0752644538879395,7.066565036773682:A-0.074877238089116]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.7704185843467712:A0.00449972877670442]]),input_ids:T7s1x1[373,373:A373.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.5626220703125,4.899341106414795:A-10.349103797799907],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.0752644538879395,7.066565036773682:A-0.06973190112133427]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0038992991469323592]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.0752644538879395,7.066565036773682:A-0.06973190112133427]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.7704185843467712:A0.0038992991469323592]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.737903594970703,5.488068580627441:A-7.779609435026534],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.0752644538879395,7.066565036773682:A-0.06732287836481501]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.004137428243359501]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.0752644538879395,7.066565036773682:A-0.06732287836481501]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.7704185843467712:A0.004137428243359501]]),input_ids:T7s1x1[3234,3234:A3234.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.344017028808594,7.672732353210449:A-8.02427680779621],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.0752644538879395,7.066565036773682:A-0.06437035002798236]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.0038962554182100056]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.0752644538879395,7.066565036773682:A-0.06437035002798236]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.7704185843467712:A0.0038962554182100056]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.541034698486328,9.532099723815918:A-7.512001441150438],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.0752644538879395,7.081362724304199:A-0.06416990056490347]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.004074687131174136]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.0752644538879395,7.081362724304199:A-0.06416990056490347]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.7704185843467712:A0.004074687131174136]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-8.409038543701172,13.753633499145508:A-2.9449270719778724],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.136015892028809,7.081362724304199:A-0.06187563871514034]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.0046245668406789155]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.136015892028809,7.081362724304199:A-0.06187563871514034]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.7704185843467712:A0.0046245668406789155]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.204063415527344,4.511193752288818:A-9.766363083241973],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.136015892028809,7.081362724304199:A-0.05911586385937645]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.004603451595576863]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.136015892028809,7.081362724304199:A-0.05911586385937645]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.7704185843467712:A0.004603451595576863]]),input_ids:T7s1x1[1334,1334:A1334.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.411291122436523,11.425795555114746:A-3.5492393767207395],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.136015892028809,7.081362724304199:A-0.06036636752957049]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.004538004295837063]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.136015892028809,7.081362724304199:A-0.06036636752957049]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.7704185843467712:A0.004538004295837063]]),input_ids:T7s1x1[437,437:A437.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.047588348388672,8.31674575805664:A-8.738448999475688],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.136015892028809,7.081362724304199:A-0.05944317272933533]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.004100194694164595]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.136015892028809,7.081362724304199:A-0.05944317272933533]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.7704185843467712:A0.004100194694164595]]),input_ids:T7s1x1[13,13:A13.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-6.876320838928223,14.70174789428711:A-1.077073920846451],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.136015892028809,7.081362724304199:A-0.0581472310956441]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A0.004603719686981676]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.136015892028809,7.081362724304199:A-0.0581472310956441]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.7704185843467712:A0.004603719686981676]]),input_ids:T7s1x1[29899,29899:A29899.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.975006103515625,4.7387590408325195:A-9.138142121592072],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-6.136015892028809,7.081362724304199:A-0.05303632679391319]], value_cache=#1[T1s1x1x49x96[-0.6787744760513306,0.7704185843467712:A0.004584753587473611]]))
-- prompt Continue: it rains...
-- answer Continue: it rains... This is a part of our free work.
- We are very good at work and work on the website.
- This means we can only work on the product.
- We do
- Our
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward()
.
with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
<- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x79b44f83b440> not registered
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
And Let’s finally export.
try:
ep = torch.export.export(
model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x79b443405340> not registered
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 3.097 seconds)
Related examples