Note
Go to the end to download the full example code.
Export attention from arnir0/Tiny-LLM with InputObserver¶
This shows how to only export attention from model arnir0/Tiny-LLM. It uses what was shown in example Export a LLM with InputObserver (with Tiny-LLM).
Let’s create a random model¶
import pandas
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import (
register_additional_serialization_functions,
torch_export_patches,
)
from onnx_diagnostic.investigate.input_observer import InputObserver
device = "cuda"
model_id = "arnir0/Tiny-LLM"
print(f"get tokenizer {model_id!r}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"get config {model_id!r}")
config = AutoConfig.from_pretrained(model_id)
print(f"create model from config for {model_id!r}")
model = AutoModelForCausalLM.from_config(config)
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
model = model.to(device).to(torch.float16)
get tokenizer 'arnir0/Tiny-LLM'
get config 'arnir0/Tiny-LLM'
create model from config for 'arnir0/Tiny-LLM'
the model is created with 20 subdmodules.
We need to only export class LlamaAttention¶
export_module = None
for _name, sub in model.named_modules():
if sub.__class__.__name__ == "LlamaAttention":
export_module = sub
assert export_module is not None, (
f"Unable to find a submodule from class LlamaAttention in "
f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
)
Let’s run the model and capture the inputs and outputs of the attention part.
def generate_text(
prompt,
model,
tokenizer,
max_length=50,
temperature=0.01,
top_k=50,
top_p=0.95,
do_sample=True,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
prompt = "Continue: it rains, what should I do?"
observer = InputObserver()
with (
register_additional_serialization_functions(patch_transformers=True),
observer(export_module),
):
generate_text(prompt, model, tokenizer)
Export¶
First, what was inferred.
kwargs = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
print("attention type:", type(export_module))
print(f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}")
print(f"dynamic_shapes={dynamic_shapes}")
attention type: <class 'transformers.models.llama.modeling_llama.LlamaAttention'>
kwargs=dict(hidden_states:GT10s1x13x192,position_embeddings:(GT10s1x13x96,GT10s1x13x96),past_key_values:DynamicCache(key_cache=#1[GT10s1x1x0x96], value_cache=#1[GT10s1x1x0x96]),cache_position:GT7s13,position_ids:GT7s1x13)
dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
Next, the export.
filename = "plot_export_tiny_llm_attention_input_observer.onnx"
with torch_export_patches(patch_torch=True, patch_transformers=True):
to_onnx(
export_module,
args=(),
kwargs=kwargs,
filename=filename,
dynamic_shapes=dynamic_shapes,
exporter="custom",
verbose=1,
)
[to_onnx] build the graph module from <class 'transformers.models.llama.modeling_llama.LlamaAttention'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(hidden_states:T10r3,position_embeddings:(T10r3,T10r3),past_key_values:DynamicCache(key_cache=#1[T10r4], value_cache=#1[T10r4]),cache_position:T7r1,position_ids:T7r2)
[_make_builder_interpreter] dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'LlamaAttention'
[ExportOptions.export] aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(hidden_states:T10r3,position_embeddings:(T10r3,T10r3),past_key_values:DynamicCache(key_cache=#1[T10r4], value_cache=#1[T10r4]),cache_position:T7r1,position_ids:T7r2)
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious={'past_key_values': [{2: 'd=[0]'}, {2: 'd=[0]'}]}
[torch_export] export starts with backed_size_oblivious={'past_key_values': [{2: 'd=[0]'}, {2: 'd=[0]'}]}
[ExportOptions.export] export done in 2.7888713870015636
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[CustomTracer.remove_inplace] starts with 97 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 12 inplace nodes
[CustomTracer.remove_inplace] S2: 4 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 98 iterations and 73 nodes (n_inplace=4)
[ExportOptions.export] inplaces: 12 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.002953422001155559, modified=12
[ExportOptions.export] done with no decomposition in 0.003064266998990206
[to_onnx] graph module done in 2.805061394999939 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
[to_onnx] 109 onnx nodes done in 0.1518585379999422 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True]
[GraphBuilder-APS.inline_functions] begin inlining graph
[GraphBuilder-APS.inline_functions] skip_functions=set()
[GraphBuilder-APS._inline_functions_iterations] replace local functions in node 'If', name='cond'
[_inline_functions_subgraph_iteration] begin with 136584469886464
[_inline_functions_subgraph_iteration] inline function 'false_graph_0' domain 'local_functions'
[_inline_functions_subgraph_iteration] 20 new nodes for 'false_graph_0', 'local_functions'
[_inline_functions_subgraph_iteration] done with 136584469886464 and 1 replacements
[_inline_functions_subgraph_iteration] begin with 136584469886464
[_inline_functions_subgraph_iteration] done with 136584469886464 and 0 replacements
[GraphBuilder-APS._inline_functions_iterations] replace local functions in node 'If', name='cond'
[_inline_functions_subgraph_iteration] begin with 136584458331984
[_inline_functions_subgraph_iteration] inline function 'true_graph_0' domain 'local_functions'
[_inline_functions_subgraph_iteration] 30 new nodes for 'true_graph_0', 'local_functions'
[_inline_functions_subgraph_iteration] done with 136584458331984 and 1 replacements
[_inline_functions_subgraph_iteration] begin with 136584458331984
[_inline_functions_subgraph_iteration] done with 136584458331984 and 0 replacements
[GraphBuilder-APS.inline_functions] done inlining graph 136584460864688 in 0.009420833997864975
[GraphBuilder-APS._add_shape_information] dynamic shapes replacements={'batch': 'batch', 'channel_2': 'channel_2', 'channel_1': 'channel_1', 'channel_3': 'channel_3', 'D0': 'D0', 'channel': 'channel', 'D0_1': 'D0_1', 'DYN0': 'channel', 's87': 'channel', 'DYN1': 'channel_1', 's65': 'channel_1', 'DYN2': 'channel_2', 's29': 'channel_2', 'DYN3': 'D0', 's43': 'D0', 'DYN4': 'D0_1', 's21': 'D0_1', 'DYN6': 'channel_3', 's70': 'channel_3'}
[GraphBuilder-APS.optimize] start with 109 nodes
[GraphBuilder-APS.optimize] #patterns=113
[GraphBuilder-APS.optimize] start with subgraphs
[GraphBuilder-APS.optimize] done with subgraphs
[GraphBuilderPatternOptimization-APS.optimize] start with 45 nodes, 22 initializers, 113 patterns, priorities=[0, 1, 2, 3], max_iter=180
[GraphBuilderPatternOptimization-APS.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-APS.optimize] iteration 0: 45 nodes, priority=0
[GraphBuilderPatternOptimization-APS.optimize] applies 4 matches, 3*ShapeBasedEditDistanceReshapePattern, 1*SameChildrenPattern - time=0.003 | max_time=ShapeBasedEditDistanceReshapePattern:0.000
[GraphBuilderPatternOptimization-APS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-APS.optimize] n_added=0, n_removed=0, n_applied=4 applied patterns, 43 nodes left with 1 iterations
[GraphBuilderPatternOptimization-APS.optimize] increase priority to 1
[GraphBuilderPatternOptimization-APS.optimize] iteration 1: 43 nodes, priority=1
[GraphBuilderPatternOptimization-APS.optimize] applies 6 matches, 1*ConstantToInitializerPattern, 2*SlicesSplitPattern, 1*SqueezeUnsqueezePattern, 2*TransposeEqualReshapePattern - time=0.005 | max_time=ShapeBasedEditDistanceReshapePattern:0.000
[GraphBuilderPatternOptimization-APS.optimize] iteration 2: 40 nodes, priority=1
[GraphBuilderPatternOptimization-APS.optimize] applies 2 matches, 2*FunctionHalfRotaryEmbeddingPattern - time=0.005 | max_time=ShapeBasedEditDistanceReshapePattern:0.001
[GraphBuilderPatternOptimization-APS.optimize] iteration 3: 30 nodes, priority=1
[GraphBuilderPatternOptimization-APS.optimize] increase priority to 2
[GraphBuilderPatternOptimization-APS.optimize] iteration 4: 30 nodes, priority=2
[GraphBuilderPatternOptimization-APS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-APS.optimize] iteration 5: 30 nodes, priority=3
[GraphBuilderPatternOptimization-APS.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-APS.optimize] done after 6 iterations with 30 nodes in 0.060
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 30 nodes, 17 initializers
[OrderOptimization.shape_order] done after in 0.00018964599803439341s with changed=3 scale=12
[GraphBuilder-APS.optimize] done with 30 nodes in 0.156
[GraphBuilder-APS.to_onnx] make_model 17 inits 4 params
[GraphBuilder-APS.time_evaluation_constants_] 0.00027100200168206356
[GraphBuilder-APS._build_initializers] start with 17 initializers, large_model=True, external_threshold=1024
[GraphBuilder-APS._build_initializers] switch low/high order
[GraphBuilder-APS._build_initializers] done in 1.5019977581687272e-06s with 17 initializers, 4 large initializers
[GraphBuilder-APS._add_shape_information] dynamic shapes replacements={'batch': 'batch', 'channel_2': 'channel_2', 'channel_1': 'channel_1', 'channel_3': 'channel_3', 'D0': 'D0', 'channel': 'channel', 'D0_1': 'D0_1', 'DYN0': 'channel', 's87': 'channel', 'DYN1': 'channel_1', 's65': 'channel_1', 'DYN2': 'channel_2', 's29': 'channel_2', 'DYN3': 'D0', 's43': 'D0', 'DYN4': 'D0_1', 's21': 'D0_1', 'DYN6': 'channel_3', 's70': 'channel_3'}
[to_onnx] to_onnx done in 0.17553250700075296s and 30 nodes, 17 initializers, 7 inputs, 2 outputs
Let’s measure the discrepancies.
data = observer.check_discrepancies(
filename, progress_bar=True, atol=1e-2, include_io=True, skip_none=True
)
df = pandas.DataFrame(data)
df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx")
print(df)
0%| | 0/3 [00:00<?, ?it/s]
33%|███▎ | 1/3 [00:00<00:01, 1.65it/s]
100%|██████████| 3/3 [00:00<00:00, 4.65it/s]
abs rel sum n dnan dev >0.1 >0.01 ... duration_torch ort_duration n_inputs n_none n_empty inputs outputs_torch outputs_ort
0 0.000122 0.039040 0.021232 2496.0 0 0 0 0 ... 0.214568 0.548897 7 2 2 dict(hidden_states:T10s1x13x192,position_embed... #2[T10s1x13x192,None] #2[T10s1x13x192,T1s]
1 0.000061 0.015270 0.001334 192.0 0 0 0 0 ... 0.029705 0.035655 7 0 0 dict(hidden_states:T10s1x1x192,position_embedd... #2[T10s1x1x192,None] #2[T10s1x1x192,T1s]
2 0.000061 0.009899 0.001285 192.0 0 0 0 0 ... 0.000708 0.001154 7 0 0 dict(hidden_states:T10s1x1x192,position_embedd... #2[T10s1x1x192,None] #2[T10s1x1x192,T1s]
[3 rows x 18 columns]
Let’s show the errors.

Total running time of the script: (0 minutes 10.723 seconds)
Related examples