.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_transformers/plot_validate_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_transformers_plot_validate_model.py: .. _l-plot-validate-model: Validate a LLM export and inspect discrepancies ================================================ :func:`validate_model ` is a convenience function that bundles the entire export-and-verify pipeline into a single call: 1. Load the model config (bundled copy preferred, then local HF cache, then network). 2. Optionally apply *config_overrides* to reduce the model size for fast testing. 3. Load the pre-trained weights — or, when *random_weights* is ``True``, instantiate the model from the (possibly modified) config with random weights so that no large checkpoint is downloaded. 4. Run ``model.generate`` with a text prompt inside an :class:`InputObserver ` context to capture real input/output tensors. 5. Infer export arguments and dynamic shapes from the captured tensors. 6. Export the model to ONNX. 7. Run ONNX Runtime on every captured input set, compare the outputs against the original PyTorch outputs, and report the per-step discrepancies. The function returns a :class:`ValidateSummary ` with status flags / error messages and a :class:`ValidateData ` with all intermediate artefacts including the raw discrepancy records. **Command-line options** All steps after step 2 can be exercised offline:: python plot_validate_model.py --no-trained # random weights, no download python plot_validate_model.py --num-hidden-layers 1 # 1-layer model, faster python plot_validate_model.py --model arnir0/Tiny-LLM # different model .. GENERATED FROM PYTHON SOURCE LINES 39-41 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 41-47 .. code-block:: Python import argparse import sys import pandas from yobx.torch.validate import validate_model, ValidateSummary, ValidateData .. GENERATED FROM PYTHON SOURCE LINES 48-50 Command-line arguments ---------------------- .. GENERATED FROM PYTHON SOURCE LINES 50-91 .. code-block:: Python _DEFAULT_MODEL = "arnir0/Tiny-LLM" parser = argparse.ArgumentParser(description="Validate a HuggingFace LLM export to ONNX.") parser.add_argument( "--model", default=_DEFAULT_MODEL, metavar="MODEL_ID", help=f"HuggingFace model ID (default: {_DEFAULT_MODEL!r}).", ) parser.add_argument( "--trained", action=argparse.BooleanOptionalAction, default=True, help=( "Load pre-trained weights (default). " "Pass --no-trained to use random weights (no large download)." ), ) parser.add_argument( "--num-hidden-layers", type=int, default=None, metavar="LAYERS", help="Override config.num_hidden_layers (reduces model size for testing).", ) parser.add_argument( "--max-new-tokens", type=int, default=3, metavar="N", help="Number of tokens generated during input capture (default: 3).", ) # parse_known_args avoids failures when sphinx-gallery passes extra arguments. args, _ = parser.parse_known_args(sys.argv[1:]) config_overrides = {} if args.num_hidden_layers is not None: config_overrides["num_hidden_layers"] = args.num_hidden_layers .. GENERATED FROM PYTHON SOURCE LINES 92-99 Run validate_model ------------------ :func:`validate_model ` orchestrates the entire pipeline. Setting *verbose=2* prints a one-line status for each captured input set during the discrepancy check (index, SUCCESS, absolute and relative differences). Use *verbose=3* to additionally print tensor shapes. .. GENERATED FROM PYTHON SOURCE LINES 99-112 .. code-block:: Python summary: ValidateSummary data: ValidateData summary, data = validate_model( args.model, random_weights=not args.trained, max_new_tokens=args.max_new_tokens, config_overrides=config_overrides or None, quiet=True, verbose=2, ) .. rst-class:: sphx-glr-script-out .. code-block:: none [validate_model] loading config for 'arnir0/Tiny-LLM' [validate_model] loading tokenizer for 'arnir0/Tiny-LLM' [validate_model] loading model for 'arnir0/Tiny-LLM' Loading weights: 0%| | 0/12 [00:00 '/tmp/tmp_3tykcu_/arnir0-Tiny-LLM.yobx.22.default.patch.onnx' [to_onnx] build the graph module from , type(args)= [to_onnx] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None} [_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', )) [_make_builder_interpreter] input args=() [_make_builder_interpreter] input kwargs=dict(input_ids:T7r2,attention_mask:T7r2,position_ids:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),logits_to_keep:int) [_make_builder_interpreter] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None} [_make_builder_interpreter] same_signature=True, tracing_mode=symbolic [ExportOptions.export] ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', )) - export 'LlamaForCausalLM' [ExportOptions.export] torch.export.export strict=False [ExportOptions.export] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None} [ExportOptions.export] args=() [ExportOptions.export] kwargs=dict(input_ids:T7r2,attention_mask:T7r2,position_ids:T7r2,past_key_values:DynamicCache(key_cache=#1[T1r4], value_cache=#1[T1r4]),logits_to_keep:int) [ExportOptions.export] export start with strict=False backed_size_oblivious=auto [ExportOptions.export] export done in 2.8982437800004845 [ExportOptions.export] post_process_exported_program with decomposition_table=None [ExportOptions.export] remove inplace nodes [ExportOptions.export] slices: 9 slices nodes were removed [CustomTracer.remove_inplace] starts with 175 nodes (n_inplace_submobules=0) [CustomTracer.remove_inplace] S1: 15 inplace nodes [CustomTracer.remove_inplace] S2: 7 inplace nodes and 100 iterations [CustomTracer.remove_inplace] end with 100 iterations and 144 nodes (n_inplace=7) [ExportOptions.export] inplaces: 15 inplaced nodes were removed [ExportOptions.export] done remove inplace in 0.010209355999904801, modified=15 [ExportOptions.export] done with no decomposition in 0.010503989000426373 [to_onnx] graph module done in 2.915339342000152 s [to_onnx] start creating the onnx nodes [to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True) 0%| | 0/144 [00:00 '/tmp/tmp_3tykcu_/arnir0-Tiny-LLM.yobx.22.default.patch.onnx' [validate_model] checking discrepancies ... [validate_model] discrepancies: 3/3 OK [0] OK abs=1.24e-05 rel=0.000433 [1] OK abs=1e-05 rel=0.00141 [2] OK abs=1.05e-05 rel=0.00184 .. GENERATED FROM PYTHON SOURCE LINES 113-120 Summary ------- :class:`ValidateSummary ` stores high-level status flags and error messages. Every field that was not reached (e.g. because an earlier step failed) remains ``None`` and is omitted from :meth:`items`. .. GENERATED FROM PYTHON SOURCE LINES 120-125 .. code-block:: Python print("-- summary --") for k, v in sorted(summary.items()): print(f" {k}: {v}") .. rst-class:: sphx-glr-script-out .. code-block:: none -- summary -- config_from_cache: bundled discrepancies: OK discrepancies_ok: 3 discrepancies_total: 3 export: OK model_from_cache: True model_id: arnir0/Tiny-LLM n_captured: 3 prompt: Continue: it rains, what should I do? .. GENERATED FROM PYTHON SOURCE LINES 126-146 Discrepancies ------------- :attr:`ValidateData.discrepancies ` is the raw list of dicts returned by :meth:`InputObserver.check_discrepancies `. Each row corresponds to one forward call captured during ``model.generate``. The most important columns are: * ``SUCCESS`` — ``True`` when the absolute difference is below *atol* and the relative difference is below *rtol*. * ``abs`` — maximum absolute element-wise difference across all outputs. * ``rel`` — maximum relative element-wise difference. * ``index`` — position in the capture sequence (0 = first forward call, i.e. the prefill step). * ``inputs`` / ``outputs_torch`` / ``outputs_ort`` — shape strings for the feeds and outputs. A :class:`pandas.DataFrame` gives a compact overview: .. GENERATED FROM PYTHON SOURCE LINES 146-155 .. code-block:: Python if data.discrepancies is not None: df = pandas.DataFrame(data.discrepancies) # Only show the most informative columns if they exist. cols = [c for c in ("index", "SUCCESS", "abs", "rel", "n_inputs") if c in df.columns] print(df[cols].to_string(index=False)) else: print("(no discrepancy data — export may have failed)") .. rst-class:: sphx-glr-script-out .. code-block:: none index SUCCESS abs rel n_inputs 0 True 0.000012 0.000433 5 1 True 0.000010 0.001411 5 2 True 0.000010 0.001841 5 .. GENERATED FROM PYTHON SOURCE LINES 156-170 Interpreting the results ------------------------ A typical successful run shows ``SUCCESS=True`` for every row with very small ``abs`` and ``rel`` values (well below ``1e-4``). When the export fails, ``summary.export`` will be ``"FAILED"`` and ``summary.error_export`` will contain the exception message. The discrepancy check is skipped in that case. When the export succeeds but outputs diverge, ``summary.discrepancies`` will be ``"FAILED"`` with ``summary.discrepancies_ok < summary.discrepancies_total``. Increase *verbose* to 3 to print the input and output shapes for every failing row. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.736 seconds) .. _sphx_glr_download_auto_examples_transformers_plot_validate_model.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_validate_model.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_validate_model.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_validate_model.zip ` .. include:: plot_validate_model.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_