Note
Go to the end to download the full example code.
Export a LLM with InputObserver (with Tiny-LLM)¶
The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. Example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM) gives details on how to guess dummy inputs and dynamic shapes to do so. Let’s see how to simplify that.
Dummy Example¶
Let’s use the example provided on arnir0/Tiny-LLM.
import pandas
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.rt_helper import onnx_generate
from onnx_diagnostic.torch_export_patches import (
register_additional_serialization_functions,
torch_export_patches,
)
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.investigate.input_observer import InputObserver
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
def generate_text(
prompt,
model,
tokenizer,
max_length=50,
temperature=0.01,
top_k=50,
top_p=0.95,
do_sample=True,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Define your prompt
prompt = "Continue: it rains, what should I do?"
generated_text = generate_text(prompt, model, tokenizer)
print("-----------------")
print(generated_text)
print("-----------------")
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 21845.33it/s, Materializing param=lm_head.weight]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 4185.93it/s, Materializing param=lm_head.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 251.95it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 240.75it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 341.15it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 333.10it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 426.84it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 419.75it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 514.44it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 509.86it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 587.07it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 580.89it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 667.09it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 662.47it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 746.48it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 741.34it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 819.79it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 813.76it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 890.98it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 884.87it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 929.29it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 917.72it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 990.14it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 983.89it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 973.80it/s, Materializing param=model.norm.weight]
-----------------
Continue: it rains, what should I do?
I have a lot of people who are in the world. I have a lot of people who are in the world, and I have a lot of people who are in the world
-----------------
Replace forward method¶
We first capture inputs and outputs with an :class`InputObserver <onnx_diagnostic.investigate.input_observer>`. We also need to registers additional patches for transformers. Then pytorch knows how to flatten/unflatten inputs.
observer = InputObserver()
with register_additional_serialization_functions(patch_transformers=True), observer(model):
generate_text(prompt, model, tokenizer)
print(f"number of stored inputs: {len(observer.info)}")
number of stored inputs: 3
Exports¶
The InputObserver has now enough data to infer arguments and dynamic shapes. We need more than serialization but also patches to export the model. Inferred dynamic shapes looks like:
print(observer.infer_dynamic_shapes(set_batch_dimension_for=True))
{'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'logits_to_keep': None}
and inferred arguments:
print(string_type(observer.infer_arguments(), with_shape=True))
dict(input_ids:T7s1x13,attention_mask:T7s1x13,position_ids:T7s1x13,past_key_values:DynamicCache(key_cache=#1[T1s1x1x0x96], value_cache=#1[T1s1x1x0x96]),cache_position:T7s13,logits_to_keep:int)
Let’s export.
filenamec = "plot_export_tiny_llm_input_observer.custom.onnx"
with torch_export_patches(patch_transformers=True):
to_onnx(
model,
(),
kwargs=observer.infer_arguments(),
dynamic_shapes=observer.infer_dynamic_shapes(set_batch_dimension_for=True),
filename=filenamec,
exporter="custom",
)
Check discrepancies¶
The model is exported into ONNX. We use again the stored inputs and outputs to verify the model produces the same outputs.
data = observer.check_discrepancies(filenamec, progress_bar=True)
print(pandas.DataFrame(data))
0%| | 0/3 [00:00<?, ?it/s]
33%|███▎ | 1/3 [00:00<00:00, 3.88it/s]
100%|██████████| 3/3 [00:00<00:00, 11.28it/s]
abs rel sum n ... n_empty inputs outputs_torch outputs_ort
0 0.000018 0.000843 0.176129 34496.0 ... 2 dict(input_ids:T7s1x13,attention_mask:T7s1x13,... #3[T1s1x1x32000,T1s1x1x13x96,T1s1x1x13x96] #3[T1s1x1x32000,T1s1x1x13x96,T1s1x1x13x96]
1 0.000010 0.001482 0.063151 34688.0 ... 0 dict(input_ids:T7s1x1,attention_mask:T7s1x14,p... #3[T1s1x1x32000,T1s1x1x14x96,T1s1x1x14x96] #3[T1s1x1x32000,T1s1x1x14x96,T1s1x1x14x96]
2 0.000011 0.003134 0.062501 34880.0 ... 0 dict(input_ids:T7s1x1,attention_mask:T7s1x15,p... #3[T1s1x1x32000,T1s1x1x15x96,T1s1x1x15x96] #3[T1s1x1x32000,T1s1x1x15x96,T1s1x1x15x96]
[3 rows x 18 columns]
Minimal script to export a LLM¶
The following lines are a condensed copy with less comments.
# from HuggingFace
print("----------------")
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# from HuggingFace again
prompt = "Continue: it rains, what should I do?"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=False,
)
observer = InputObserver()
with register_additional_serialization_functions(patch_transformers=True), observer(model):
generate_text(prompt, model, tokenizer)
filename = "plot_export_tiny_llm_input_observer.onnx"
with torch_export_patches(patch_transformers=True):
torch.onnx.export(
model,
(),
filename,
kwargs=observer.infer_arguments(),
dynamic_shapes=observer.infer_dynamic_shapes(set_batch_dimension_for=True),
)
data = observer.check_discrepancies(filename, progress_bar=True)
print(pandas.DataFrame(data))
----------------
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 25731.93it/s, Materializing param=lm_head.weight]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 2552.83it/s, Materializing param=lm_head.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 191.05it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 186.80it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 267.53it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 263.94it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 344.10it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 340.85it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 420.90it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 418.22it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 496.92it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 493.57it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 570.69it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 567.33it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 639.97it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 635.69it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 708.51it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 704.78it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 774.90it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 770.30it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 840.53it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 836.23it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 904.89it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 900.18it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 891.90it/s, Materializing param=model.norm.weight]
[torch.onnx] Obtain model graph for `LlamaForCausalLM([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `LlamaForCausalLM([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decompositions...
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
[torch.onnx] Run decompositions... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
[torch.onnx] Optimize the ONNX graph...
Applied 30 of general pattern rewrite rules.
[torch.onnx] Optimize the ONNX graph... ✅
0%| | 0/3 [00:00<?, ?it/s]
33%|███▎ | 1/3 [00:00<00:00, 4.19it/s]
100%|██████████| 3/3 [00:00<00:00, 12.28it/s]
abs rel sum n ... n_empty inputs outputs_torch outputs_ort
0 0.000010 0.000650 0.061900 34496.0 ... 2 dict(input_ids:T7s1x13,attention_mask:T7s1x13,... #3[T1s1x1x32000,T1s1x1x13x96,T1s1x1x13x96] #3[T1s1x1x32000,T1s1x1x13x96,T1s1x1x13x96]
1 0.000011 0.000800 0.063265 34688.0 ... 0 dict(input_ids:T7s1x1,attention_mask:T7s1x14,p... #3[T1s1x1x32000,T1s1x1x14x96,T1s1x1x14x96] #3[T1s1x1x32000,T1s1x1x14x96,T1s1x1x14x96]
2 0.000011 0.001836 0.051332 34880.0 ... 0 dict(input_ids:T7s1x1,attention_mask:T7s1x15,p... #3[T1s1x1x32000,T1s1x1x15x96,T1s1x1x15x96] #3[T1s1x1x32000,T1s1x1x15x96,T1s1x1x15x96]
[3 rows x 18 columns]
ONNX Prompt¶
onnx_tokens = onnx_generate(
filenamec,
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
eos_token_id=model.config.eos_token_id,
max_new_tokens=50,
)
onnx_generated_text = tokenizer.decode(onnx_tokens, skip_special_tokens=True)
print("-----------------")
print("\n".join(onnx_generated_text))
print("-----------------")
-----------------
Continue: it rains, what should I do?
I have a lot of people who are in the world. I have a lot of people who are in the world, and I have a lot of people who are in the world. I have a lot of people who are in the world,
-----------------

Total running time of the script: (0 minutes 35.403 seconds)
Related examples
Export a LLM through method generate (with Tiny-LLM)