Note
Go to the end to download the full example code.
Export a LLM through method generate (with Tiny-LLM)¶
The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. Example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM) gives details on how to guess dummy inputs and dynamic shapes to do so. Let’s see how to simplify that.
Dummy Example¶
Let’s use the example provided on arnir0/Tiny-LLM.
import pandas
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import method_to_onnx
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
def generate_text(
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Define your prompt
prompt = "Continue: it rains..."
generated_text = generate_text(prompt, model, tokenizer)
print("-----------------")
print(generated_text)
print("-----------------")
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 3063.77it/s, Materializing param=lm_head.weight]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 980.66it/s, Materializing param=lm_head.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 241.27it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 232.86it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 334.23it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 325.45it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 421.09it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 414.75it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 508.01it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 504.29it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 597.98it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 593.39it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 683.89it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 678.93it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 766.87it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 761.89it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 847.71it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 842.17it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 925.59it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 919.72it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 1001.01it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 994.90it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 1069.82it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 1061.96it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 1018.43it/s, Materializing param=model.norm.weight]
-----------------
Continue: it rains... Continue
Posted 41:34 PM - March 26, 2012 | Comments
2000-16-06-0202
-----------------
Replace forward method¶
We now modify the model to export the model by replacing the forward method.
We still call method generate but this one will call a different function
created by onnx_diagnostic.export.api.method_to_onnx().
This one captured the inputs of the forward method, 2 calls are needed or
at least, 3 are recommended for LLMs as the first call does not contain any cache.
If the default settings do not work, skip_kwargs_names and dynamic_shapes
can be changed to remove some undesired inputs or add more dynamic dimensions.
filename = "plot_export_tiny_llm_method_generate.onnx"
forward_replacement = method_to_onnx(
model,
method_name="forward", # default value
exporter="custom", # onnx-dynamo to use the official exporter
filename=filename, # onnx file to create
patch_kwargs=dict(patch_transformers=True), # patches before eporting
# to see the progress, it is recommended on the first try to see
# how to set ``skip_kwargs_names`` and ``dynamic_shapes`` if it is needed
verbose=1,
# triggers the ONNX conversion after 3 calls to forward method,
# the onnx version is triggered with the last one,
# the others are used to infer the dynamic shapes if they are not
# specified below
convert_after_n_calls=3,
# The input used in the example has a batch size equal to 1, all
# inputs going through method forward will have the same batch size.
# To force the dynamism of this dimension, we need to indicate
# which inputs has a batch size.
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
)
dynamic shapes can be inferred from at least two calls to the forward method,
3 is better for LLMs (first call is prefill, cache is missing),
you can see the inference results with verbose=1.
If the value is not the expected one (to change the names for example),
They can be overwritten.
dynamic_shapes={
"cache_position": {0: "sequence_length"},
"past_key_values": [
{0: "batch_size", 2: "past_sequence_length"},
{0: "batch_size", 2: "past_sequence_length"},
],
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
}
Finally, we need to replace the forward method.
As forward_replacement is a module of type
onnx_diagnostic.export.api.WrapperToExportMethodToOnnx,
a lambda function must be used to avoid this one to be
included as a submodule (and create an infinite loop).
print(f"type(forward_replacement)={type(forward_replacement)}")
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
type(forward_replacement)=<class 'onnx_diagnostic.export.api.WrapperToExportMethodToOnnx'>
Let’s call generate again. The conversion is triggered after
convert_after_n_calls=3 calls to the method forward,
which exactly what the method generate is doing.
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)
[method_to_onnx] input[0]: ((),dict(input_ids:T7s1x8,attention_mask:T7s1x8,cache_position:T7s8))
[method_to_onnx] input[1]: ((),dict(input_ids:T7s1x1,attention_mask:T7s1x9,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),cache_position:T7s1))
[method_to_onnx] input[2]: ((),dict(input_ids:T7s1x1,attention_mask:T7s1x10,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),cache_position:T7s1))
[method_to_onnx] save 3 inputs in 'plot_export_tiny_llm_method_generate.inputs.pt'
[method_to_onnx] guess_dynamic_shapes=((),dict(input_ids:{1:DYNAMIC},attention_mask:{1:DYNAMIC},past_key_values:#2[{2:DYNAMIC},{2:DYNAMIC}],cache_position:{0:DYNAMIC}))
[method_to_onnx.rename_dynamic_shapes] apply pattern shapes 'LLM.text'
[method_to_onnx] dynamic_batch_for={'attention_mask', 'input_ids', 'past_key_values'}
[method_to_onnx] dynamic_shapes with batch=((), {'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}})
[method_to_onnx] export args=()
[method_to_onnx] export kwargs=dict(input_ids:T7s1x1,attention_mask:T7s1x10,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),cache_position:T7s1)
[method_to_onnx] dynamic_shapes=((),dict(input_ids:{0:DYN(batch),1:DYN(seqlength)},attention_mask:{0:DYN(batch),1:DYN(totallength)},past_key_values:#2[{0:DYN(batch),2:DYN(pastlength)},{0:DYN(batch),2:DYN(pastlength)}],cache_position:{0:DYN(seqlength)}))
[to_onnx] build the graph module from <class 'onnx_diagnostic.export.api.WrapperToExportMethodToOnnx._convert_method_to_onnx.<locals>.WrapWithExactSignature'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),cache_position:T7r1)
[_make_builder_interpreter] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'WrapWithExactSignature'
[ExportOptions.export] aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),cache_position:T7r1)
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious={'input_ids': {0: 'd=[1]', 1: 'd=[1]'}, 'attention_mask': {0: 'd=[1]'}, 'past_key_values': [{0: 'd=[1]'}, {0: 'd=[1]'}], 'cache_position': {0: 'd=[1]'}}
[torch_export] export starts with backed_size_oblivious={'input_ids': {0: 'd=[1]', 1: 'd=[1]'}, 'attention_mask': {0: 'd=[1]'}, 'past_key_values': [{0: 'd=[1]'}, {0: 'd=[1]'}], 'cache_position': {0: 'd=[1]'}}
[ExportOptions.export] export done in 1.637513773000137
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[ExportOptions.export] slices: 9 slices nodes were removed
[CustomTracer.remove_inplace] starts with 170 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 13 inplace nodes
[CustomTracer.remove_inplace] S2: 6 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 100 iterations and 141 nodes (n_inplace=6)
[ExportOptions.export] inplaces: 13 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.004672862000006717, modified=13
[ExportOptions.export] done with no decomposition in 0.004789915000401379
[to_onnx] graph module done in 1.6521961380003631 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
0%| | 0/141 [00:00<?, ?it/s]
79%|███████▉ | 112/141 [00:00<00:00, 1106.97it/s]
100%|██████████| 141/141 [00:00<00:00, 1137.78it/s]
[to_onnx] 203 onnx nodes done in 0.20099439499972505 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True, True]
[GraphBuilder-QOM.inline_functions] begin inlining graph
[GraphBuilder-QOM.inline_functions] skip_functions=set()
[GraphBuilder-QOM._inline_functions_iterations] inline function 'submod_3' domain 'local_functions' [n_replacements=1]
[GraphBuilder-QOM._inline_functions_iterations] done with 9 new nodes for 'submod_3', 'local_functions'
[GraphBuilder-QOM.inline_functions] done inlining graph 128532838118320 in 0.0038044259999878705
[GraphBuilder-QOM._add_shape_information] dynamic shapes replacements={'seqlength': 'seqlength', 'batch': 'batch', 'pastlength': 'pastlength', 'totallength': 'totallength', 's61': 'batch', 's43': 'batch', 'batch^s61^batch^s61': 'batch', 's67': 'batch', 's72': 'batch', 's70': 'seqlength', 'Max(s58,s70)': 'seqlength', 's58': 'seqlength', 's53': 'totallength', 's44': 'pastlength', 's21': 'pastlength'}
[GraphBuilder-QOM.optimize] start with 211 nodes
[GraphBuilder-QOM.optimize] #patterns=109
[GraphBuilder-QOM.optimize] start with subgraphs
[GraphBuilder-QOM.optimize] done with subgraphs
[GraphBuilderPatternOptimization-QOM.optimize] start with 161 nodes, 32 initializers, 109 patterns, priorities=[0, 1, 2, 3], max_iter=644
[GraphBuilderPatternOptimization-QOM.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-QOM.optimize] iteration 0: 161 nodes, priority=0
[GraphBuilderPatternOptimization-QOM.optimize] applies 33 matches, 11*CastPattern, 2*IdentityPattern, 3*ShapeBasedReshapeIsSqueezePattern, 2*ShapeBasedStaticExpandPattern, 3*ShapeBasedEditDistanceReshapePattern, 4*SameChildrenPattern, 1*SameChildrenFromInputPattern, 2*SqueezeAddPattern, 1*SqueezeUnsqueezePattern, 3*UnsqueezeUnsqueezePattern, 1*FunctionAttentionPattern - time=0.013 | max_time=GeluOrtPattern:0.002
[GraphBuilderPatternOptimization-QOM.optimize] reapply {'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-QOM.optimize] n_added=6, n_removed=8, n_applied=35 applied patterns, 110 nodes left with 2 iterations
[GraphBuilderPatternOptimization-QOM.optimize] increase priority to 1
[GraphBuilderPatternOptimization-QOM.optimize] iteration 1: 110 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] applies 16 matches, 2*ConcatTwiceUnaryPattern, 1*ConstantToInitializerPattern, 1*IdentityPattern, 2*SlicesSplitPattern, 1*SqueezeBinaryUnsqueezePattern, 2*SwapUnsqueezeTransposePattern, 3*UnsqueezeUnsqueezePattern, 1*QuickGeluPattern, 3*SimplifiedLayerNormalizationPattern - time=0.016 | max_time=GeluOrtPattern:0.002
[GraphBuilderPatternOptimization-QOM.optimize] iteration 2: 87 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] applies 8 matches, 2*IdentityPattern, 1*UnsqueezeUnsqueezePattern, 2*FunctionHalfRotaryEmbeddingPattern, 3*SimplifiedLayerNormalizationMulPattern - time=0.012 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-QOM.optimize] iteration 3: 71 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] applies 2 matches, 2*SkipSimplifiedLayerNormalizationPattern - time=0.007 | max_time=ShapeBasedEditDistanceReshapePattern:0.001
[GraphBuilderPatternOptimization-QOM.optimize] iteration 4: 69 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] increase priority to 2
[GraphBuilderPatternOptimization-QOM.optimize] iteration 5: 69 nodes, priority=2
[GraphBuilderPatternOptimization-QOM.optimize] increase priority to 3
[GraphBuilderPatternOptimization-QOM.optimize] iteration 6: 69 nodes, priority=3
[GraphBuilderPatternOptimization-QOM.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-QOM.optimize] done after 7 iterations with 69 nodes in 0.114
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 69 nodes, 31 initializers
[OrderOptimization.shape_order] done after in 0.00037146499926166143s with changed=5 scale=25
[GraphBuilder-QOM.optimize] done with 69 nodes in 0.131
[GraphBuilder-QOM.to_onnx] make_model 35 inits 12 params
[GraphBuilder-QOM.time_evaluation_constants_] 0.0009111650006161653
[GraphBuilder-QOM._build_initializers] start with 35 initializers, large_model=True, external_threshold=1024
[GraphBuilder-QOM._build_initializers] switch low/high order
[GraphBuilder-QOM._build_initializers] done in 2.322000000276603e-06s with 31 initializers, 9 large initializers
[GraphBuilder-QOM._add_shape_information] dynamic shapes replacements={'seqlength': 'seqlength', 'batch': 'batch', 'pastlength': 'pastlength', 'totallength': 'totallength', 's61': 'batch', 's43': 'batch', 'batch^s61^batch^s61': 'batch', 's67': 'batch', 's72': 'batch', 's70': 'seqlength', 'Max(s58,s70)': 'seqlength', 's58': 'seqlength', 's53': 'totallength', 's44': 'pastlength', 's21': 'pastlength'}
[to_onnx] to_onnx done in 0.1486357189996852s and 69 nodes, 31 initializers, 5 inputs, 3 outputs
[method_to_onnx] save 3 outputs in 'plot_export_tiny_llm_method_generate.outputs.pt'
Continue: it rains...
We have been in crisis
Tarally I have been working on the first week at the first one.
Sat: “Besides the next day we will
them all the best
We finally need to check the discrepancies. The exports produced an onnx file and dumped the input and output of the torch model. We now run the onnx model to check it produces the same results. It is done after because the model may not hold twice in memory (torch and onnxruntime). verbose=2 shows more information about expected outputs.
data = forward_replacement.check_discrepancies(verbose=1)
df = pandas.DataFrame(data)
print(df)
[method_to_onnx.check_discrepancies] register classes [<class 'transformers.cache_utils.DynamicCache'>, <class 'transformers.cache_utils.DynamicLayer'>, <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>, <class 'torch.dtype'>]
[method_to_onnx.check_discrepancies] load 'plot_export_tiny_llm_method_generate.inputs.pt'
[method_to_onnx.check_discrepancies] load 'plot_export_tiny_llm_method_generate.outputs.pt'
[method_to_onnx.check_discrepancies] create onnx session 'plot_export_tiny_llm_method_generate.onnx'
[method_to_onnx.check_discrepancies] input_names=['input_ids', 'attention_mask', 'past_key_values_key_0', 'past_key_values_value_0', 'cache_position']
[method_to_onnx.check_discrepancies] onnx_shapes=INT64[batch,seqlength], INT64[batch,totallength], FLOAT[batch,1,pastlength,96], FLOAT[batch,1,pastlength,96], INT64[seqlength]
[method_to_onnx.check_discrepancies] process input 0 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] process input 1 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] process input 2 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] done
abs rel sum n dnan dev >0.1 >0.01 SUCCESS index duration_torch ort_duration n_inputs
0 0.000016 0.003386 0.482308 257536.0 0 0 0 0 True 0 0.002885 1.846359 5
1 0.000013 0.003116 0.106110 33728.0 0 0 0 0 True 1 0.002364 0.001560 5
2 0.000010 0.001086 0.051389 33920.0 0 0 0 0 True 2 0.005172 0.001325 5

Total running time of the script: (0 minutes 11.585 seconds)
Related examples
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)