Note
Go to the end to download the full example code.
Export a LLM through method generate (with Tiny-LLM)¶
The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. Example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM) gives details on how to guess dummy inputs and dynamic shapes to do so. Let’s see how to simplify that.
Dummy Example¶
Let’s use the example provided on arnir0/Tiny-LLM.
import pandas
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import method_to_onnx
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
def generate_text(
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Define your prompt
prompt = "Continue: it rains, what should I do?"
generated_text = generate_text(prompt, model, tokenizer)
print("-----------------")
print(generated_text)
print("-----------------")
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 8289.14it/s, Materializing param=lm_head.weight]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 2646.25it/s, Materializing param=lm_head.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 107.50it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 104.16it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 145.42it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 141.97it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 179.46it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 174.77it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 209.86it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 206.61it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 178.72it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 173.10it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 197.37it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 195.46it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 220.42it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 219.75it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 236.85it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 235.78it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 245.65it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 243.21it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 264.01it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 260.93it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 281.02it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 280.15it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 278.84it/s, Materializing param=model.norm.weight]
-----------------
Continue: it rains, what should I do?
- I have the right amount of the rats for all sides.
- In the back of the 17th, 14.
- If you have been
-----------------
Replace forward method¶
We now modify the model to export the model by replacing the forward method.
We still call method generate but this one will call a different function
created by onnx_diagnostic.export.api.method_to_onnx().
This one captured the inputs of the forward method, 2 calls are needed or
at least, 3 are recommended for LLMs as the first call does not contain any cache.
If the default settings do not work, skip_kwargs_names and dynamic_shapes
can be changed to remove some undesired inputs or add more dynamic dimensions.
filename = "plot_export_tiny_llm_method_generate.custom.onnx"
forward_replacement = method_to_onnx(
model,
method_name="forward", # default value
exporter="custom", # onnx-dynamo to use the official exporter
filename=filename, # onnx file to create
patch_kwargs=dict(patch_transformers=True), # patches before eporting
# to see the progress, it is recommended on the first try to see
# how to set ``skip_kwargs_names`` and ``dynamic_shapes`` if it is needed
verbose=1,
# triggers the ONNX conversion after 3 calls to forward method,
# the onnx version is triggered with the last one,
# the others are used to infer the dynamic shapes if they are not
# specified below
convert_after_n_calls=3,
# The input used in the example has a batch size equal to 1, all
# inputs going through method forward will have the same batch size.
# To force the dynamism of this dimension, we need to indicate
# which inputs have a batch size.
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
# Earlier versions of pytorch did not accept a dynamic batch size equal to 1,
# this last parameter can be added to expand some inputs if the batch size is 1.
# The exporter should work without.
expand_batch_for={"input_ids", "attention_mask", "past_key_values"},
)
dynamic shapes can be inferred from at least two calls to the forward method,
3 is better for LLMs (first call is prefill, cache is missing),
you can see the inference results with verbose=1.
If the value is not the expected one (to change the names for example),
They can be overwritten.
dynamic_shapes={
"cache_position": {0: "sequence_length"},
"past_key_values": [
{0: "batch_size", 2: "past_sequence_length"},
{0: "batch_size", 2: "past_sequence_length"},
],
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "total_sequence_length"},
}
Finally, we need to replace the forward method.
As forward_replacement is a module of type
onnx_diagnostic.export.api.WrapperToExportMethodToOnnx,
a lambda function must be used to avoid this one to be
included as a submodule (and create an infinite loop).
print(f"type(forward_replacement)={type(forward_replacement)}")
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
type(forward_replacement)=<class 'onnx_diagnostic.export.api.WrapperToExportMethodToOnnx'>
Let’s call generate again. The conversion is triggered after
convert_after_n_calls=3 calls to the method forward,
which exactly what the method generate is doing.
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)
[method_to_onnx] input[0]: ((),dict(input_ids:T7s1x13,attention_mask:T7s1x13,cache_position:T7s13))
[method_to_onnx] input[1]: ((),dict(input_ids:T7s1x1,attention_mask:T7s1x14,past_key_values:DynamicCache(key_cache=#1[T1s1x1x13x96], value_cache=#1[T1s1x1x13x96]),cache_position:T7s1))
[method_to_onnx] input[2]: ((),dict(input_ids:T7s1x1,attention_mask:T7s1x15,past_key_values:DynamicCache(key_cache=#1[T1s1x1x14x96], value_cache=#1[T1s1x1x14x96]),cache_position:T7s1))
[method_to_onnx] save 3 inputs in 'plot_export_tiny_llm_method_generate.custom.inputs.pt'
[method_to_onnx] guess_dynamic_shapes=((),dict(input_ids:{1:DYNAMIC},attention_mask:{1:DYNAMIC},past_key_values:#2[{2:DYNAMIC},{2:DYNAMIC}],cache_position:{0:DYNAMIC}))
[method_to_onnx.rename_dynamic_shapes] apply pattern shapes 'LLM.text'
[method_to_onnx] dynamic_batch_for={'past_key_values', 'attention_mask', 'input_ids'}
[method_to_onnx] dynamic_shapes with batch=((), {'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}})
[method_to_onnx] export args=()
[method_to_onnx] export kwargs=dict(input_ids:T7s2x1,attention_mask:T7s2x15,past_key_values:DynamicCache(key_cache=#1[T1s2x1x14x96], value_cache=#1[T1s2x1x14x96]),cache_position:T7s1)
[method_to_onnx] dynamic_shapes=((),dict(input_ids:{0:DYN(batch),1:DYN(seqlength)},attention_mask:{0:DYN(batch),1:DYN(totallength)},past_key_values:#2[{0:DYN(batch),2:DYN(pastlength)},{0:DYN(batch),2:DYN(pastlength)}],cache_position:{0:DYN(seqlength)}))
[to_onnx] build the graph module from <class 'onnx_diagnostic.export.api.WrapperToExportMethodToOnnx._convert_method_to_onnx.<locals>.WrapWithExactSignature'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),cache_position:T7r1)
[_make_builder_interpreter] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'WrapWithExactSignature'
[ExportOptions.export] aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seqlength'}, 'attention_mask': {0: 'batch', 1: 'totallength'}, 'past_key_values': [{0: 'batch', 2: 'pastlength'}, {0: 'batch', 2: 'pastlength'}], 'cache_position': {0: 'seqlength'}}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),cache_position:T7r1)
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious={'input_ids': {1: 'd=[1]'}, 'cache_position': {0: 'd=[1]'}}
[torch_export] export starts with backed_size_oblivious={'input_ids': {1: 'd=[1]'}, 'cache_position': {0: 'd=[1]'}}
[ExportOptions.export] export done in 3.032506055998965
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[ExportOptions.export] slices: 10 slices nodes were removed
[CustomTracer.remove_inplace] starts with 171 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 13 inplace nodes
[CustomTracer.remove_inplace] S2: 7 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 100 iterations and 141 nodes (n_inplace=7)
[ExportOptions.export] inplaces: 13 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.004061824001837522, modified=13
[ExportOptions.export] done with no decomposition in 0.004180184001597809
[to_onnx] graph module done in 3.055646485001489 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
0%| | 0/141 [00:00<?, ?it/s]
77%|███████▋ | 108/141 [00:00<00:00, 1073.36it/s]
100%|██████████| 141/141 [00:00<00:00, 1056.34it/s]
[to_onnx] 203 onnx nodes done in 0.24642151299849502 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True, True]
[GraphBuilder-HFK.inline_functions] begin inlining graph
[GraphBuilder-HFK.inline_functions] skip_functions=set()
[GraphBuilder-HFK._inline_functions_iterations] inline function 'submod_3' domain 'local_functions' [n_replacements=1]
[GraphBuilder-HFK._inline_functions_iterations] done with 9 new nodes for 'submod_3', 'local_functions'
[GraphBuilder-HFK.inline_functions] done inlining graph 138359794490720 in 0.004071620998729486
[GraphBuilder-HFK._add_shape_information] dynamic shapes replacements={'pastlength': 'pastlength', 'batch': 'batch', 'totallength': 'totallength', 'seqlength': 'seqlength', 's67': 'batch', 's72': 'batch', 's43': 'batch', 's61': 'batch', 'batch^s61^batch^s61': 'batch', 's70': 'seqlength', 'Max(s58,s70)': 'seqlength', 's58': 'seqlength', 's53': 'totallength', 's21': 'pastlength', 's44': 'pastlength'}
[GraphBuilder-HFK.optimize] start with 211 nodes
[GraphBuilder-HFK.optimize] #patterns=109
[GraphBuilder-HFK.optimize] start with subgraphs
[GraphBuilder-HFK.optimize] done with subgraphs
[GraphBuilderPatternOptimization-HFK.optimize] start with 161 nodes, 32 initializers, 109 patterns, priorities=[0, 1, 2, 3], max_iter=644
[GraphBuilderPatternOptimization-HFK.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-HFK.optimize] iteration 0: 161 nodes, priority=0
[GraphBuilderPatternOptimization-HFK.optimize] applies 33 matches, 11*CastPattern, 2*IdentityPattern, 3*ShapeBasedReshapeIsSqueezePattern, 2*ShapeBasedStaticExpandPattern, 3*ShapeBasedEditDistanceReshapePattern, 4*SameChildrenPattern, 1*SameChildrenFromInputPattern, 2*SqueezeAddPattern, 1*SqueezeUnsqueezePattern, 3*UnsqueezeUnsqueezePattern, 1*FunctionAttentionPattern - time=0.013 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-HFK.optimize] reapply {'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-HFK.optimize] n_added=6, n_removed=8, n_applied=35 applied patterns, 110 nodes left with 2 iterations
[GraphBuilderPatternOptimization-HFK.optimize] increase priority to 1
[GraphBuilderPatternOptimization-HFK.optimize] iteration 1: 110 nodes, priority=1
[GraphBuilderPatternOptimization-HFK.optimize] applies 16 matches, 2*ConcatTwiceUnaryPattern, 1*ConstantToInitializerPattern, 1*IdentityPattern, 2*SlicesSplitPattern, 1*SqueezeBinaryUnsqueezePattern, 2*SwapUnsqueezeTransposePattern, 3*UnsqueezeUnsqueezePattern, 1*QuickGeluPattern, 3*SimplifiedLayerNormalizationPattern - time=0.018 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-HFK.optimize] iteration 2: 87 nodes, priority=1
[GraphBuilderPatternOptimization-HFK.optimize] applies 8 matches, 2*IdentityPattern, 1*UnsqueezeUnsqueezePattern, 2*FunctionHalfRotaryEmbeddingPattern, 3*SimplifiedLayerNormalizationMulPattern - time=0.012 | max_time=SoftmaxCrossEntropyLossCastPattern:0.001
[GraphBuilderPatternOptimization-HFK.optimize] iteration 3: 71 nodes, priority=1
[GraphBuilderPatternOptimization-HFK.optimize] applies 2 matches, 2*SkipSimplifiedLayerNormalizationPattern - time=0.007 | max_time=ShapeBasedEditDistanceReshapePattern:0.000
[GraphBuilderPatternOptimization-HFK.optimize] iteration 4: 69 nodes, priority=1
[GraphBuilderPatternOptimization-HFK.optimize] increase priority to 2
[GraphBuilderPatternOptimization-HFK.optimize] iteration 5: 69 nodes, priority=2
[GraphBuilderPatternOptimization-HFK.optimize] increase priority to 3
[GraphBuilderPatternOptimization-HFK.optimize] iteration 6: 69 nodes, priority=3
[GraphBuilderPatternOptimization-HFK.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-HFK.optimize] done after 7 iterations with 69 nodes in 0.125
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 69 nodes, 31 initializers
[OrderOptimization.shape_order] done after in 0.0003772980016947258s with changed=5 scale=25
[GraphBuilder-HFK.optimize] done with 69 nodes in 0.143
[GraphBuilder-HFK.to_onnx] make_model 35 inits 12 params
[GraphBuilder-HFK.time_evaluation_constants_] 0.0008324000009451993
[GraphBuilder-HFK._build_initializers] start with 35 initializers, large_model=True, external_threshold=1024
[GraphBuilder-HFK._build_initializers] switch low/high order
[GraphBuilder-HFK._build_initializers] done in 2.34400067711249e-06s with 31 initializers, 9 large initializers
[GraphBuilder-HFK._add_shape_information] dynamic shapes replacements={'pastlength': 'pastlength', 'batch': 'batch', 'totallength': 'totallength', 'seqlength': 'seqlength', 's67': 'batch', 's72': 'batch', 's43': 'batch', 's61': 'batch', 'batch^s61^batch^s61': 'batch', 's70': 'seqlength', 'Max(s58,s70)': 'seqlength', 's58': 'seqlength', 's53': 'totallength', 's21': 'pastlength', 's44': 'pastlength'}
[to_onnx] to_onnx done in 0.16169063299821573s and 69 nodes, 31 initializers, 5 inputs, 3 outputs
[method_to_onnx] save 3 outputs in 'plot_export_tiny_llm_method_generate.custom.outputs.pt'
Continue: it rains, what should I do?
• Hee?
• You should do the best?
“Is there a way a better one? Are there other way back to that?”
• “The answer to
We finally need to check the discrepancies. The exports produced an onnx file and dumped the input and output of the torch model. We now run the onnx model to check it produces the same results. It is done after because the model may not hold twice in memory (torch and onnxruntime). verbose=2 shows more information about expected outputs.
data = forward_replacement.check_discrepancies(verbose=1)
df = pandas.DataFrame(data)
print(df)
[method_to_onnx.check_discrepancies] register classes [<class 'transformers.modeling_outputs.CausalLMOutputWithPast'>, <class 'transformers.cache_utils.DynamicLayer'>, <class 'transformers.cache_utils.DynamicCache'>]
[method_to_onnx.check_discrepancies] load 'plot_export_tiny_llm_method_generate.custom.inputs.pt'
[method_to_onnx.check_discrepancies] load 'plot_export_tiny_llm_method_generate.custom.outputs.pt'
[method_to_onnx.check_discrepancies] create onnx session 'plot_export_tiny_llm_method_generate.custom.onnx'
[method_to_onnx.check_discrepancies] input_names=['input_ids', 'attention_mask', 'past_key_values_key_0', 'past_key_values_value_0', 'cache_position']
[method_to_onnx.check_discrepancies] onnx_shapes=INT64[batch,seqlength], INT64[batch,totallength], FLOAT[batch,1,pastlength,96], FLOAT[batch,1,pastlength,96], INT64[seqlength]
[method_to_onnx.check_discrepancies] process input 0 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] process input 1 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] process input 2 #args=0 #kwargs=4
[method_to_onnx.check_discrepancies] done
abs rel sum n dnan dev >0.1 >0.01 SUCCESS index duration_torch ort_duration n_inputs
0 0.000019 0.004534 0.927152 418496.0 0 0 0 0 True 0 0.002990 2.114623 5
1 0.000010 0.001035 0.066861 34688.0 0 0 0 0 True 1 0.001989 0.002303 5
2 0.000012 0.001107 0.064074 34880.0 0 0 0 0 True 2 0.006405 0.001317 5
Minimal script to export a LLM¶
The following lines are a condensed copy with less comments.
# from HuggingFace
print("----------------")
MODEL_NAME = "arnir0/Tiny-LLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# to export into onnx
forward_replacement = method_to_onnx(
model,
method_name="forward",
exporter="onnx-dynamo",
filename="plot_export_tiny_llm_method_generate.dynamo.onnx",
patch_kwargs=dict(patch_transformers=True),
verbose=0,
convert_after_n_calls=3,
dynamic_batch_for={"input_ids", "attention_mask", "past_key_values"},
)
model.forward = lambda *args, **kwargs: forward_replacement(*args, **kwargs)
# from HuggingFace again
prompt = "Continue: it rains, what should I do?"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=100,
temperature=1,
top_k=50,
top_p=0.95,
do_sample=True,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("prompt answer:", generated_text)
# to check discrepancies
data = forward_replacement.check_discrepancies()
df = pandas.DataFrame(data)
print(df)
----------------
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 6269.51it/s, Materializing param=lm_head.weight]
Loading weights: 8%|▊ | 1/12 [00:00<00:00, 1757.88it/s, Materializing param=lm_head.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 407.71it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 17%|█▋ | 2/12 [00:00<00:00, 372.38it/s, Materializing param=model.embed_tokens.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 265.56it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 25%|██▌ | 3/12 [00:00<00:00, 257.58it/s, Materializing param=model.layers.0.input_layernorm.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 314.29it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 33%|███▎ | 4/12 [00:00<00:00, 307.13it/s, Materializing param=model.layers.0.mlp.down_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 320.88it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 42%|████▏ | 5/12 [00:00<00:00, 316.41it/s, Materializing param=model.layers.0.mlp.gate_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 359.22it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 50%|█████ | 6/12 [00:00<00:00, 354.14it/s, Materializing param=model.layers.0.mlp.up_proj.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 404.24it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 58%|█████▊ | 7/12 [00:00<00:00, 400.19it/s, Materializing param=model.layers.0.post_attention_layernorm.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 428.20it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 67%|██████▋ | 8/12 [00:00<00:00, 423.29it/s, Materializing param=model.layers.0.self_attn.k_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 436.87it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 75%|███████▌ | 9/12 [00:00<00:00, 431.48it/s, Materializing param=model.layers.0.self_attn.o_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 470.10it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 83%|████████▎ | 10/12 [00:00<00:00, 466.39it/s, Materializing param=model.layers.0.self_attn.q_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 480.32it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 92%|█████████▏| 11/12 [00:00<00:00, 474.91it/s, Materializing param=model.layers.0.self_attn.v_proj.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 512.74it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 509.67it/s, Materializing param=model.norm.weight]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 504.97it/s, Materializing param=model.norm.weight]
~/github/onnx-diagnostic/onnx_diagnostic/export/api.py:229: UserWarning: Exporting a model while it is in training mode. Please ensure that this is intended, as it may lead to different behavior during inference. Calling model.eval() before export is recommended.
epo = torch.onnx.export(
[torch.onnx] Obtain model graph for `WrapWithExactSignature([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `WrapWithExactSignature([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
return cls.__new__(cls, *args)
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
~/vv/this312/lib/python3.12/site-packages/torch/onnx/_internal/exporter/_onnx_program.py:460: UserWarning: # The axis name: batch will not be used, since it shares the same shape constraints with another axis: batch.
rename_mapping = _dynamic_shapes.create_rename_mapping(
Applied 32 of general pattern rewrite rules.
prompt answer: Continue: it rains, what should I do?
HOL: We should have the right balance (think you have, pink, and pink).
What do we know for the best and most important things in terms of our history?
We’re so excited that our time will be a very important one. As one of the highest and most importantly, the best possible outcome is to go through the past. We hope we have you ready to give it
abs rel sum n dnan dev >0.1 >0.01 SUCCESS index duration_torch ort_duration n_inputs
0 0.000019 0.004534 0.927152 418496.0 0 0 0 0 True 0 0.021793 0.187167 5
1 0.000010 0.001035 0.066861 34688.0 0 0 0 0 True 1 0.004354 0.001967 5
2 0.000019 0.004635 0.161854 34880.0 0 0 0 0 True 2 0.015524 0.001975 5

Total running time of the script: (0 minutes 23.623 seconds)
Related examples
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)