onnx_diagnostic.investigate.input_observer

This one is in torch/onnx/_internal/exporter/_input_observer

class onnx_diagnostic.investigate.input_observer.InputCandidate(args: tuple[Any, ...], kwargs: dict[str, Any], clone: bool, cst_kwargs: dict[str, int | str | float | bool])[source][source]

Retains one set of inputs given to the forward method or any other method the class InputObserver is stealing from. Any class is allowed as long as it can be flattened.

Args:
args:

Positional arguments.

kwargs:

Optional arguments.

clone:

Clones the inputs before storing them. Some tensors may be modified inplace, the original value must be retained.

cst_kwargs:

Any optional arguments constant over multiple calls. int, float, str, bool values must be stored here.

The constructor flattens the received arguments. Any necessary flattening function should have been registered first.

align_with(best_candidate: InputCandidate, captured_inputs: dict[int | str, int], signature_names: list[str])[source][source]

Two candidates are considered as aligned if after being flattened if they have the same number of tensors (None allowed).

property n_tensors_for_args_kwargs: dict[int | str, int]

Returns the number of flat tensors in every args or kwargs.

property position_to_args_kwargs: list[int | str]

Returns the corresponding args or kwargs for every tensor in the flattened inputs.

str_obs() str[source][source]

Prints out some information about the osbervations.

class onnx_diagnostic.investigate.input_observer.InputObserver(missing: dict[str, Any] | None = None)[source][source]

Steals forward method to collect inputs and outputs. This information is used to infer dynamic shapes and export arguments.

Args:
missing: If a named argument (in kwargs) is missing,

a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as pixel_values) and another one is added (such as past_key_values). The values are only to infer dynamic shapes and arguments, not to run the model.

Examples

>>> input_observer = InputObserver()
>>> with input_observer(model):
>>>     model(x1, y1)
>>>     model(x2, y2)
>>> ep = torch.export.export(  # or torch.onnx.export
>>>     model,
>>>     input_observer.infer_arguments(),
>>>     dynamic_shapes.input_observer.infer_dynamic_shapes(),
>>> )

With LLM:

>>> input_observer = InputObserver()
>>> with input_observer(model):
>>>     model.generate(input_ids)
>>> ep = torch.export.export(  # or torch.onnx.export
>>>     model,
>>>     (),
>>>     kwargs=input_observer.infer_arguments(),
>>>     dynamic_shapes.input_observer.infer_dynamic_shapes(),
>>> )

Examples can be found in Export a LLM with InputObserver (with Tiny-LLM), Export whisper-tiny with InputObserver, Export Gemma3 tiny random with InputObserver.

check_discrepancies(onnx_model: str | ModelProto, atol: float = 0.0001, rtol: float = 0.1, hist=(0.1, 0.01), progress_bar: bool = False, include_io: bool = True) list[dict[str, str | int | float | bool]][source][source]

Computes the discrepancies between the saved inputs and outputs with the saved onnx model.

Args:
onnx_model:

ONNX Model to verify.

atol:

Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.

rtol:

Relative tolerance.

hist:

Thresholds, the function determines the number of discrepancies above these thresholds.

progress_bar:

Shows a progress bar (requires tqdm).

include_io:

Shows inputs/outputs shapes in the summary returned by this function.

Returns:

A list of dictionaries, ready to be consumed by a dataframe.

The function catches exceptions, it shows the error in the returned summary.

infer_arguments(index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None, flat: bool = False) list[Tensor] | tuple[Tensor, ...] | dict[str, Tensor][source][source]

Infers arguments based on the collected tensors.

Args:
index_or_args_or_kwargs: If missing, the method selects one set of inputs

among the available ones, usually this inputs containing the set of stored inputs with the highest number of tensors. The then replaces None values and missing tensors by empty tensors. If not missing, it can be an integer to fetch one of the stored set or some inputs.

flat: If True, it returns a flattened list of tensors,

if False, it returns a tuple or a dictionary preserving the nested structures.

Returns:

Inferred arguments, every optional tensor is replaced by a empty tensor.

infer_dynamic_shapes(set_batch_dimension_for: set[int | str] | bool | None = None) tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None][source][source]

Infers dynamic shapes. Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument set_batch_dimension_for can be used to tell the first dimension is a dynamic dimension for a particular set of inputs referenced by their name (str) or their position (int).

Args:
set_batch_dimension_for (set[int | str] | None): A set of input

identifiers (by position as int or by name as str) for which the first dimension should be treated as a dynamic batch dimension. If None, no dimensions are explicitly marked as dynamic.

num_obs() int[source][source]

Returns the number of stored set if inputs.

class onnx_diagnostic.investigate.input_observer.InputObserverInfo(signature_names: list[str], default_values: dict[str, int | bool | str | float], missing: dict[str, Any])[source][source]

Contains all the necessary information to infer dynamic shapes and the arguments to send to torch.export.export().

Args:
signature_names: Names of the arguments of the method

the collector tensors come from. They are used if it becomes necessary to move positional arguments to named ones. They are used a second time because torch.export.export() cares about the order in kwargs and dynamic shapes, it needs to be the same in the ordered dictionaries add_inputs receive.

default_values: Default values defined by the signature of the function,

any value equal to that is ignore to simplify the export.

missing: If a named argument (in kwargs) is missing,

a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as pixel_values) and another one is added (such as past_key_values). The values are only to infer dynamic shapes and arguments, not to run the model.

add_inputs(args: tuple[Any, ...], kwargs: dict[str, Any])[source][source]

Stores one set of inputs. They are deepcopied.

Args:

args: Positional arguments. kwargs: Named arguments.

add_outputs(res: Tensor | tuple[Tensor, ...], latency: float)[source][source]

Stores outputs. They are deepcopied.

align_inputs_none_values()[source][source]

Once the best candidate is chosen, this method aligns every set of inputs on the best candidate, it inserts None at the right position when optional inputs are not specified. We consider a set of inputs is aligned if this method does not change the original flattened inputs.

infer_arguments(index_or_candidate: InputCandidate | int | None = None, flat: bool = False) list[Tensor] | tuple[Tensor, ...] | dict[str, Tensor][source][source]

Infers arguments based on the collected tensors.

infer_dynamic_shapes(set_batch_dimension_for: set[int | str] | bool | None = None, return_flat: bool = False) tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None][source][source]

Infers dynamic shapes based on the collected tensors. Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument set_batch_dimension_for can be used to tell the first dimension is a dynamic dimension for a particular set of inputs referenced by their name (str) or their position (int).

Args:
set_batch_dimension_for (set[int | str] | None): Set of input identifiers,

by name (str) or position (int), for which the first dimension should be treated as a dynamic batch dimension. If None or empty, no additional batch dimensions are marked as dynamic.

return_flat: Tells the function to return a flat tuple instead of

nested structured.