Export a LLM through method generate (with Tiny-LLM)

The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. Example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM) gives details on how to guess dummy inputs and dynamic shapes to do so. Let’s see how to simplify that.

Dummy Example

Let’s use the example provided on arnir0/Tiny-LLM.

import pandas
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import method_to_onnx


MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)


def generate_text(
    prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

    # Define your prompt


prompt = "Continue: it rains..."
generated_text = generate_text(prompt, model, tokenizer)
print("-----------------")
print(generated_text)
print("-----------------")
Loading weights:   0%|          | 0/12 [00:00<?, ?it/s]
Loading weights:   8%|▊         | 1/12 [00:00<00:00, 3063.77it/s, Materializing param=lm_head.weight]
Loading weights:   8%|▊         | 1/12 [00:00<00:00, 980.66it/s, Materializing param=lm_head.weight]
Loading weights:  17%|█▋        | 2/12 [00:00<00:00, 241.27it/s, Materializing param=model.embed_tokens.weight]
Loading weights:  17%|█▋        | 2/12 [00:00<00:00, 232.86it/s, Materializing param=model.embed_tokens.weight]
Loading weights:  25%|██▌       | 3/12 [00:00<00:00, 334.23it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights:  25%|██▌       | 3/12 [00:00<00:00, 325.45it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights:  33%|███▎      | 4/12 [00:00<00:00, 421.09it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights:  33%|███▎      | 4/12 [00:00<00:00, 414.75it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights:  42%|████▏     | 5/12 [00:00<00:00, 508.01it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights:  42%|████▏     | 5/12 [00:00<00:00, 504.29it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights:  50%|█████     | 6/12 [00:00<00:00, 597.98it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights:  50%|█████     | 6/12 [00:00<00:00, 593.39it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights:  58%|█████▊    | 7/12 [00:00<00:00, 683.89it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights:  58%|█████▊    | 7/12 [00:00<00:00, 678.93it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights:  67%|██████▋   | 8/12 [00:00<00:00, 766.87it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights:  67%|██████▋   | 8/12 [00:00<00:00, 761.89it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights:  75%|███████▌  | 9/12 [00:00<00:00, 847.71it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights:  75%|███████▌  | 9/12 [00:00<00:00, 842.17it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights:  83%|████████▎ | 10/12 [00:00<00:00, 925.59it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights:  83%|████████▎ | 10/12 [00:00<00:00, 919.72it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights:  92%|█████████▏| 11/12 [00:00<00:00, 1001.01it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights:  92%|█████████▏| 11/12 [00:00<00:00, 994.90it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 1069.82it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 1061.96it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 1018.43it/s, Materializing param=model.norm.weight]
-----------------
Continue: it rains... Continue
Posted 41:34 PM - March 26, 2012 | Comments
2000-16-06-0202
-----------------

Replace forward method

We now modify the model to export the model by replacing the forward method. We still call method generate but this one will call a different function created by onnx_diagnostic.export.api.method_to_onnx(). This one captured the inputs of the forward method, 2 calls are needed or at least, 3 are recommended for LLMs as the first call does not contain any cache. If the default settings do not work, skip_kwargs_names and dynamic_shapes can be changed to remove some undesired inputs or add more dynamic dimensions.

filename = "plot_export_tiny_llm_method_generate.onnx"
forward_replacement = method_to_onnx(
    model,
    method_name="forward",  # default value
    exporter="custom",  # onnx-dynamo to use the official exporter
    filename=filename,  # onnx file to create
    patch_kwargs=dict(patch_transformers=True),  # patches before eporting
    # to see the progress, it is recommended on the first try to see
    # how to set ``skip_kwargs_names`` and ``dynamic_shapes`` if it is needed
    verbose=1,
    # triggers the ONNX conversion after 3 calls to forward method,
    # the onnx version is triggered with the last one,
    # the others are used to infer the dynamic shapes if they are not
    # specified below
    convert_after_n_calls=3,
    # The input used in the example has a batch size equal to 1, all
    # inputs going through method forward will have the same batch size.
    # To force the dynamism of this dimension, we need to indicate
    # which inputs has a batch size.
    dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
)

dynamic shapes can be inferred from at least two calls to the forward method, 3 is better for LLMs (first call is prefill, cache is missing), you can see the inference results with verbose=1. If the value is not the expected one (to change the names for example), They can be overwritten.

dynamic_shapes={
    "cache_position": {0: "sequence_length"},
    "past_key_values": [
        {0: "batch_size", 2: "past_sequence_length"},
        {0: "batch_size", 2: "past_sequence_length"},
    ],
    "input_ids": {0: "batch_size", 1: "sequence_length"},
    "attention_mask": {0: "batch_size", 1: "total_sequence_length"},
}

Finally, we need to replace the forward method. As forward_replacement is a module of type onnx_diagnostic.export.api.WrapperToExportMethodToOnnx, a lambda function must be used to avoid this one to be included as a submodule (and create an infinite loop).

print(f"type(forward_replacement)={type(forward_replacement)}")
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
type(forward_replacement)=<class 'onnx_diagnostic.export.api.WrapperToExportMethodToOnnx'>

Let’s call generate again. The conversion is triggered after convert_after_n_calls=3 calls to the method forward, which exactly what the method generate is doing.

generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)
[method_to_onnx] input[0]: ((),dict(input_ids:T7s1x8,attention_mask:T7s1x8,cache_position:T7s8))
[method_to_onnx] input[1]: ((),dict(input_ids:T7s1x1,attention_mask:T7s1x9,past_key_values:DynamicCache(key_cache=#1[T1s1x1x8x96], value_cache=#1[T1s1x1x8x96]),cache_position:T7s1))
[method_to_onnx] input[2]: ((),dict(input_ids:T7s1x1,attention_mask:T7s1x10,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),cache_position:T7s1))
[method_to_onnx] save 3 inputs in 'plot_export_tiny_llm_method_generate.inputs.pt'
[method_to_onnx] guess_dynamic_shapes=((),dict(input_ids:{1:DYNAMIC},attention_mask:{1:DYNAMIC},past_key_values:#2[{2:DYNAMIC},{2:DYNAMIC}],cache_position:{0:DYNAMIC}))
[method_to_onnx.rename_dynamic_shapes] apply pattern shapes 'LLM.text'
[method_to_onnx] dynamic_batch_for={'attention_mask', 'input_ids', 'past_key_values'}
[method_to_onnx] dynamic_shapes with batch=((), {'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}})
[method_to_onnx] export args=()
[method_to_onnx] export kwargs=dict(input_ids:T7s1x1,attention_mask:T7s1x10,past_key_values:DynamicCache(key_cache=#1[T1s1x1x9x96], value_cache=#1[T1s1x1x9x96]),cache_position:T7s1)
[method_to_onnx] dynamic_shapes=((),dict(input_ids:{0:DYN(batch),1:DYN(seqlength)},attention_mask:{0:DYN(batch),1:DYN(totallength)},past_key_values:#2[{0:DYN(batch),2:DYN(pastlength)},{0:DYN(batch),2:DYN(pastlength)}],cache_position:{0:DYN(seqlength)}))
[to_onnx] build the graph module from <class 'onnx_diagnostic.export.api.WrapperToExportMethodToOnnx._convert_method_to_onnx.<locals>.WrapWithExactSignature'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),cache_position:T7r1)
[_make_builder_interpreter] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'WrapWithExactSignature'
[ExportOptions.export] aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),cache_position:T7r1)
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious={'input_ids': {0: 'd=[1]', 1: 'd=[1]'}, 'attention_mask': {0: 'd=[1]'}, 'past_key_values': [{0: 'd=[1]'}, {0: 'd=[1]'}], 'cache_position': {0: 'd=[1]'}}
[torch_export] export starts with backed_size_oblivious={'input_ids': {0: 'd=[1]', 1: 'd=[1]'}, 'attention_mask': {0: 'd=[1]'}, 'past_key_values': [{0: 'd=[1]'}, {0: 'd=[1]'}], 'cache_position': {0: 'd=[1]'}}
[ExportOptions.export] export done in 1.637513773000137
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[ExportOptions.export] slices: 9 slices nodes were removed
[CustomTracer.remove_inplace] starts with 170 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 13 inplace nodes
[CustomTracer.remove_inplace] S2: 6 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 100 iterations and 141 nodes (n_inplace=6)
[ExportOptions.export] inplaces: 13 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.004672862000006717, modified=13
[ExportOptions.export] done with no decomposition in 0.004789915000401379
[to_onnx] graph module done in 1.6521961380003631 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)

  0%|          | 0/141 [00:00<?, ?it/s]
 79%|███████▉  | 112/141 [00:00<00:00, 1106.97it/s]
100%|██████████| 141/141 [00:00<00:00, 1137.78it/s]
[to_onnx] 203 onnx nodes done in 0.20099439499972505 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True, True]
[GraphBuilder-QOM.inline_functions] begin inlining graph
[GraphBuilder-QOM.inline_functions] skip_functions=set()
[GraphBuilder-QOM._inline_functions_iterations] inline function 'submod_3' domain 'local_functions' [n_replacements=1]
[GraphBuilder-QOM._inline_functions_iterations] done with 9 new nodes for 'submod_3', 'local_functions'
[GraphBuilder-QOM.inline_functions] done inlining graph 128532838118320 in 0.0038044259999878705
[GraphBuilder-QOM._add_shape_information] dynamic shapes replacements={'seqlength': 'seqlength', 'batch': 'batch', 'pastlength': 'pastlength', 'totallength': 'totallength', 's61': 'batch', 's43': 'batch', 'batch^s61^batch^s61': 'batch', 's67': 'batch', 's72': 'batch', 's70': 'seqlength', 'Max(s58,s70)': 'seqlength', 's58': 'seqlength', 's53': 'totallength', 's44': 'pastlength', 's21': 'pastlength'}
[GraphBuilder-QOM.optimize] start with 211 nodes
[GraphBuilder-QOM.optimize] #patterns=109
[GraphBuilder-QOM.optimize] start with subgraphs
[GraphBuilder-QOM.optimize] done with subgraphs
[GraphBuilderPatternOptimization-QOM.optimize] start with 161 nodes, 32 initializers, 109 patterns, priorities=[0, 1, 2, 3], max_iter=644
[GraphBuilderPatternOptimization-QOM.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-QOM.optimize] iteration 0: 161 nodes, priority=0
[GraphBuilderPatternOptimization-QOM.optimize] applies 33 matches, 11*CastPattern, 2*IdentityPattern, 3*ShapeBasedReshapeIsSqueezePattern, 2*ShapeBasedStaticExpandPattern, 3*ShapeBasedEditDistanceReshapePattern, 4*SameChildrenPattern, 1*SameChildrenFromInputPattern, 2*SqueezeAddPattern, 1*SqueezeUnsqueezePattern, 3*UnsqueezeUnsqueezePattern, 1*FunctionAttentionPattern - time=0.013 | max_time=GeluOrtPattern:0.002
[GraphBuilderPatternOptimization-QOM.optimize] reapply {'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-QOM.optimize] n_added=6, n_removed=8, n_applied=35 applied patterns, 110 nodes left with 2 iterations
[GraphBuilderPatternOptimization-QOM.optimize] increase priority to 1
[GraphBuilderPatternOptimization-QOM.optimize] iteration 1: 110 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] applies 16 matches, 2*ConcatTwiceUnaryPattern, 1*ConstantToInitializerPattern, 1*IdentityPattern, 2*SlicesSplitPattern, 1*SqueezeBinaryUnsqueezePattern, 2*SwapUnsqueezeTransposePattern, 3*UnsqueezeUnsqueezePattern, 1*QuickGeluPattern, 3*SimplifiedLayerNormalizationPattern - time=0.016 | max_time=GeluOrtPattern:0.002
[GraphBuilderPatternOptimization-QOM.optimize] iteration 2: 87 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] applies 8 matches, 2*IdentityPattern, 1*UnsqueezeUnsqueezePattern, 2*FunctionHalfRotaryEmbeddingPattern, 3*SimplifiedLayerNormalizationMulPattern - time=0.012 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-QOM.optimize] iteration 3: 71 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] applies 2 matches, 2*SkipSimplifiedLayerNormalizationPattern - time=0.007 | max_time=ShapeBasedEditDistanceReshapePattern:0.001
[GraphBuilderPatternOptimization-QOM.optimize] iteration 4: 69 nodes, priority=1
[GraphBuilderPatternOptimization-QOM.optimize] increase priority to 2
[GraphBuilderPatternOptimization-QOM.optimize] iteration 5: 69 nodes, priority=2
[GraphBuilderPatternOptimization-QOM.optimize] increase priority to 3
[GraphBuilderPatternOptimization-QOM.optimize] iteration 6: 69 nodes, priority=3
[GraphBuilderPatternOptimization-QOM.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-QOM.optimize] done after 7 iterations with 69 nodes in 0.114
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 69 nodes, 31 initializers
[OrderOptimization.shape_order] done after in 0.00037146499926166143s with changed=5 scale=25
[GraphBuilder-QOM.optimize] done with 69 nodes in 0.131
[GraphBuilder-QOM.to_onnx] make_model 35 inits 12 params
[GraphBuilder-QOM.time_evaluation_constants_] 0.0009111650006161653
[GraphBuilder-QOM._build_initializers] start with 35 initializers, large_model=True, external_threshold=1024
[GraphBuilder-QOM._build_initializers] switch low/high order
[GraphBuilder-QOM._build_initializers] done in 2.322000000276603e-06s with 31 initializers, 9 large initializers
[GraphBuilder-QOM._add_shape_information] dynamic shapes replacements={'seqlength': 'seqlength', 'batch': 'batch', 'pastlength': 'pastlength', 'totallength': 'totallength', 's61': 'batch', 's43': 'batch', 'batch^s61^batch^s61': 'batch', 's67': 'batch', 's72': 'batch', 's70': 'seqlength', 'Max(s58,s70)': 'seqlength', 's58': 'seqlength', 's53': 'totallength', 's44': 'pastlength', 's21': 'pastlength'}
[to_onnx] to_onnx done in 0.1486357189996852s and 69 nodes, 31 initializers, 5 inputs, 3 outputs
[method_to_onnx] save 3 outputs in 'plot_export_tiny_llm_method_generate.outputs.pt'
Continue: it rains...
We have been in crisis
Tarally I have been working on the first week at the first one.
Sat: “Besides the next day we will
them all the best

We finally need to check the discrepancies. The exports produced an onnx file and dumped the input and output of the torch model. We now run the onnx model to check it produces the same results. It is done after because the model may not hold twice in memory (torch and onnxruntime). verbose=2 shows more information about expected outputs.

data = forward_replacement.check_discrepancies(verbose=1)
df = pandas.DataFrame(data)
print(df)
[method_to_onnx.check_discrepancies] register classes [<class 'transformers.cache_utils.DynamicCache'>, <class 'transformers.cache_utils.DynamicLayer'>, <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>, <class 'torch.dtype'>]
[method_to_onnx.check_discrepancies] load 'plot_export_tiny_llm_method_generate.inputs.pt'
[method_to_onnx.check_discrepancies] load 'plot_export_tiny_llm_method_generate.outputs.pt'
[method_to_onnx.check_discrepancies] create onnx session 'plot_export_tiny_llm_method_generate.onnx'
[method_to_onnx.check_discrepancies] input_names=['input_ids', 'attention_mask', 'past_key_values_key_0', 'past_key_values_value_0', 'cache_position']
[method_to_onnx.check_discrepancies] onnx_shapes=INT64[batch,seqlength], INT64[batch,totallength], FLOAT[batch,1,pastlength,96], FLOAT[batch,1,pastlength,96], INT64[seqlength]
[method_to_onnx.check_discrepancies] process input 0 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] process input 1 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] process input 2 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] done
        abs       rel       sum         n  dnan  dev  >0.1  >0.01  SUCCESS  index  duration_torch  ort_duration  n_inputs
0  0.000016  0.003386  0.482308  257536.0     0    0     0      0     True      0        0.002885      1.846359         5
1  0.000013  0.003116  0.106110   33728.0     0    0     0      0     True      1        0.002364      0.001560         5
2  0.000010  0.001086  0.051389   33920.0     0    0     0      0     True      2        0.005172      0.001325         5
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
plot export tiny llm method generate

Total running time of the script: (0 minutes 11.585 seconds)

Related examples

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Gallery generated by Sphinx-Gallery