Note
Go to the end to download the full example code.
Export a model through method generate (with Tiny-LLM)¶
The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. Example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM) gives details on how to guess dummy inputs and dynamic shapes to do so. Let’s see how to simplify that.
Dummy Example¶
Let’s use the example provided on arnir0/Tiny-LLM.
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import method_to_onnx
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
def generate_text(
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
):
inputs = tokenizer.encode(prompt, return_tensors="pt")
outputs = model.generate(
inputs,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Define your prompt
prompt = "Continue: it rains..."
generated_text = generate_text(prompt, model, tokenizer)
print("-----------------")
print(generated_text)
print("-----------------")
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 6482.70it/s, Materializing param=lm_head.weight]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 1875.81it/s, Materializing param=lm_head.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 446.35it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 410.50it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 329.55it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 319.92it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 416.59it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 411.37it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 505.17it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 499.45it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 588.59it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 572.48it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 643.86it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 631.84it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 705.86it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 698.92it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 776.24it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 769.35it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 844.50it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 837.70it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 910.94it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 904.30it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 975.78it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 968.68it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 955.44it/s, Materializing param=model.norm.weight]
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
-----------------
Continue: it rains... The World War II: 6-68.
Knook and the Warpaced 110-125-1987 (R-0.1-0
-----------------
Replace forward method¶
We now modify the model to export the model by replacing the forward method.
We still call method generate but this one will call a different function
created by onnx_diagnostic.export.api.method_to_onnx().
This one captured the inputs of the forward method, 2 calls are needed or
at least, 3 are recommended for LLMs as the first call does not contain any cache.
If the default settings do not work, skip_kwargs_names and dynamic_shapes
can be changed to remove some undesired inputs or add more dynamic dimensions.
filename = "plot_export_tiny_llm_method_generate.onnx"
forward_replacement = method_to_onnx(
model,
method_name="forward", # default value
exporter="custom", # onnx-dynamo to use the official exporter
filename=filename, # onnx file to create
patch_kwargs=dict(patch_transformers=True), # patches before eporting
# to see the progress, it is recommended on the first try to see
# how to set ``skip_kwargs_names`` and ``dynamic_shapes`` if it is needed
verbose=1,
# triggers the ONNX conversion after 3 calls to forward method,
# the onnx version is triggered with the last one,
# the others are used to infer the dynamic shape if they are not
# specified below
convert_after_n_calls=3,
# skips the following inputs even though they are captured,
# these ones are filled with default values we don't want in
# the onnx model
skip_kwargs_names={"kwargs", "use_cache", "return_dict", "inputs_embeds"},
# dynamic shape can be inferred from at least two calls to the forward method,
# 3 is better for LLMs, you can see the inference results with ``verbose=1``,
# this parameter is used to overwrite the inferred values,
# this is usually needed because the inferred dynamic shapes contains
# less dynamic dimension than requested.
dynamic_shapes={
"cache_position": {0: "total_sequence_length"},
"past_key_values": [
{0: "batch_size", 2: "past_sequence_length"},
{0: "batch_size", 2: "past_sequence_length"},
],
"input_ids": {0: "batch_size", 1: "sequence_length"},
},
)
The lambda function cannot be skipped as forward_replacement is a module.
print(f"type(forward_replacement)={type(forward_replacement)}")
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
type(forward_replacement)=<class 'onnx_diagnostic.export.api._WrapperToExportMethodToOnnx'>
Let’s call generate again.
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
[method_to_onnx] input[0]: ((),dict(cache_position:T7s8,input_ids:T7s1x8))
[method_to_onnx] input[1]: ((),dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),input_ids:T7s1x1))
[method_to_onnx] input[2]: ((),dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1))
[method_to_onnx] export args=()
[method_to_onnx] export kwargs=dict(cache_position:T7s1,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),input_ids:T7s1x1)
[method_to_onnx] dynamic_shapes=#1[dict(cache_position:{0:DYN(total_sequence_length)},past_key_values:#2[{0:DYN(batch_size),2:DYN(past_sequence_length)},{0:DYN(batch_size),2:DYN(past_sequence_length)}],input_ids:{0:DYN(batch_size),1:DYN(sequence_length)})]
[to_onnx] build the graph module from <class 'onnx_diagnostic.export.api._WrapperToExportMethodToOnnx._convert_method_to_onnx.<locals>.WrapWithExactSignature'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'cache_position': {0: 'total_sequence_length'}, 'past_key_values': [{0: 'batch_size', 2: 'past_sequence_length'}, {0: 'batch_size', 2: 'past_sequence_length'}], 'input_ids': {0: 'batch_size', 1: 'sequence_length'}}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(cache_position:T7r1,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),input_ids:T7r2)
[_make_builder_interpreter] dynamic_shapes={'cache_position': {0: 'total_sequence_length'}, 'past_key_values': [{0: 'batch_size', 2: 'past_sequence_length'}, {0: 'batch_size', 2: 'past_sequence_length'}], 'input_ids': {0: 'batch_size', 1: 'sequence_length'}}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'WrapWithExactSignature'
[ExportOptions.export] aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'cache_position': {0: 'total_sequence_length'}, 'past_key_values': [{0: 'batch_size', 2: 'past_sequence_length'}, {0: 'batch_size', 2: 'past_sequence_length'}], 'input_ids': {0: 'batch_size', 1: 'sequence_length'}}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(cache_position:T7r1,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),input_ids:T7r2)
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious={'cache_position': {0: 'd=[1]'}, 'past_key_values': [{0: 'd=[1]'}, {0: 'd=[1]'}], 'input_ids': {0: 'd=[1]', 1: 'd=[1]'}}
[torch_export] export starts with backed_size_oblivious={'cache_position': {0: 'd=[1]'}, 'past_key_values': [{0: 'd=[1]'}, {0: 'd=[1]'}], 'input_ids': {0: 'd=[1]', 1: 'd=[1]'}}
[ExportOptions.export] export done in 2.819908008998027
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[ExportOptions.export] slices: 9 slices nodes were removed
[CustomTracer.remove_inplace] starts with 131 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 2 inplace nodes
[CustomTracer.remove_inplace] S2: 2 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 100 iterations and 123 nodes (n_inplace=2)
[ExportOptions.export] inplaces: 2 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.004581423010677099, modified=2
[ExportOptions.export] done with no decomposition in 0.0047416860033990815
[to_onnx] graph module done in 3.2434142760030227 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
0%| | 0/123 [00:00<?, ?it/s]
56%|█████▌ | 69/123 [00:00<00:00, 680.73it/s]
100%|██████████| 123/123 [00:00<00:00, 570.89it/s]
[to_onnx] 165 onnx nodes done in 0.3251985069946386 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True, True]
[GraphBuilder-QRW.inline_functions] begin inlining graph
[GraphBuilder-QRW.inline_functions] skip_functions=set()
[GraphBuilder-QRW._inline_functions_iterations] inline function 'submod_3' domain 'local_functions' [n_replacements=1]
[GraphBuilder-QRW._inline_functions_iterations] done with 9 new nodes for 'submod_3', 'local_functions'
[GraphBuilder-QRW.inline_functions] done inlining graph 134160969948560 in 0.0065539570059627295
[GraphBuilder-QRW._add_shape_information] dynamic shapes replacements={'total_sequence_length': 'total_sequence_length', 'batch_size': 'batch_size', 'sequence_length': 'sequence_length', 'past_sequence_length': 'past_sequence_length', 's58': 'total_sequence_length', 's72': 'batch_size', 's67': 'batch_size', 's61': 'batch_size', 'batch_size^s61^batch_size^s61': 'batch_size', 's70': 'sequence_length', 's21': 'past_sequence_length', 's43': 'past_sequence_length'}
[GraphBuilder-QRW.optimize] start with 173 nodes
[GraphBuilder-QRW.optimize] #patterns=109
[GraphBuilder-QRW.optimize] start with subgraphs
[GraphBuilder-QRW.optimize] done with subgraphs
[GraphBuilderPatternOptimization-QRW.optimize] start with 143 nodes, 31 initializers, 109 patterns, priorities=[0, 1, 2, 3], max_iter=572
[GraphBuilderPatternOptimization-QRW.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-QRW.optimize] iteration 0: 143 nodes, priority=0
[GraphBuilderPatternOptimization-QRW.optimize] applies 26 matches, 8*CastPattern, 1*IdentityPattern, 3*ShapeBasedReshapeIsSqueezePattern, 1*ShapeBasedStaticExpandPattern, 2*ShapeBasedEditDistanceReshapePattern, 4*SameChildrenPattern, 1*SameChildrenFromInputPattern, 2*SqueezeAddPattern, 1*SqueezeUnsqueezePattern, 2*UnsqueezeUnsqueezePattern, 1*FunctionAttentionPattern - time=0.023 | max_time=GeluErfPattern:0.003
[GraphBuilderPatternOptimization-QRW.optimize] reapply {'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-QRW.optimize] n_added=6, n_removed=8, n_applied=28 applied patterns, 105 nodes left with 2 iterations
[GraphBuilderPatternOptimization-QRW.optimize] increase priority to 1
[GraphBuilderPatternOptimization-QRW.optimize] iteration 1: 105 nodes, priority=1
[GraphBuilderPatternOptimization-QRW.optimize] applies 19 matches, 2*ConcatTwiceUnaryPattern, 1*ConstantToInitializerPattern, 1*ShapeBasedConcatExpandPattern, 2*SlicesSplitPattern, 1*SqueezeBinaryUnsqueezePattern, 4*SqueezeUnsqueezePattern, 2*SwapUnsqueezeTransposePattern, 2*UnsqueezeUnsqueezePattern, 1*QuickGeluPattern, 3*SimplifiedLayerNormalizationPattern - time=0.020 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization-QRW.optimize] iteration 2: 79 nodes, priority=1
[GraphBuilderPatternOptimization-QRW.optimize] applies 8 matches, 2*IdentityPattern, 1*UnsqueezeUnsqueezePattern, 2*FunctionHalfRotaryEmbeddingPattern, 3*SimplifiedLayerNormalizationMulPattern - time=0.019 | max_time=SoftmaxCrossEntropyLossCastPattern:0.003
[GraphBuilderPatternOptimization-QRW.optimize] iteration 3: 63 nodes, priority=1
[GraphBuilderPatternOptimization-QRW.optimize] applies 2 matches, 2*SkipSimplifiedLayerNormalizationPattern - time=0.009 | max_time=ShapeBasedEditDistanceReshapePattern:0.001
[GraphBuilderPatternOptimization-QRW.optimize] iteration 4: 61 nodes, priority=1
[GraphBuilderPatternOptimization-QRW.optimize] increase priority to 2
[GraphBuilderPatternOptimization-QRW.optimize] iteration 5: 61 nodes, priority=2
[GraphBuilderPatternOptimization-QRW.optimize] increase priority to 3
[GraphBuilderPatternOptimization-QRW.optimize] iteration 6: 61 nodes, priority=3
[GraphBuilderPatternOptimization-QRW.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-QRW.optimize] done after 7 iterations with 61 nodes in 0.168
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 61 nodes, 29 initializers
[OrderOptimization.shape_order] done after in 0.0006044109904905781s with changed=4 scale=29
[GraphBuilder-QRW.optimize] done with 61 nodes in 0.196
[GraphBuilder-QRW.to_onnx] make_model 33 inits 12 params
[GraphBuilder-QRW.time_evaluation_constants_] 0.0007528060086769983
[GraphBuilder-QRW._build_initializers] start with 33 initializers, large_model=True, external_threshold=1024
[GraphBuilder-QRW._build_initializers] switch low/high order
[GraphBuilder-QRW._build_initializers] done in 3.9000005926936865e-06s with 29 initializers, 9 large initializers
[GraphBuilder-QRW._add_shape_information] dynamic shapes replacements={'total_sequence_length': 'total_sequence_length', 'batch_size': 'batch_size', 'sequence_length': 'sequence_length', 'past_sequence_length': 'past_sequence_length', 's58': 'total_sequence_length', 's72': 'batch_size', 's67': 'batch_size', 's61': 'batch_size', 'batch_size^s61^batch_size^s61': 'batch_size', 's70': 'sequence_length', 's21': 'past_sequence_length', 's43': 'past_sequence_length'}
[to_onnx] to_onnx done in 0.22395322599913925s and 61 nodes, 29 initializers, 4 inputs, 3 outputs
Continue: it rains...
I got to be sooo!
You have to come here and try out with your kids - if you have anything like this....
Now just go see the kids......
I do
doc.plot_legend("Tiny-LLM\nforward inputs\through generate", "onnx export", "tomato")

Total running time of the script: (0 minutes 13.709 seconds)
Related examples
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)