Note
Go to the end to download the full example code.
Validate a LLM export and inspect discrepancies#
validate_model is a convenience
function that bundles the entire export-and-verify pipeline into a single call:
Load the model config (bundled copy preferred, then local HF cache, then network).
Optionally apply config_overrides to reduce the model size for fast testing.
Load the pre-trained weights — or, when random_weights is
True, instantiate the model from the (possibly modified) config with random weights so that no large checkpoint is downloaded.Run
model.generatewith a text prompt inside anInputObservercontext to capture real input/output tensors.Infer export arguments and dynamic shapes from the captured tensors.
Export the model to ONNX.
Run ONNX Runtime on every captured input set, compare the outputs against the original PyTorch outputs, and report the per-step discrepancies.
The function returns a ValidateSummary
with status flags / error messages and a ValidateData
with all intermediate artefacts including the raw discrepancy records.
Command-line options
All steps after step 2 can be exercised offline:
python plot_validate_model.py --no-trained # random weights, no download
python plot_validate_model.py --num-hidden-layers 1 # 1-layer model, faster
python plot_validate_model.py --model arnir0/Tiny-LLM # different model
Imports#
import argparse
import sys
import pandas
from yobx.torch.validate import validate_model, ValidateSummary, ValidateData
Command-line arguments#
_DEFAULT_MODEL = "arnir0/Tiny-LLM"
parser = argparse.ArgumentParser(description="Validate a HuggingFace LLM export to ONNX.")
parser.add_argument(
"--model",
default=_DEFAULT_MODEL,
metavar="MODEL_ID",
help=f"HuggingFace model ID (default: {_DEFAULT_MODEL!r}).",
)
parser.add_argument(
"--trained",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Load pre-trained weights (default). "
"Pass --no-trained to use random weights (no large download)."
),
)
parser.add_argument(
"--num-hidden-layers",
type=int,
default=None,
metavar="LAYERS",
help="Override config.num_hidden_layers (reduces model size for testing).",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=3,
metavar="N",
help="Number of tokens generated during input capture (default: 3).",
)
# parse_known_args avoids failures when sphinx-gallery passes extra arguments.
args, _ = parser.parse_known_args(sys.argv[1:])
config_overrides = {}
if args.num_hidden_layers is not None:
config_overrides["num_hidden_layers"] = args.num_hidden_layers
Run validate_model#
validate_model orchestrates the
entire pipeline. Setting verbose=2 prints a one-line status for each
captured input set during the discrepancy check (index, SUCCESS, absolute and
relative differences). Use verbose=3 to additionally print tensor shapes.
summary: ValidateSummary
data: ValidateData
summary, data = validate_model(
args.model,
random_weights=not args.trained,
max_new_tokens=args.max_new_tokens,
config_overrides=config_overrides or None,
quiet=True,
verbose=2,
)
[validate_model] loading config for 'arnir0/Tiny-LLM'
[validate_model] loading tokenizer for 'arnir0/Tiny-LLM'
[validate_model] loading model for 'arnir0/Tiny-LLM'
Loading weights: 0%| | 0/12 [00:00<?, ?it/s]
Loading weights: 100%|██████████| 12/12 [00:00<00:00, 486.70it/s]
[validate_model] capturing inputs with InputObserver (prompt='Continue: it rains, what should I do?')
[validate_model] captured 3 input set(s)
[validate_model] kwargs: dict(input_ids:T7s1x13,attention_mask:T7s1x13,position_ids:T7s1x13,past_key_values:DynamicCache(key_cache=#1[T1s1x1x0x96], value_cache=#1[T1s1x1x0x96]),logits_to_keep:int)
[validate_model] dynamic_shapes: {'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None}
[validate_model] exporting to ONNX (exporter='yobx', opset=22) -> '/tmp/tmp_3tykcu_/arnir0-Tiny-LLM.yobx.22.default.patch.onnx'
[to_onnx] build the graph module from <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(input_ids:T7r2,attention_mask:T7r2,position_ids:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),logits_to_keep:int)
[_make_builder_interpreter] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)) - export 'LlamaForCausalLM'
[ExportOptions.export] torch.export.export strict=False
[ExportOptions.export] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(input_ids:T7r2,attention_mask:T7r2,position_ids:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),logits_to_keep:int)
[ExportOptions.export] export start with strict=False backed_size_oblivious=auto
[ExportOptions.export] export done in 2.8982437800004845
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[ExportOptions.export] slices: 9 slices nodes were removed
[CustomTracer.remove_inplace] starts with 175 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 15 inplace nodes
[CustomTracer.remove_inplace] S2: 7 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 100 iterations and 144 nodes (n_inplace=7)
[ExportOptions.export] inplaces: 15 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.010209355999904801, modified=15
[ExportOptions.export] done with no decomposition in 0.010503989000426373
[to_onnx] graph module done in 2.915339342000152 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
0%| | 0/144 [00:00<?, ?it/s]
38%|███▊ | 54/144 [00:00<00:00, 530.68it/s]
75%|███████▌ | 108/144 [00:00<00:00, 223.35it/s]
97%|█████████▋| 140/144 [00:00<00:00, 211.00it/s]
100%|██████████| 144/144 [00:00<00:00, 225.89it/s]
[to_onnx] 208 onnx nodes done in 0.7456645170004776 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True, True]
[GraphBuilder-YDK.inline_functions] begin inlining graph
[GraphBuilder-YDK.inline_functions] skip_functions=set()
[GraphBuilder-YDK._inline_functions_iterations] inline function 'submod_3' domain 'local_functions' [n_replacements=1]
[GraphBuilder-YDK._inline_functions_iterations] done with 9 new nodes for 'submod_3', 'local_functions'
[GraphBuilder-YDK.inline_functions] done inlining graph 123941052969952 in 0.018863514998884057
[GraphBuilder-YDK._add_shape_information] dynamic shapes replacements={'batch_3': 'batch_3', 'batch_2': 'batch_2', 'channel': 'channel', 'batch': 'batch', 'D0_1': 'D0_1', 'D0': 'D0', 'batch_1': 'batch_1', 'channel_2': 'channel_2', 'batch_4': 'batch_4', 'channel_1': 'channel_1', 'DYN0': 'batch', 's67': 'batch', 's61': 'batch', 'DYN0^s61': 'batch', 'DYN8': 'batch_4', 's72': 'batch', 'DYN1': 'channel', 's70': 'channel', 'DYN2': 'batch_1', 's43': 'batch_1', 'DYN3': 'channel_1', 's53': 'channel_1', 'DYN4': 'batch_2', 's44': 'batch_2', 'DYN5': 'channel_2', 's9': 'channel_2', 'DYN6': 'batch_3', 'DYN7': 'D0', 's45': 'D0', 'DYN9': 'D0_1', 's21': 'D0_1'}
[GraphBuilder-YDK.optimize] start with 216 nodes
[GraphBuilder-YDK.optimize] #patterns=98
[GraphBuilder-YDK.optimize] start with subgraphs
[GraphBuilder-YDK.optimize] done with subgraphs
[GraphBuilderPatternOptimization-YDK.optimize] start with 170 nodes, 32 initializers, 98 patterns, priorities=[0, 1, 2, 3], max_iter=680
[GraphBuilderPatternOptimization-YDK.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-YDK.optimize] iteration 0: 170 nodes, priority=0
[GraphBuilderPatternOptimization-YDK.optimize] applies 32 matches, 12*CastPattern, 2*IdentityPattern, 4*ShapeBasedReshapeIsSqueezePattern, 2*ShapeBasedStaticExpandPattern, 1*ShapeBasedEditDistanceReshapePattern, 4*SameChildrenPattern, 1*SqueezeAddPattern, 2*SqueezeUnsqueezePattern, 1*SwapUnaryPattern, 3*UnsqueezeUnsqueezePattern - time=0.033 | max_time=SoftmaxCrossEntropyLossCastPattern:0.005
[GraphBuilderPatternOptimization-YDK.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-YDK.optimize] n_added=3, n_removed=4, n_applied=33 applied patterns, 133 nodes left with 2 iterations
[GraphBuilderPatternOptimization-YDK.optimize] increase priority to 1
[GraphBuilderPatternOptimization-YDK.optimize] iteration 1: 133 nodes, priority=1
[GraphBuilderPatternOptimization-YDK.optimize] applies 19 matches, 2*ConcatTwiceUnaryPattern, 2*ConstantToInitializerPattern, 1*IdentityPattern, 2*SlicesSplitPattern, 3*SqueezeUnsqueezePattern, 1*SwapExpandReshapePattern, 1*SwapRangeAddScalarPattern, 2*SwapUnsqueezeTransposePattern, 3*UnsqueezeUnsqueezePattern, 1*WhereAddPattern, 1*FunctionCosSinCachePattern - time=0.033 | max_time=GeluPattern:0.001
[GraphBuilderPatternOptimization-YDK.optimize] iteration 2: 117 nodes, priority=1
[GraphBuilderPatternOptimization-YDK.optimize] applies 6 matches, 1*ConcatReshapePattern, 2*IdentityPattern, 1*SameChildrenPattern, 2*FunctionHalfRotaryEmbeddingPattern - time=0.054 | max_time=SoftmaxCrossEntropyLossCastPattern:0.007
[GraphBuilderPatternOptimization-YDK.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-YDK.optimize] n_added=0, n_removed=0, n_applied=58 applied patterns, 104 nodes left with 1 iterations
[GraphBuilderPatternOptimization-YDK.optimize] increase priority to 2
[GraphBuilderPatternOptimization-YDK.optimize] iteration 3: 104 nodes, priority=2
[GraphBuilderPatternOptimization-YDK.optimize] applies 1 matches, [0]=MatchResult: SqueezeAddPattern replaces ['Squeeze', 'Squeeze', 'Add'] - time=0.051 | max_time=GeluPattern:0.002
[GraphBuilderPatternOptimization-YDK.optimize] iteration 4: 105 nodes, priority=2
[GraphBuilderPatternOptimization-YDK.optimize] applies 1 matches, [0]=MatchResult: SameChildrenPattern replaces ['Add', 'Add', 'Squeeze', 'Squeeze'] - time=0.046 | max_time=SameChildrenPattern:0.002
[GraphBuilderPatternOptimization-YDK.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-YDK.optimize] n_added=0, n_removed=0, n_applied=60 applied patterns, 103 nodes left with 1 iterations
[GraphBuilderPatternOptimization-YDK.optimize] increase priority to 3
[GraphBuilderPatternOptimization-YDK.optimize] iteration 5: 103 nodes, priority=3
[GraphBuilderPatternOptimization-YDK.optimize] applies 1 matches, [0]=MatchResult: FunctionCausalMaskPattern replaces ['Squeeze', 'Squeeze', 'Range', 'Range', 'Unsqueeze', 'Unsqueeze', 'LessOrEqual'] - time=0.031 | max_time=ShapeBasedEditDistanceReshapePattern:0.001
[GraphBuilderPatternOptimization-YDK.optimize] iteration 6: 101 nodes, priority=3
[GraphBuilderPatternOptimization-YDK.optimize] applies 2 matches, 1*SqueezeBinaryUnsqueezePattern, 1*FunctionCausalMaskMulAddPattern - time=0.041 | max_time=GeluPattern:0.002
[GraphBuilderPatternOptimization-YDK.optimize] iteration 7: 93 nodes, priority=3
[GraphBuilderPatternOptimization-YDK.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-YDK.optimize] done after 8 iterations with 93 nodes in 0.651
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.shape_order] -- starts with 92 nodes, 26 initializers
[OrderOptimization.shape_order] done after in 0.0015182059996732278s with changed=6 scale=31
[GraphBuilder-YDK.optimize] done with 92 nodes in 0.984
[GraphBuilder-YDK.to_onnx] make_model 30 inits 12 params
[GraphBuilder-YDK.time_evaluation_constants_] 0.0016301680007018149
[GraphBuilder-YDK._build_initializers] start with 30 initializers, large_model=False, external_threshold=1024
[GraphBuilder-YDK._build_initializers] switch low/high order
[GraphBuilder-YDK._build_initializers] done in 2.121099896612577e-05s with 26 initializers, 0 large initializers
[GraphBuilder-YDK._add_shape_information] dynamic shapes replacements={'D0_2': 'D0_2', 'batch_5': 'batch_5', 'batch_3': 'batch_3', 'D0_3': 'D0_3', 'channel_3': 'channel_3', 'channel_5': 'channel_5', 'batch_4': 'batch_4', 'batch_7': 'batch_7', 'channel_4': 'channel_4', 'batch_6': 'batch_6', 'DYN0': 'batch_3', 's67': 'batch_3', 'DYN0^s61': 'batch_3', 's72': 'batch_3', 'DYN6': 'batch_6', 's61': 'batch_3', 'DYN8': 'batch_7', 'batch': 'batch_3', 'DYN1': 'channel_3', 's70': 'channel_3', 'channel': 'channel_3', 'DYN2': 'batch_4', 's43': 'batch_4', 'batch_1': 'batch_4', 'DYN3': 'channel_4', 's53': 'channel_4', 'channel_1': 'channel_4', 'DYN4': 'batch_5', 'batch_2': 'batch_5', 's44': 'batch_5', 'DYN5': 'channel_5', 's9': 'channel_5', 'channel_2': 'channel_5', 'DYN7': 'D0_2', 'D0': 'D0_2', 's45': 'D0_2', 'DYN9': 'D0_3', 's21': 'D0_3', 'D0_1': 'D0_3'}
[to_onnx] to_onnx done in 1.4072204899985081s and 92 nodes, 26 initializers, 5 inputs, 3 outputs
[validate_model] export succeeded -> '/tmp/tmp_3tykcu_/arnir0-Tiny-LLM.yobx.22.default.patch.onnx'
[validate_model] checking discrepancies ...
[validate_model] discrepancies: 3/3 OK
[0] OK abs=1.24e-05 rel=0.000433
[1] OK abs=1e-05 rel=0.00141
[2] OK abs=1.05e-05 rel=0.00184
Summary#
ValidateSummary stores
high-level status flags and error messages. Every field that was not reached
(e.g. because an earlier step failed) remains None and is omitted from
items().
-- summary --
config_from_cache: bundled
discrepancies: OK
discrepancies_ok: 3
discrepancies_total: 3
export: OK
model_from_cache: True
model_id: arnir0/Tiny-LLM
n_captured: 3
prompt: Continue: it rains, what should I do?
Discrepancies#
ValidateData.discrepancies is the
raw list of dicts returned by
InputObserver.check_discrepancies.
Each row corresponds to one forward call captured during model.generate.
The most important columns are:
SUCCESS—Truewhen the absolute difference is below atol and the relative difference is below rtol.abs— maximum absolute element-wise difference across all outputs.rel— maximum relative element-wise difference.index— position in the capture sequence (0 = first forward call, i.e. the prefill step).inputs/outputs_torch/outputs_ort— shape strings for the feeds and outputs.
A pandas.DataFrame gives a compact overview:
if data.discrepancies is not None:
df = pandas.DataFrame(data.discrepancies)
# Only show the most informative columns if they exist.
cols = [c for c in ("index", "SUCCESS", "abs", "rel", "n_inputs") if c in df.columns]
print(df[cols].to_string(index=False))
else:
print("(no discrepancy data — export may have failed)")
index SUCCESS abs rel n_inputs
0 True 0.000012 0.000433 5
1 True 0.000010 0.001411 5
2 True 0.000010 0.001841 5
Interpreting the results#
A typical successful run shows SUCCESS=True for every row with very small
abs and rel values (well below 1e-4).
When the export fails, summary.export will be "FAILED" and
summary.error_export will contain the exception message. The discrepancy
check is skipped in that case.
When the export succeeds but outputs diverge, summary.discrepancies will
be "FAILED" with summary.discrepancies_ok < summary.discrepancies_total.
Increase verbose to 3 to print the input and output shapes for every
failing row.
Total running time of the script: (0 minutes 7.736 seconds)
Related examples