Note
Go to the end to download the full example code.
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)¶
Inputs are always dynamic with LLMs that is why dynamic shapes
needs to be specified when a LLM is exported with torch.export.export()
.
Most of the examples on HuggingFace use method
transformers.GenerationMixin.generate()
but we only want to
export the model and its method forward
.
That example shows to guess the inputs of this method even though the model
is executed through meth generate
.
We focus on the model arnir0/Tiny-LLM. To avoid downloading any weights, we write a function creating a random model based on the same architecture.
Steel the forward method¶
The first step is to guess the dummy inputs. Let’s use the true model for that. We use the dummy example from the model page.
import copy
import pprint
import torch
import transformers
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.torch_models.llms import get_tiny_llm
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
We rewrite the forward method to print the cache dimension.
def _forward_(*args, _f=None, **kwargs):
assert _f is not None
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print("->", string_type(res, with_shape=True, with_min_max=True))
return res
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, **kwargs: _forward_(
*args, _f=_f, **kwargs
)
Let’s run the model.
prompt = "Continue: it rains..."
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("-- prompt", prompt)
print("-- answer", generated_text)
<- ((),dict(cache_position:T7s8[0,7:A3.5],past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8[1,29901:A6305.375],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x8x32000[-15.516718864440918,15.75765609741211:A-3.381915190983544],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]))
<- ((),dict(cache_position:T7s1[8,8:A8.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96[-5.490959167480469,6.226877689361572:A-0.11321351693110653]], value_cache=#1[T1s1x1x8x96[-0.6787744760513306,0.49568021297454834:A0.007227749521139988]]),input_ids:T7s1x1[1323,1323:A1323.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-11.716479301452637,17.75351333618164:A-2.489861077328678],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.12675475233158473]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.005567598636153653]]))
<- ((),dict(cache_position:T7s1[9,9:A9.0],past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96[-5.490959167480469,6.226877689361572:A-0.12675475233158473]], value_cache=#1[T1s1x1x9x96[-0.6787744760513306,0.49568021297454834:A0.005567598636153653]]),input_ids:T7s1x1[14150,14150:A14150.0],inputs_embeds:None,use_cache:bool=True,return_dict:bool=True))
-> CausalLMOutputWithPast(logits:T1s1x1x32000[-14.882962226867676,7.9873762130737305:A-7.374863562272396],past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96[-5.490959167480469,6.226877689361572:A-0.1322267666200787]], value_cache=#1[T1s1x1x10x96[-0.6787744760513306,0.49568021297454834:A0.006777827040893195]]))
-- prompt Continue: it rains...
-- answer Continue: it rains...Continue
Let’s restore the forward as it was.
model.forward = keep_model_forward
Another syntax with onnx_diagnostic.helpers.torch_helper.steal_forward()
.
with steal_forward(model):
model.generate(inputs, max_length=50, temperature=1, top_k=50, top_p=0.95, do_sample=True)
+ -- stolen forward for class LlamaForCausalLM -- iteration 0
<- args=() --- kwargs=dict(cache_position:T7s8,past_key_values:DynamicCache(key_cache=#0[], value_cache=#0[]),input_ids:T7s1x8,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x8x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 1
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 2
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 3
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x10x96], value_cache=#1[T1s1x1x10x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 4
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x11x96], value_cache=#1[T1s1x1x11x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 5
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x12x96], value_cache=#1[T1s1x1x12x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 6
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 7
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 8
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x15x96], value_cache=#1[T1s1x1x15x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 9
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x16x96], value_cache=#1[T1s1x1x16x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 10
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x17x96], value_cache=#1[T1s1x1x17x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 11
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x18x96], value_cache=#1[T1s1x1x18x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 12
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x19x96], value_cache=#1[T1s1x1x19x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 13
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x20x96], value_cache=#1[T1s1x1x20x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 14
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x21x96], value_cache=#1[T1s1x1x21x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 15
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x22x96], value_cache=#1[T1s1x1x22x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 16
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x23x96], value_cache=#1[T1s1x1x23x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 17
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x24x96], value_cache=#1[T1s1x1x24x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 18
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x25x96], value_cache=#1[T1s1x1x25x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 19
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x26x96], value_cache=#1[T1s1x1x26x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 20
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x27x96], value_cache=#1[T1s1x1x27x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 21
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x28x96], value_cache=#1[T1s1x1x28x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 22
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x29x96], value_cache=#1[T1s1x1x29x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 23
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x30x96], value_cache=#1[T1s1x1x30x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 24
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x31x96], value_cache=#1[T1s1x1x31x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 25
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x32x96], value_cache=#1[T1s1x1x32x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 26
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x33x96], value_cache=#1[T1s1x1x33x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 27
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x34x96], value_cache=#1[T1s1x1x34x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 28
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x35x96], value_cache=#1[T1s1x1x35x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 29
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x36x96], value_cache=#1[T1s1x1x36x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 30
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x37x96], value_cache=#1[T1s1x1x37x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 31
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x38x96], value_cache=#1[T1s1x1x38x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 32
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x39x96], value_cache=#1[T1s1x1x39x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 33
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x40x96], value_cache=#1[T1s1x1x40x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 34
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x41x96], value_cache=#1[T1s1x1x41x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 35
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x42x96], value_cache=#1[T1s1x1x42x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 36
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x43x96], value_cache=#1[T1s1x1x43x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 37
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x44x96], value_cache=#1[T1s1x1x44x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 38
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x45x96], value_cache=#1[T1s1x1x45x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 39
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x46x96], value_cache=#1[T1s1x1x46x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 40
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x47x96], value_cache=#1[T1s1x1x47x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]))
-.
+ -- stolen forward for class LlamaForCausalLM -- iteration 41
<- args=() --- kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x48x96], value_cache=#1[T1s1x1x48x96]),input_ids:T7s1x1,inputs_embeds:None,use_cache:bool,return_dict:bool)
-> CausalLMOutputWithPast(logits:T1s1x1x32000,past_key_values:DynamicCache(key_cache=#1[T1s1x1x49x96], value_cache=#1[T1s1x1x49x96]))
-.
Untrained model¶
This part can skipped if you are only interested in exporting the original model. It is useful to create a unit test to ensure a specific architecture can be exported despite the many changes brought to torch or transformers.
Let’s create an untrained model using the config file provided
config.json
to create an untrained model:
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
Then let’s use it.
experiment = get_tiny_llm()
untrained_model, inputs, dynamic_shapes = (
experiment["model"],
experiment["inputs"],
experiment["dynamic_shapes"],
)
Before we run it, we make a copy of the inputs as the cache get modified by the execution. Then it is no longer valid associated with the previous input_ids and mask.
print("input type before", string_type(inputs, with_shape=True))
expected_output = untrained_model(**inputs)
print("input type after-", string_type(inputs, with_shape=True))
input type before dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
input type after- dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
The outputs
print("result type", string_type(expected_output, with_shape=True))
result type CausalLMOutputWithPast(logits:T1s2x3x32000,past_key_values:DynamicCache(key_cache=#1[T1s2x1x33x96], value_cache=#1[T1s2x1x33x96]))
It works.
ExportedProgram¶
try:
ep = torch.export.export(
untrained_model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x7d214da0a630> not registered
Back to the original model¶
Let’s use the same dummy inputs but we use the downloaded model.
Dummy inputs and dynamic shapes are created by function
onnx_diagnostic.torch_models.llms.get_tiny_llm()
.
data = get_tiny_llm()
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
Let’s print the inputs.
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}
And Let’s finally export.
try:
ep = torch.export.export(
model,
(),
kwargs=cloned_inputs,
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
strict=False,
)
print("It worked:")
print(ep)
except Exception as e:
# To work, it needs at least PRs:
# * https://github.com/huggingface/transformers/pull/36311
# * https://github.com/huggingface/transformers/pull/36652
print("It failed:", e)
It failed: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x7d214d6c7b30> not registered
If you have any error, then look at example Export Tiny-LLM with patches.
doc.plot_legend("Tiny-LLM\nforward inputs\nbehind generate", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 2.374 seconds)
Related examples