Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Inputs are always dynamic with LLMs that is why dynamic shapes needs to be specified when a LLM is exported with torch.export.export(). Most of the examples on HuggingFace use method transformers.GenerationMixin.generate() but we only want to export the model and its method forward.

That example shows to guess the inputs of this method even though the model is executed through meth generate.

We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.

Steel the forward method

The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.

import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)

We rewrite the forward method to print the cache dimension.

def _forward_(*args, _f=None, **kwargs):
    assert _f is not None
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        # torch.compiler.is_exporting requires torch>=2.7
        print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
    res = _f(*args, **kwargs)
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        print("->", string_type(res, with_shape=True, with_min_max=True))
    return res


keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
    *args, _f=_f, **kwargs
)

Let’s run the model.

prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[2803,2803:A2803.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.139326095581055,7.989854335784912:A-7.532701101825572],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.12309071480949232]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.009172345952332829]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.12309071480949232]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.009172345952332829]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.9604549407959,2.1183395385742188:A-9.114555937874364],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1437112894514333]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.009621370567144065]]))
<- ((),dict(cache_position:T7s1[10,10:A10.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1437112894514333]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.009621370567144065]]),input_ids:T7s1x1[282,282:A282.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.790553092956543,11.349196434020996:A-2.3142352080869024],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.490959167480469,6.226877689361572:A-0.131982613208851]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.00653785850688633]]))
<- ((),dict(cache_position:T7s1[11,11:A11.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96[-5.490959167480469,6.226877689361572:A-0.131982613208851]], value_cache=#1[T1s1x1x11x96[-0.6787744760513306,0.49568021297454834:A0.00653785850688633]]),input_ids:T7s1x1[457,457:A457.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.457244873046875,8.53797721862793:A-5.519397660128306],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.490959167480469,6.226877689361572:A-0.11274745565232377]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.49568021297454834:A0.00612946672389272]]))
<- ((),dict(cache_position:T7s1[12,12:A12.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96[-5.490959167480469,6.226877689361572:A-0.11274745565232377]], value_cache=#1[T1s1x1x12x96[-0.6787744760513306,0.49568021297454834:A0.00612946672389272]]),input_ids:T7s1x1[373,373:A373.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.06304931640625,4.107252597808838:A-10.73238494787598],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.490959167480469,6.226877689361572:A-0.10854281926218298]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.49568021297454834:A0.004156626482502664]]))
<- ((),dict(cache_position:T7s1[13,13:A13.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96[-5.490959167480469,6.226877689361572:A-0.10854281926218298]], value_cache=#1[T1s1x1x13x96[-0.6787744760513306,0.49568021297454834:A0.004156626482502664]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-19.92361068725586,0.7854024171829224:A-10.78099933495],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.490959167480469,6.226877689361572:A-0.10490619811640672]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.49568021297454834:A0.004835624026641415]]))
<- ((),dict(cache_position:T7s1[14,14:A14.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96[-5.490959167480469,6.226877689361572:A-0.10490619811640672]], value_cache=#1[T1s1x1x14x96[-0.6787744760513306,0.49568021297454834:A0.004835624026641415]]),input_ids:T7s1x1[29871,29871:A29871.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-10.574679374694824,14.82539176940918:A-2.726205240119714],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.490959167480469,6.226877689361572:A-0.11733232459996037]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.49568021297454834:A0.0024353962429510528]]))
<- ((),dict(cache_position:T7s1[15,15:A15.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96[-5.490959167480469,6.226877689361572:A-0.11733232459996037]], value_cache=#1[T1s1x1x15x96[-0.6787744760513306,0.49568021297454834:A0.0024353962429510528]]),input_ids:T7s1x1[29906,29906:A29906.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-23.10850715637207,6.435230255126953:A-9.63972538440395],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.490959167480469,6.226877689361572:A-0.12745934946929083]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.49568021297454834:A0.0019314025714625889]]))
<- ((),dict(cache_position:T7s1[16,16:A16.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96[-5.490959167480469,6.226877689361572:A-0.12745934946929083]], value_cache=#1[T1s1x1x16x96[-0.6787744760513306,0.49568021297454834:A0.0019314025714625889]]),input_ids:T7s1x1[29941,29941:A29941.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.41960906982422,4.423344135284424:A-10.625910959165543],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.851004600524902,6.226877689361572:A-0.1314545562350295]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.49568021297454834:A0.0016673437111171998]]))
<- ((),dict(cache_position:T7s1[17,17:A17.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96[-6.851004600524902,6.226877689361572:A-0.1314545562350295]], value_cache=#1[T1s1x1x17x96[-0.6787744760513306,0.49568021297454834:A0.0016673437111171998]]),input_ids:T7s1x1[5499,5499:A5499.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-22.13907814025879,3.339029312133789:A-11.397767138535157],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.851004600524902,6.226877689361572:A-0.1203454865454087]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.5491868257522583:A0.0028050357574287504]]))
<- ((),dict(cache_position:T7s1[18,18:A18.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96[-6.851004600524902,6.226877689361572:A-0.1203454865454087]], value_cache=#1[T1s1x1x18x96[-0.6787744760513306,0.5491868257522583:A0.0028050357574287504]]),input_ids:T7s1x1[11904,11904:A11904.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.545089721679688,5.198848247528076:A-10.147783757550641],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.851004600524902,6.226877689361572:A-0.11333097832188319]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.5491868257522583:A0.0017503377012680606]]))
<- ((),dict(cache_position:T7s1[19,19:A19.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96[-6.851004600524902,6.226877689361572:A-0.11333097832188319]], value_cache=#1[T1s1x1x19x96[-0.6787744760513306,0.5491868257522583:A0.0017503377012680606]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.449485778808594,7.352659702301025:A-10.236630266627296],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.851004600524902,6.226877689361572:A-0.10836411154759844]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.5491868257522583:A0.0020212389178823286]]))
<- ((),dict(cache_position:T7s1[20,20:A20.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96[-6.851004600524902,6.226877689361572:A-0.10836411154759844]], value_cache=#1[T1s1x1x20x96[-0.6787744760513306,0.5491868257522583:A0.0020212389178823286]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.722732543945312,2.577401876449585:A-9.322589841015636],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.851004600524902,6.226877689361572:A-0.10628366716997652]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.5491868257522583:A0.002575589069432941]]))
<- ((),dict(cache_position:T7s1[21,21:A21.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96[-6.851004600524902,6.226877689361572:A-0.10628366716997652]], value_cache=#1[T1s1x1x21x96[-0.6787744760513306,0.5491868257522583:A0.002575589069432941]]),input_ids:T7s1x1[282,282:A282.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.723098754882812,13.465116500854492:A-3.0993351022689604],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.851004600524902,6.226877689361572:A-0.10145489390820192]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.5491868257522583:A0.001354095834654579]]))
<- ((),dict(cache_position:T7s1[22,22:A22.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96[-6.851004600524902,6.226877689361572:A-0.10145489390820192]], value_cache=#1[T1s1x1x22x96[-0.6787744760513306,0.5491868257522583:A0.001354095834654579]]),input_ids:T7s1x1[457,457:A457.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.291308403015137,6.83212423324585:A-6.595367173196748],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.851004600524902,6.226877689361572:A-0.09557346753796953]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.5491868257522583:A0.0013664028466679895]]))
<- ((),dict(cache_position:T7s1[23,23:A23.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96[-6.851004600524902,6.226877689361572:A-0.09557346753796953]], value_cache=#1[T1s1x1x23x96[-0.6787744760513306,0.5491868257522583:A0.0013664028466679895]]),input_ids:T7s1x1[284,284:A284.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.498652458190918,4.909337520599365:A-7.826288171648979],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.851004600524902,6.226877689361572:A-0.09022182994011826]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.5646868348121643:A0.0009847358078774152]]))
<- ((),dict(cache_position:T7s1[24,24:A24.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96[-6.851004600524902,6.226877689361572:A-0.09022182994011826]], value_cache=#1[T1s1x1x24x96[-0.6787744760513306,0.5646868348121643:A0.0009847358078774152]]),input_ids:T7s1x1[471,471:A471.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.66248893737793,7.503578186035156:A-5.471658054812812],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.851004600524902,6.226877689361572:A-0.08821402675156909]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.5646868348121643:A0.0013008844572505042]]))
<- ((),dict(cache_position:T7s1[25,25:A25.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96[-6.851004600524902,6.226877689361572:A-0.08821402675156909]], value_cache=#1[T1s1x1x25x96[-0.6787744760513306,0.5646868348121643:A0.0013008844572505042]]),input_ids:T7s1x1[2790,2790:A2790.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.74355697631836,4.874917030334473:A-9.682159276143647],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.851004600524902,6.226877689361572:A-0.08482534370010306]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.5646868348121643:A0.0018226306210344307]]))
<- ((),dict(cache_position:T7s1[26,26:A26.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96[-6.851004600524902,6.226877689361572:A-0.08482534370010306]], value_cache=#1[T1s1x1x26x96[-0.6787744760513306,0.5646868348121643:A0.0018226306210344307]]),input_ids:T7s1x1[322,322:A322.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.895126342773438,7.03867769241333:A-6.967157858532388],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.851004600524902,6.468419075012207:A-0.08363473911500748]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.5646868348121643:A0.001397499767869176]]))
<- ((),dict(cache_position:T7s1[27,27:A27.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96[-6.851004600524902,6.468419075012207:A-0.08363473911500748]], value_cache=#1[T1s1x1x27x96[-0.6787744760513306,0.5646868348121643:A0.001397499767869176]]),input_ids:T7s1x1[263,263:A263.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-12.927133560180664,6.663992881774902:A-4.903734662296949],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.851004600524902,6.586821556091309:A-0.08327619138396715]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.5646868348121643:A0.0019419590888024427]]))
<- ((),dict(cache_position:T7s1[28,28:A28.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96[-6.851004600524902,6.586821556091309:A-0.08327619138396715]], value_cache=#1[T1s1x1x28x96[-0.6787744760513306,0.5646868348121643:A0.0019419590888024427]]),input_ids:T7s1x1[2586,2586:A2586.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.66410255432129,9.202840805053711:A-6.697580154446885],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.851004600524902,6.586821556091309:A-0.08226862622838269]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.5646868348121643:A0.0011138470264901617]]))
<- ((),dict(cache_position:T7s1[29,29:A29.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96[-6.851004600524902,6.586821556091309:A-0.08226862622838269]], value_cache=#1[T1s1x1x29x96[-0.6787744760513306,0.5646868348121643:A0.0011138470264901617]]),input_ids:T7s1x1[310,310:A310.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.824026107788086,6.104130744934082:A-8.846507747335824],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.851004600524902,6.586821556091309:A-0.07743902921336282]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.5646868348121643:A0.0013156641933922705]]))
<- ((),dict(cache_position:T7s1[30,30:A30.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96[-6.851004600524902,6.586821556091309:A-0.07743902921336282]], value_cache=#1[T1s1x1x30x96[-0.6787744760513306,0.5646868348121643:A0.0013156641933922705]]),input_ids:T7s1x1[372,372:A372.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.086811065673828,9.512075424194336:A-6.638189948941116],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.851004600524902,6.586821556091309:A-0.07725348350868444]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.5646868348121643:A0.0007874589188896187]]))
<- ((),dict(cache_position:T7s1[31,31:A31.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96[-6.851004600524902,6.586821556091309:A-0.07725348350868444]], value_cache=#1[T1s1x1x31x96[-0.6787744760513306,0.5646868348121643:A0.0007874589188896187]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.679304122924805,8.540875434875488:A-5.918238628771855],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.851004600524902,6.586821556091309:A-0.06954358272677534]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.5646868348121643:A0.0011957393659211373]]))
<- ((),dict(cache_position:T7s1[32,32:A32.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96[-6.851004600524902,6.586821556091309:A-0.06954358272677534]], value_cache=#1[T1s1x1x32x96[-0.6787744760513306,0.5646868348121643:A0.0011957393659211373]]),input_ids:T7s1x1[541,541:A541.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.235345840454102,8.420061111450195:A-6.781005434720777],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.851004600524902,6.586821556091309:A-0.06719411371699552]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.5646868348121643:A0.0008798652613884108]]))
<- ((),dict(cache_position:T7s1[33,33:A33.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96[-6.851004600524902,6.586821556091309:A-0.06719411371699552]], value_cache=#1[T1s1x1x33x96[-0.6787744760513306,0.5646868348121643:A0.0008798652613884108]]),input_ids:T7s1x1[372,372:A372.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.133174896240234,12.845958709716797:A-4.596729983803118],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.851004600524902,6.586821556091309:A-0.06834739897217956]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.5646868348121643:A0.00041108395028316533]]))
<- ((),dict(cache_position:T7s1[34,34:A34.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96[-6.851004600524902,6.586821556091309:A-0.06834739897217956]], value_cache=#1[T1s1x1x34x96[-0.6787744760513306,0.5646868348121643:A0.00041108395028316533]]),input_ids:T7s1x1[3430,3430:A3430.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.09120750427246,11.081531524658203:A-6.440643578020856],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.851004600524902,6.586821556091309:A-0.06755391214513413]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.5646868348121643:A0.0005617578158740972]]))
<- ((),dict(cache_position:T7s1[35,35:A35.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96[-6.851004600524902,6.586821556091309:A-0.06755391214513413]], value_cache=#1[T1s1x1x35x96[-0.6787744760513306,0.5646868348121643:A0.0005617578158740972]]),input_ids:T7s1x1[2691,2691:A2691.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.141157150268555,10.224329948425293:A-8.149924114157912],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.851004600524902,6.586821556091309:A-0.06424710824982564]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.5646868348121643:A0.0009106011178043142]]))
<- ((),dict(cache_position:T7s1[36,36:A36.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96[-6.851004600524902,6.586821556091309:A-0.06424710824982564]], value_cache=#1[T1s1x1x36x96[-0.6787744760513306,0.5646868348121643:A0.0009106011178043142]]),input_ids:T7s1x1[29892,29892:A29892.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.046197891235352,8.979988098144531:A-6.355539239617064],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.851004600524902,6.586821556091309:A-0.06394850127256727]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.5646868348121643:A0.0012603803639149602]]))
<- ((),dict(cache_position:T7s1[37,37:A37.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96[-6.851004600524902,6.586821556091309:A-0.06394850127256727]], value_cache=#1[T1s1x1x37x96[-0.6787744760513306,0.5646868348121643:A0.0012603803639149602]]),input_ids:T7s1x1[372,372:A372.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.895962715148926,11.524106979370117:A-5.250418728778604],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.851004600524902,6.586821556091309:A-0.06414386183068768]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.5646868348121643:A0.0008309308986490418]]))
<- ((),dict(cache_position:T7s1[38,38:A38.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96[-6.851004600524902,6.586821556091309:A-0.06414386183068768]], value_cache=#1[T1s1x1x38x96[-0.6787744760513306,0.5646868348121643:A0.0008309308986490418]]),input_ids:T7s1x1[756,756:A756.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.512384414672852,9.491189956665039:A-6.993132332458161],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.851004600524902,6.586821556091309:A-0.06050902911104062]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.5646868348121643:A-0.00013865493480270976]]))
<- ((),dict(cache_position:T7s1[39,39:A39.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96[-6.851004600524902,6.586821556091309:A-0.06050902911104062]], value_cache=#1[T1s1x1x39x96[-0.6787744760513306,0.5646868348121643:A-0.00013865493480270976]]),input_ids:T7s1x1[385,385:A385.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-18.85635757446289,5.906864166259766:A-7.796880586587824],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.851004600524902,6.586821556091309:A-0.05756792460436676]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.5646868348121643:A0.00025918271563796225]]))
<- ((),dict(cache_position:T7s1[40,40:A40.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96[-6.851004600524902,6.586821556091309:A-0.05756792460436676]], value_cache=#1[T1s1x1x40x96[-0.6787744760513306,0.5646868348121643:A0.00025918271563796225]]),input_ids:T7s1x1[8031,8031:A8031.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.201635360717773,4.826348304748535:A-8.23904167703283],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.851004600524902,6.586821556091309:A-0.0549877954955907]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.5646868348121643:A0.00018446499429253566]]))
<- ((),dict(cache_position:T7s1[41,41:A41.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96[-6.851004600524902,6.586821556091309:A-0.0549877954955907]], value_cache=#1[T1s1x1x41x96[-0.6787744760513306,0.5646868348121643:A0.00018446499429253566]]),input_ids:T7s1x1[2058,2058:A2058.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-15.037452697753906,10.821243286132812:A-7.315198119387962],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.851004600524902,6.586821556091309:A-0.053828802398475956]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.5646868348121643:A0.0001914224410123256]]))
<- ((),dict(cache_position:T7s1[42,42:A42.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96[-6.851004600524902,6.586821556091309:A-0.053828802398475956]], value_cache=#1[T1s1x1x42x96[-0.6787744760513306,0.5646868348121643:A0.0001914224410123256]]),input_ids:T7s1x1[29889,29889:A29889.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.934222221374512,6.773665428161621:A-8.52244421484298],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.851004600524902,7.081362724304199:A-0.05387350520585071]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.5646868348121643:A0.0004560130604229136]]))
<- ((),dict(cache_position:T7s1[43,43:A43.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96[-6.851004600524902,7.081362724304199:A-0.05387350520585071]], value_cache=#1[T1s1x1x43x96[-0.6787744760513306,0.5646868348121643:A0.0004560130604229136]]),input_ids:T7s1x1[306,306:A306.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.794106483459473,12.626676559448242:A-3.4942997911535203],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.851004600524902,7.081362724304199:A-0.05344041203272465]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.5646868348121643:A0.0006162148627061359]]))
<- ((),dict(cache_position:T7s1[44,44:A44.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96[-6.851004600524902,7.081362724304199:A-0.05344041203272465]], value_cache=#1[T1s1x1x44x96[-0.6787744760513306,0.5646868348121643:A0.0006162148627061359]]),input_ids:T7s1x1[4997,4997:A4997.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-17.71812629699707,10.671961784362793:A-7.434223877867684],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.851004600524902,7.081362724304199:A-0.05234356322226059]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.5646868348121643:A0.0008026099187898379]]))
<- ((),dict(cache_position:T7s1[45,45:A45.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96[-6.851004600524902,7.081362724304199:A-0.05234356322226059]], value_cache=#1[T1s1x1x45x96[-0.6787744760513306,0.5646868348121643:A0.0008026099187898379]]),input_ids:T7s1x1[1048,1048:A1048.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.743223190307617,6.638679504394531:A-8.95858639000659],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.851004600524902,7.081362724304199:A-0.051109141329014666]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.5646868348121643:A0.0007646526549258166]]))
<- ((),dict(cache_position:T7s1[46,46:A46.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96[-6.851004600524902,7.081362724304199:A-0.051109141329014666]], value_cache=#1[T1s1x1x46x96[-0.6787744760513306,0.5646868348121643:A0.0007646526549258166]]),input_ids:T7s1x1[278,278:A278.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-16.181385040283203,3.565772294998169:A-7.133517884547822],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.851004600524902,7.081362724304199:A-0.049614630229699926]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.5646868348121643:A0.0010390768984475053]]))
<- ((),dict(cache_position:T7s1[47,47:A47.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96[-6.851004600524902,7.081362724304199:A-0.049614630229699926]], value_cache=#1[T1s1x1x47x96[-0.6787744760513306,0.5646868348121643:A0.0010390768984475053]]),input_ids:T7s1x1[298,298:A298.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-9.8783597946167,13.427725791931152:A-1.0146370281663257],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.851004600524902,7.081362724304199:A-0.0479525166055939]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.5646868348121643:A0.0016179650206361708]]))
<- ((),dict(cache_position:T7s1[48,48:A48.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96[-6.851004600524902,7.081362724304199:A-0.0479525166055939]], value_cache=#1[T1s1x1x48x96[-0.6787744760513306,0.5646868348121643:A0.0016179650206361708]]),input_ids:T7s1x1[5171,5171:A5171.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-13.05555534362793,6.818144798278809:A-6.438680187001359],past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96[-6.851004600524902,7.081362724304199:A-0.04418672443922197]], value_cache=#1[T1s1x1x49x96[-0.6787744760513306,0.5646868348121643:A0.0016891376174894843]]))
-- prompt Continue: it rains...
-- answer Continue: it rains... Let the pine on the 23rd floor of the pineal washing and a bit of it, but it looks fine, it has an interesting place. I wonder about the hitter-

Let’s restore the forward as it was.

model.forward = keep_model_forward

Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward().

with steal_forward(model):
    model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
  <- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
  -> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.

Untrained model

This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.

Let’s create an untrained model using the config file provided config.json to create an untrained model: onnx_diagnostic.torch_models.llms.get_tiny_llm(). Then let’s use it.

experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
    experiment["model"],
    experiment["inputs"],
    experiment["dynamic_shapes"],
)

Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.

print("input type before", string_type(inputs, with_shape=True))

expected_output = untrained_model(**inputs)

print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

The outputs

print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))

It works.

ExportedProgram

try:
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It failed: 8*s72 (139836397529200)is not tracked with proxy for <torch.fx.experimental.proxy_tensor._ModuleStackTracer object at 0x7f2e5f314950>

Back to the original model

Let’s use the same dummy inputs but we use the downloaded model. Dummy inputs and dynamic shapes are created by function onnx_diagnostic.torch_models.llms.get_tiny_llm().

data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]

Let’s print the inputs.

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: 'batch', 1: 'cache+seq'},
 'input_ids': {0: 'batch', 1: 'seq_length'},
 'past_key_values': [[{0: 'batch', 2: 'cache_length'}],
                     [{0: 'batch', 2: 'cache_length'}]],
 'position_ids': {0: 'batch', 1: 'cache+seq'}}

And Let’s finally export.

try:
    ep = torch.export.export(
        model,
        (),
        kwargs=cloned_inputs,
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,
    )
    print("It worked:")
    print(ep)
except Exception as e:
    # To work, it needs at least PRs:
    # * https://github.com/huggingface/transformers/pull/36311
    # * https://github.com/huggingface/transformers/pull/36652
    print("It failed:", e)
It failed: 8*s72 (139836400388672)is not tracked with proxy for <torch.fx.experimental.proxy_tensor._ModuleStackTracer object at 0x7f2e5f65b260>

If you have any error, then look at example Export Tiny-LLM with patches.

doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")
plot export tiny llm

Total running time of the script: (0 minutes 2.906 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Test the export on untrained models

Test the export on untrained models

Gallery generated by Sphinx-Gallery