Export attention from arnir0/Tiny-LLM with InputObserver

This shows how to only export attention from model arnir0/Tiny-LLM. It uses what was shown in example Export a LLM with InputObserver (with Tiny-LLM).

Let’s create a random model

import pandas
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import (
    register_additional_serialization_functions,
    torch_export_patches,
)
from onnx_diagnostic.investigate.input_observer import InputObserver

device = "cuda"
model_id = "arnir0/Tiny-LLM"
print(f"get tokenizer {model_id!r}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"get config {model_id!r}")
config = AutoConfig.from_pretrained(model_id)
print(f"create model from config for {model_id!r}")
model = AutoModelForCausalLM.from_config(config)
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
model = model.to(device).to(torch.float16)
get tokenizer 'arnir0/Tiny-LLM'
get config 'arnir0/Tiny-LLM'
create model from config for 'arnir0/Tiny-LLM'
the model is created with 20 subdmodules.

We need to only export class LlamaAttention

export_module = None
for _name, sub in model.named_modules():
    if sub.__class__.__name__ == "LlamaAttention":
        export_module = sub

assert export_module is not None, (
    f"Unable to find a submodule from class LlamaAttention in "
    f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
)

Let’s run the model and capture the inputs and outputs of the attention part.

def generate_text(
    prompt,
    model,
    tokenizer,
    max_length=50,
    temperature=0.01,
    top_k=50,
    top_p=0.95,
    do_sample=True,
):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text


prompt = "Continue: it rains, what should I do?"
observer = InputObserver()
with (
    register_additional_serialization_functions(patch_transformers=True),
    observer(export_module),
):
    generate_text(prompt, model, tokenizer)

Export

First, what was inferred.

kwargs = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
print("attention type:", type(export_module))
print(f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}")
print(f"dynamic_shapes={dynamic_shapes}")
attention type: <class 'transformers.models.llama.modeling_llama.LlamaAttention'>
kwargs=dict(hidden_states:GT10s1x13x192,position_embeddings:(GT10s1x13x96,GT10s1x13x96),past_key_values:DynamicCache(key_cache=#1[GT10s1x1x0x96], value_cache=#1[GT10s1x1x0x96]),cache_position:GT7s13,position_ids:GT7s1x13)
dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}

Next, the export.

filename = "plot_export_tiny_llm_attention_input_observer.onnx"
with torch_export_patches(patch_torch=True, patch_transformers=True):
    to_onnx(
        export_module,
        args=(),
        kwargs=kwargs,
        filename=filename,
        dynamic_shapes=dynamic_shapes,
        exporter="custom",
        verbose=1,
    )
[to_onnx] build the graph module from <class 'transformers.models.llama.modeling_llama.LlamaAttention'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(hidden_states:T10r3,position_embeddings:(T10r3,T10r3),past_key_values:DynamicCache(key_cache=#1[T10r4], value_cache=#1[T10r4]),cache_position:T7r1,position_ids:T7r2)
[_make_builder_interpreter] dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'LlamaAttention'
[ExportOptions.export] aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'hidden_states': {1: DimHint(DYNAMIC)}, 'position_embeddings': ({1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}), 'past_key_values': [{2: DimHint(DYNAMIC)}, {2: DimHint(DYNAMIC)}], 'cache_position': {0: DimHint(DYNAMIC)}, 'kwargs': {'position_ids': {1: DimHint(DYNAMIC)}}}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(hidden_states:T10r3,position_embeddings:(T10r3,T10r3),past_key_values:DynamicCache(key_cache=#1[T10r4], value_cache=#1[T10r4]),cache_position:T7r1,position_ids:T7r2)
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious={'past_key_values': [{2: 'd=[0]'}, {2: 'd=[0]'}]}
[torch_export] export starts with backed_size_oblivious={'past_key_values': [{2: 'd=[0]'}, {2: 'd=[0]'}]}
[ExportOptions.export] export done in 2.7888713870015636
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[CustomTracer.remove_inplace] starts with 97 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 12 inplace nodes
[CustomTracer.remove_inplace] S2: 4 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 98 iterations and 73 nodes (n_inplace=4)
[ExportOptions.export] inplaces: 12 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.002953422001155559, modified=12
[ExportOptions.export] done with no decomposition in 0.003064266998990206
[to_onnx] graph module done in 2.805061394999939 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
[to_onnx] 109 onnx nodes done in 0.1518585379999422 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True]
[GraphBuilder-APS.inline_functions] begin inlining graph
[GraphBuilder-APS.inline_functions] skip_functions=set()
[GraphBuilder-APS._inline_functions_iterations] replace local functions in node 'If', name='cond'
[_inline_functions_subgraph_iteration] begin with 136584469886464
[_inline_functions_subgraph_iteration] inline function 'false_graph_0' domain 'local_functions'
[_inline_functions_subgraph_iteration] 20 new nodes for 'false_graph_0', 'local_functions'
[_inline_functions_subgraph_iteration] done with 136584469886464 and 1 replacements
[_inline_functions_subgraph_iteration] begin with 136584469886464
[_inline_functions_subgraph_iteration] done with 136584469886464 and 0 replacements
[GraphBuilder-APS._inline_functions_iterations] replace local functions in node 'If', name='cond'
[_inline_functions_subgraph_iteration] begin with 136584458331984
[_inline_functions_subgraph_iteration] inline function 'true_graph_0' domain 'local_functions'
[_inline_functions_subgraph_iteration] 30 new nodes for 'true_graph_0', 'local_functions'
[_inline_functions_subgraph_iteration] done with 136584458331984 and 1 replacements
[_inline_functions_subgraph_iteration] begin with 136584458331984
[_inline_functions_subgraph_iteration] done with 136584458331984 and 0 replacements
[GraphBuilder-APS.inline_functions] done inlining graph 136584460864688 in 0.009420833997864975
[GraphBuilder-APS._add_shape_information] dynamic shapes replacements={'batch': 'batch', 'channel_2': 'channel_2', 'channel_1': 'channel_1', 'channel_3': 'channel_3', 'D0': 'D0', 'channel': 'channel', 'D0_1': 'D0_1', 'DYN0': 'channel', 's87': 'channel', 'DYN1': 'channel_1', 's65': 'channel_1', 'DYN2': 'channel_2', 's29': 'channel_2', 'DYN3': 'D0', 's43': 'D0', 'DYN4': 'D0_1', 's21': 'D0_1', 'DYN6': 'channel_3', 's70': 'channel_3'}
[GraphBuilder-APS.optimize] start with 109 nodes
[GraphBuilder-APS.optimize] #patterns=113
[GraphBuilder-APS.optimize] start with subgraphs
[GraphBuilder-APS.optimize] done with subgraphs
[GraphBuilderPatternOptimization-APS.optimize] start with 45 nodes, 22 initializers, 113 patterns, priorities=[0, 1, 2, 3], max_iter=180
[GraphBuilderPatternOptimization-APS.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-APS.optimize] iteration 0: 45 nodes, priority=0
[GraphBuilderPatternOptimization-APS.optimize] applies 4 matches, 3*ShapeBasedEditDistanceReshapePattern, 1*SameChildrenPattern - time=0.003 | max_time=ShapeBasedEditDistanceReshapePattern:0.000
[GraphBuilderPatternOptimization-APS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-APS.optimize] n_added=0, n_removed=0, n_applied=4 applied patterns, 43 nodes left with 1 iterations
[GraphBuilderPatternOptimization-APS.optimize] increase priority to 1
[GraphBuilderPatternOptimization-APS.optimize] iteration 1: 43 nodes, priority=1
[GraphBuilderPatternOptimization-APS.optimize] applies 6 matches, 1*ConstantToInitializerPattern, 2*SlicesSplitPattern, 1*SqueezeUnsqueezePattern, 2*TransposeEqualReshapePattern - time=0.005 | max_time=ShapeBasedEditDistanceReshapePattern:0.000
[GraphBuilderPatternOptimization-APS.optimize] iteration 2: 40 nodes, priority=1
[GraphBuilderPatternOptimization-APS.optimize] applies 2 matches, 2*FunctionHalfRotaryEmbeddingPattern - time=0.005 | max_time=ShapeBasedEditDistanceReshapePattern:0.001
[GraphBuilderPatternOptimization-APS.optimize] iteration 3: 30 nodes, priority=1
[GraphBuilderPatternOptimization-APS.optimize] increase priority to 2
[GraphBuilderPatternOptimization-APS.optimize] iteration 4: 30 nodes, priority=2
[GraphBuilderPatternOptimization-APS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-APS.optimize] iteration 5: 30 nodes, priority=3
[GraphBuilderPatternOptimization-APS.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-APS.optimize] done after 6 iterations with 30 nodes in 0.060
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 30 nodes, 17 initializers
[OrderOptimization.shape_order] done after in 0.00018964599803439341s with changed=3 scale=12
[GraphBuilder-APS.optimize] done with 30 nodes in 0.156
[GraphBuilder-APS.to_onnx] make_model 17 inits 4 params
[GraphBuilder-APS.time_evaluation_constants_] 0.00027100200168206356
[GraphBuilder-APS._build_initializers] start with 17 initializers, large_model=True, external_threshold=1024
[GraphBuilder-APS._build_initializers] switch low/high order
[GraphBuilder-APS._build_initializers] done in 1.5019977581687272e-06s with 17 initializers, 4 large initializers
[GraphBuilder-APS._add_shape_information] dynamic shapes replacements={'batch': 'batch', 'channel_2': 'channel_2', 'channel_1': 'channel_1', 'channel_3': 'channel_3', 'D0': 'D0', 'channel': 'channel', 'D0_1': 'D0_1', 'DYN0': 'channel', 's87': 'channel', 'DYN1': 'channel_1', 's65': 'channel_1', 'DYN2': 'channel_2', 's29': 'channel_2', 'DYN3': 'D0', 's43': 'D0', 'DYN4': 'D0_1', 's21': 'D0_1', 'DYN6': 'channel_3', 's70': 'channel_3'}
[to_onnx] to_onnx done in 0.17553250700075296s and 30 nodes, 17 initializers, 7 inputs, 2 outputs

Let’s measure the discrepancies.

data = observer.check_discrepancies(
    filename, progress_bar=True, atol=1e-2, include_io=True, skip_none=True
)
df = pandas.DataFrame(data)
df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx")
print(df)
  0%|          | 0/3 [00:00<?, ?it/s]
 33%|███▎      | 1/3 [00:00<00:01,  1.65it/s]
100%|██████████| 3/3 [00:00<00:00,  4.65it/s]
        abs       rel       sum       n  dnan  dev  >0.1  >0.01  ...  duration_torch  ort_duration  n_inputs  n_none  n_empty                                             inputs          outputs_torch           outputs_ort
0  0.000122  0.039040  0.021232  2496.0     0    0     0      0  ...        0.214568      0.548897         7       2        2  dict(hidden_states:T10s1x13x192,position_embed...  #2[T10s1x13x192,None]  #2[T10s1x13x192,T1s]
1  0.000061  0.015270  0.001334   192.0     0    0     0      0  ...        0.029705      0.035655         7       0        0  dict(hidden_states:T10s1x1x192,position_embedd...   #2[T10s1x1x192,None]   #2[T10s1x1x192,T1s]
2  0.000061  0.009899  0.001285   192.0     0    0     0      0  ...        0.000708      0.001154         7       0        0  dict(hidden_states:T10s1x1x192,position_embedd...   #2[T10s1x1x192,None]   #2[T10s1x1x192,T1s]

[3 rows x 18 columns]

Let’s show the errors.

for row in data:
    if not row["SUCCESS"] and "error" in row:
        print(row["error"])
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
plot export tiny llm attention input observer

Total running time of the script: (0 minutes 10.723 seconds)

Related examples

Export OptiMind-SFT with InputObserver

Export OptiMind-SFT with InputObserver

Export Gemma3 tiny random with InputObserver

Export Gemma3 tiny random with InputObserver

Export a LLM with InputObserver (with Tiny-LLM)

Export a LLM with InputObserver (with Tiny-LLM)

Gallery generated by Sphinx-Gallery