yobx.torch.validate#

Validates an ONNX export for a HuggingFace model using InputObserver and a default text prompt.

Unlike onnx_diagnostic.torch_models.validate.validate_model(), which uses random/dummy tensors as inputs, this module captures real model inputs by running the model on a default prompt through an InputObserver. The observed inputs and inferred dynamic shapes are then used for the ONNX export.

yobx.torch.validate.DEFAULT_PROMPT = 'Continue: it rains, what should I do?'#

Default text prompt used when validating text-generation models.

class yobx.torch.validate.ValidateData(config: Any | None = None, model: Any | None = None, input_ids: Any | None = None, attention_mask: Any | None = None, observer: Any | None = None, kwargs: Dict[str, Any] | None = None, dynamic_shapes: Any | None = None, filename: str | None = None, discrepancies: List[Dict[str, Any]] | None = None)[source]#

Intermediate artefacts collected by validate_model().

All fields default to None and are populated progressively as validation proceeds.

attention_mask: Any | None = None#

Attention mask tensor used during capture.

config: Any | None = None#

Loaded transformers config object.

discrepancies: List[Dict[str, Any]] | None = None#

Per-input-set discrepancy records from InputObserver.check_discrepancies().

dynamic_shapes: Any | None = None#

Inferred dynamic shapes passed to the exporter.

filename: str | None = None#

Path to the exported .onnx file.

input_ids: Any | None = None#

Input token ids tensor used during capture.

items()[source]#

Yield (field_name, value) pairs for every non-None field.

This mirrors dict.items() so that existing code such as for k, v in sorted(data.items()) keeps working without modification.

kwargs: Dict[str, Any] | None = None#

Inferred export keyword arguments.

model: Any | None = None#

Loaded (or randomly-initialised) PyTorch model.

observer: Any | None = None#

InputObserver instance after capture.

class yobx.torch.validate.ValidateSummary(model_id: str, prompt: str, config_from_cache: str | bool | None = None, config_overrides: str | None = None, error_config: str | None = None, error_tokenizer: str | None = None, model_from_cache: bool | None = None, error_model: str | None = None, n_captured: int | None = None, error_observer: str | None = None, export: str | None = None, error_export: str | None = None, discrepancies_ok: int | None = None, discrepancies_total: int | None = None, discrepancies: str | None = None, error_discrepancies: str | None = None)[source]#

Flat summary dictionary returned by validate_model().

Contains status flags and error messages collected during validation. Fields that were not reached (e.g. because an earlier step failed) remain None.

Parameters:
  • model_id – HuggingFace model identifier

  • prompt – Text prompt used during validation

  • config_from_cache – bundled”`` / "local" when config was loaded from cache, False for network.

  • config_overrides – String representation of the config overrides applied.

  • error_config – Error message if config loading failed.

  • error_tokenizer – Error message if tokenizer loading failed.

  • model_from_cacheTrue when the model was loaded from the local HF cache

  • error_model – Error message if model loading failed.

  • n_captured – Number of input sets captured by the InputObserver.

  • error_observer – Error message if input capture failed.

  • export"OK" or "FAILED" depending on whether the ONNX export succeeded.

  • error_export – Error message if the ONNX export failed.

  • discrepancies_ok – Number of input sets where ONNX Runtime results matched PyTorch.

  • discrepancies_total – Total number of input sets checked for discrepancies.

  • discrepancies"OK" or "FAILED" for the overall discrepancy check.

  • error_discrepancies – Error message if the discrepancy check raised an exception.

get(key: str, default: Any = None) Any[source]#

Return the value for key if present; otherwise default.

items()[source]#

Yield (field_name, value) pairs for every non-None field.

This mirrors dict.items() so that existing code such as for k, v in sorted(summary.items()) keeps working without modification.

keys()[source]#

Return an iterator over field names, mirroring dict.keys().

yobx.torch.validate.validate_model(model_id: str, prompt: str = 'Continue: it rains, what should I do?', exporter: str = 'yobx', optimization: str | None = 'default', verbose: int = 0, dump_folder: str | None = None, opset: int = 22, dtype: str | None = None, device: str | None = None, max_new_tokens: int = 10, do_run: bool = True, patch: bool = True, quiet: bool = False, tokenized_inputs: Dict[str, Any] | None = None, config_overrides: Dict[str, Any] | None = None, random_weights: bool = False) Tuple[ValidateSummary, ValidateData][source]#

Validates an ONNX export for any HuggingFace model_id by capturing real model inputs with InputObserver and a text prompt, instead of relying on dummy / random tensors.

The function:

  1. Loads the model config for model_id (cached copy preferred).

  2. Optionally applies config_overrides to tweak the architecture (e.g. reduce num_hidden_layers for a faster test).

  3. Loads the tokeniser and the pretrained model weights — or, when random_weights is True, instantiates the model directly from the (potentially modified) config without downloading any weights.

  4. Runs model.generate() with prompt inside an InputObserver context to collect real input/output tensors.

  5. Exports the model to ONNX using the observed inputs and the inferred dynamic shapes.

  6. Computes discrepancies between the original PyTorch outputs and the ONNX runtime outputs for every captured input set.

Parameters:
  • model_id – HuggingFace model id, e.g. "arnir0/Tiny-LLM"

  • prompt – Text prompt used to drive the generation step. Defaults to DEFAULT_PROMPT.

  • exporter – ONNX exporter to use, e.g. "yobx" (default), "modelbuilder", or "onnx-dynamo".

  • optimization – Optimisation level applied after export. Passed directly to yobx.torch.to_onnx(). None means no optimisation; "default" applies the default set.

  • verbose – Verbosity level (0 = silent).

  • dump_folder – When given, all artefacts (ONNX file, export logs …) are saved under this directory.

  • opset – ONNX opset version to target (default 22).

  • dtype – Cast the model (and inputs) to this dtype before exporting, e.g. "float16". None keeps the default (float32).

  • device – Run on this device, e.g. "cpu" or "cuda". None defaults to CPU.

  • max_new_tokens – Number of tokens generated by model.generate() during input capture (default 10). Larger values capture more past-key-value shapes.

  • do_run – When True (default), checks that the ONNX model can be run after export and computes discrepancies.

  • patch – Apply apply_patches_for_model and register_flattening_functions during export (default True).

  • quiet – When True, exceptions are caught and reported in the returned summary dictionary rather than re-raised.

  • tokenized_inputs – Optional pre-tokenized inputs to use instead of running the tokenizer on prompt. Should be a dict with at least "input_ids" and optionally "attention_mask" (mirrors the output of a HuggingFace tokenizer). When provided the tokenizer is not loaded and prompt is only stored in the summary for reference.

  • config_overrides – Optional mapping of config attribute names to new values applied to the model config before the model is instantiated, e.g. {"num_hidden_layers": 2}. Useful to create a smaller model for testing without changing the architecture definition on disk.

  • random_weights – When True, instantiate the model from the (possibly modified) config with random weights instead of downloading the pretrained weights. This avoids any network access for the model itself, which is useful for fast unit-testing or CI validation.

Returns:

A 2-tuple (summary, data) where summary is a ValidateSummary instance with status flags and error messages, and data is a ValidateData instance that collects all intermediate artefacts.

Example:

from yobx.torch.validate import validate_model

summary, data = validate_model("arnir0/Tiny-LLM", verbose=1)
for k, v in sorted(summary.items()):
    print(f":{k},{v};")