Source code for onnx_diagnostic.investigate.input_observer

import contextlib
import inspect
import time
from typing import Any, Callable, Sequence
import onnx
import torch
from ..helpers import max_diff, string_type
from ..reference import OnnxruntimeEvaluator

EOL = "\n"


def _flatten_unflatten_for_dynamic_shapes(
    obj: Any,
    use_dict: bool = True,
    change_function: Callable[[torch.Tensor], Any] | None = None,
) -> Any:
    """Returns the object in a different structure similar to what
    the definition of the dynamic shapes should use.

    Args:
        obj:
            object from a custom class
        use_dict:
            closer to the original result but
            :func:`torch.export.export` only considers the values,
            the context gives the dictionary keys but it is not expressed
            in the dynamic shapes, these specifications seems to be different
            for the strict and non strict mode. It also preserves tuple.
        change_function:
            to modify the tensor in the structure itself,
            like replace them by a shape

    Returns:
        the flattened object
    """
    if isinstance(obj, torch.Tensor):
        return change_function(obj) if change_function else obj
    flat, spec = torch.utils._pytree.tree_flatten(obj)
    start = 0
    end = 0
    subtrees = []
    for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
        end += subspec.num_leaves
        value = subspec.unflatten(flat[start:end])
        value = _flatten_unflatten_for_dynamic_shapes(
            value, use_dict=use_dict, change_function=change_function
        )
        subtrees.append(value)
        start = end
    if use_dict:
        if spec.type is dict:
            # This is a dictionary.
            return dict(zip(spec.context, subtrees))
        if spec.type is tuple:
            return tuple(subtrees)
        if spec.type is list:
            return list(subtrees)
        if spec.type is None and not subtrees:
            return None
        if spec.context:
            # This is a custom class with attributes.
            # It is returned as a list.
            return list(subtrees)
        raise ValueError(
            f"Unable to interpret spec type {spec.type} "
            f"(type is {type(spec.type)}, context is {spec.context}), "
            f"spec={spec}, subtrees={subtrees}"
        )
    # This is a list.
    return subtrees


def _infer_dynamic_dimensions(
    shape_list: Sequence[tuple[int, ...]], set_batch_dimension: bool = False
) -> list[int]:
    """Returns the list of dynamic dimensions given a list of shapes
    corresponding to the same tensor.

    Args:
        shape_list:
            list of shapes, they must all have the same length
        set_batch_dimension:
            forces the first dimension to be treated as dynamic,
            even if all shapes have the same value for that dimension

    Returns:
        list of dynamic dimensions
    """
    unique_ranks = {len(shape) for shape in shape_list}
    torch._check(
        len(unique_ranks) == 1,
        lambda: "all shapes in shape_list must have the same rank",
    )
    rank = unique_ranks.pop()
    dynamic = []
    for i in range(rank):
        dims = [shape[i] for shape in shape_list]
        if len(set(dims)) > 1 or (i == 0 and set_batch_dimension):
            dynamic.append(i)
    return dynamic


[docs] class InputCandidate: """Retains one set of inputs given to the forward method or any other method the class :class:`InputObserver` is stealing from. Any class is allowed as long as it can be flattened. Args: args: Positional arguments. kwargs: Optional arguments. clone: Clones the inputs before storing them. Some tensors may be modified inplace, the original value must be retained. cst_kwargs: Any optional arguments constant over multiple calls. int, float, str, bool values must be stored here. The constructor flattens the received arguments. Any necessary flattening function should have been registered first. """ def __init__( self, args: tuple[Any, ...], kwargs: dict[str, Any], clone: bool, cst_kwargs: dict[str, int | str | float | bool], ): self.args = args self.kwargs = kwargs self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs)) self.n_tensors = sum(t is not None for t in self.flat_list) self._position_to_args_kwargs: list[int | str] | None = None self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None self.cst_kwargs = cst_kwargs.copy() assert "use_cache" not in self.cst_kwargs if clone: self.flat_list = [ (None if not isinstance(t, torch.Tensor) else t.clone().detach()) for t in self.flat_list ] self.args, self.kwargs = torch.utils._pytree.tree_unflatten( self.flat_list, self.spec ) self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None self.aligned_flat_list: list[torch.Tensor | None] | None = None def __str__(self) -> str: return ( f"{self.__class__.__name__}({len(self.args)} args, " f"{len(self.kwargs)} kwargs, {len(self.flat_list)} tensors, " f"{len(self.aligned_flat_list or [])} aligned tensors)" ) def __len__(self) -> int: """Returns the number of flattended tensors, None tensors are included.""" return len(self.flat_list)
[docs] def str_obs(self) -> str: """Prints out some information about the osbervations.""" return ( f"InputCandidate(args={string_type(self.args, with_shape=True)}, " f"kwargs={string_type(self.kwargs, with_shape=True)}, " f"cst_kwargs={self.cst_kwargs})" )
def build_mappings(self) -> list[int | str]: if self._position_to_args_kwargs is not None: return self._position_to_args_kwargs self._n_tensors_for_args_kwargs = {} flat_index_to_args: list[int | str] = [] for index_args, a in enumerate(self.args): size = len(torch.utils._pytree.tree_flatten(a)[0]) self._n_tensors_for_args_kwargs[index_args] = size flat_index_to_args.extend([index_args] * size) for k, v in self.kwargs.items(): size = len(torch.utils._pytree.tree_flatten(v)[0]) self._n_tensors_for_args_kwargs[k] = size flat_index_to_args.extend([k] * size) self._position_to_args_kwargs = flat_index_to_args return self._position_to_args_kwargs @property def position_to_args_kwargs(self) -> list[int | str]: """Returns the corresponding args or kwargs for every tensor in the flattened inputs. """ if self._position_to_args_kwargs is None: self.build_mappings() # type checking is missing it assert self._position_to_args_kwargs is not None return self._position_to_args_kwargs @property def n_tensors_for_args_kwargs(self) -> dict[int | str, int]: """Returns the number of flat tensors in every args or kwargs.""" if self._n_tensors_for_args_kwargs is None: self.build_mappings() # type checking is missing it assert self._n_tensors_for_args_kwargs is not None return self._n_tensors_for_args_kwargs def _set_aligned_flat_list( self, aligned_flat_list: list[torch.Tensor | None], aligned_spec: torch.utils._pytree.PyTreeSpec, ): self.aligned_flat_list = aligned_flat_list self.aligned_spec = aligned_spec
[docs] def align_with( self, best_candidate: "InputCandidate", captured_inputs: dict[int | str, int], signature_names: list[str], ): """Two candidates are considered as aligned if after being flattened if they have the same number of tensors (None allowed).""" if self.cst_kwargs != best_candidate.cst_kwargs: raise RuntimeError( f"Two calls were made with different constant values, " f"{self.cst_kwargs} != {best_candidate.cst_kwargs}" ) args = self.args if len(self.args) > len(best_candidate.args): # We need to move some args to kwargs as the best_candidate does. new_kwargs = {} for i in range(len(best_candidate.args), len(self.args)): new_kwargs[signature_names[i]] = args[i] args = args[: len(best_candidate.args)] kwargs = {**new_kwargs, **self.kwargs} else: kwargs = self.kwargs flat = [] for i in range(len(best_candidate.args)): if i < len(args) and (isinstance(args[i], torch.Tensor) or args[i]): ts = torch.utils._pytree.tree_flatten(self.args[i])[0] if i in captured_inputs and captured_inputs[i] != len(ts): raise RuntimeError( f"Positional argument {i} has {len(ts)} tensors " f"but previously got {captured_inputs[i]} tensors. " f"Inference is impossible in that case." ) captured_inputs[i] = len(ts) flat.extend(ts) continue # If the argument i is not specified or is None or an empty container. flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[i])]) for k in best_candidate.kwargs: if k in kwargs and (isinstance(kwargs[k], torch.Tensor) or kwargs[k]): ts = torch.utils._pytree.tree_flatten(kwargs[k])[0] if k in captured_inputs and captured_inputs[k] != len(ts): raise RuntimeError( f"Named argument {k!r} has {len(ts)} tensors " f"but previously got {captured_inputs[k]} tensors in " f"kwargs={list(kwargs)}. " f"Inference is impossible in that case." ) captured_inputs[k] = len(ts) flat.extend(ts) continue # If the argument k is not specified or is None or an empty container. flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[k])]) self._set_aligned_flat_list(flat, best_candidate.spec)
@property def n_aligned_tensors(self) -> int: if self.aligned_flat_list is None: raise RuntimeError("This input was not aligned with the others.") return len(self.aligned_flat_list)
[docs] class InputObserverInfo: """Contains all the necessary information to infer dynamic shapes and the arguments to send to :func:`torch.export.export`. Args: signature_names: Names of the arguments of the method the collector tensors come from. They are used if it becomes necessary to move positional arguments to named ones. They are used a second time because :func:`torch.export.export` cares about the order in kwargs and dynamic shapes, it needs to be the same in the ordered dictionaries `add_inputs` receive. default_values: Default values defined by the signature of the function, any value equal to that is ignore to simplify the export. missing: If a named argument (in kwargs) is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one is added (such as `past_key_values`). The values are only to infer dynamic shapes and arguments, not to run the model. """ def __init__( self, signature_names: list[str], default_values: dict[str, int | bool | str | float], missing: dict[str, Any], ): self.default_values = default_values self.missing = missing self.inputs: list[InputCandidate] = [] self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_outputs: list[list[torch.Tensor | None]] = [] self.latencies: list[float] = [] self.signature_names = signature_names self._best_candidate: InputCandidate | None = None self._captured_inputs: dict[int | str, int] | None = None def __len__(self) -> int: """Returns the number of collected set of inputs/outputs.""" return len(self.inputs)
[docs] def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): """Stores one set of inputs. They are deepcopied. Args: args: Positional arguments. kwargs: Named arguments. """ cst_kwargs = { k: v for k, v in kwargs.items() if k in self.signature_names and isinstance(v, (int, float, bool, str)) and v != self.default_values.get(k, None) and self.default_values.get(k, None) is not None } kwargs = { k: v for k, v in kwargs.items() if v is not None and not isinstance(v, (int, float, bool)) } # adds missing attributes for k, v in self.missing.items(): if k not in kwargs: kwargs[k] = v # kwargs may come in a different ordeer teach. # dictionaries are ordered and torch.export.export expects # dynamic shapes an kwargs to follow the same order. ordered_kwargs = {k: kwargs[k] for k in self.signature_names if k in kwargs} for k, v in kwargs.items(): if k not in ordered_kwargs: ordered_kwargs[k] = v candidate = InputCandidate(args, ordered_kwargs, clone=True, cst_kwargs=cst_kwargs) self.inputs.append(candidate) if self._best_candidate is None or len(self._best_candidate) < len(candidate): self._best_candidate = candidate
[docs] def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...], latency: float): """Stores outputs. They are deepcopied.""" flat_res, spec = torch.utils._pytree.tree_flatten(res) self.outputs_specs.append(spec) self.flat_outputs.append( [(None if t is None else t.clone().detach()) for t in flat_res] ) self.latencies.append(latency)
[docs] def align_inputs_none_values(self): """Once the best candidate is chosen, this method aligns every set of inputs on the best candidate, it inserts None at the right position when optional inputs are not specified. We consider a set of inputs is aligned if this method does not change the original flattened inputs. """ if not self.inputs or self._best_candidate is None: raise RuntimeError("No inputs were captured.") if all(candidate.aligned_flat_list is not None for candidate in self.inputs): # No new inputs, no alignment is necessary. return # Let's reprocess everything. self._captured_inputs = {} for candidate in self.inputs: if len(set(candidate.kwargs) | set(self._best_candidate.kwargs)) > len( self._best_candidate.kwargs ): raise RuntimeError( f"At least one call to the observed model " f"must contain all the named arguments. " f"candidate kwargs={list(candidate.kwargs)}, " f"best candidate kwargs={list(self._best_candidate.kwargs)}, " f"all candidate kwargs={EOL}" f"{EOL.join(string_type(c.kwargs, with_shape=True) for c in self.inputs)}" ) candidate.align_with( self._best_candidate, self._captured_inputs, self.signature_names )
[docs] def infer_dynamic_shapes( self, set_batch_dimension_for: set[int | str] | bool | None = None, return_flat: bool = False, ) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]: """Infers dynamic shapes based on the collected tensors. Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument `set_batch_dimension_for` can be used to tell the first dimension is a dynamic dimension for a particular set of inputs referenced by their name (str) or their position (int). Args: set_batch_dimension_for (set[int | str] | None): Set of input identifiers, by name (``str``) or position (``int``), for which the first dimension should be treated as a dynamic batch dimension. If ``None`` or empty, no additional batch dimensions are marked as dynamic. return_flat: Tells the function to return a flat tuple instead of nested structured. """ self.align_inputs_none_values() # type checking assert self._best_candidate is not None assert self._best_candidate.flat_list is not None assert self._best_candidate.aligned_flat_list is not None def _set_batch_dimension(name_or_position): if not set_batch_dimension_for: return False if ( isinstance(set_batch_dimension_for, bool) and set_batch_dimension_for ) or name_or_position in set_batch_dimension_for: return True if isinstance(name_or_position, int): torch._check( name_or_position < len(self.signature_names), lambda: f"argument at position {name_or_position} is out of boundary", ) if self.signature_names[name_or_position] in set_batch_dimension_for: return True return False def _set_batch_dimension_for_flat_index(index): # type checking assert self._best_candidate is not None return _set_batch_dimension(self._best_candidate.position_to_args_kwargs[index]) if len(self._best_candidate.flat_list) != len(self._best_candidate.aligned_flat_list): raise NotImplementedError( "infer_dynamic_shapes is not implemented " "when the best candidate is not 'aligned'." "This happens when there is not stored set inputs where " "all optional inputs showing in other sets are defined." ) if len({inputs.n_aligned_tensors for inputs in self.inputs}) != 1: raise NotImplementedError( f"infer_dynamic_shapes is not implemented " f"when the number of input tensors are not the same in " f"every set of inputs " f"{[inputs.n_aligned_tensors for inputs in self.inputs]}." ) shape_lists = [ [(None if t is None else t.shape) for t in candidate.aligned_flat_list] for candidate in self.inputs if candidate.aligned_flat_list is not None ] n_tensors = len(shape_lists[0]) dynamic_shapes = [ _infer_dynamic_dimensions( [s for s in [shapes[index] for shapes in shape_lists] if s is not None], set_batch_dimension=_set_batch_dimension_for_flat_index(index), ) for index in range(n_tensors) ] cst = torch.export.Dim.DYNAMIC flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes] if return_flat: return tuple(flat_dynamic_shapes) if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len( self._best_candidate.kwargs ): # It means forward method is called with tensors only. if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs: # only positional arguments return tuple(flat_dynamic_shapes) if not self._best_candidate.args: # only named arguments ds = dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes)) return {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} # positional arguments needs to be moved to the named arguments n_args = len(self._best_candidate.args) pos_names = self.signature_names[:n_args] return { **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), **dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])), **dict.fromkeys(self._best_candidate.cst_kwargs, None), } # nested types, here comes the fun part because the shapes cannot be unflattened, # custom classes must appear in their flattened shape. # This does not work in all cases but every time every available argument is flattened # with the same number of tensors. The function does not check # if that assumption is true. flat_inputs, _max_spec = torch.utils._pytree.tree_flatten( (self._best_candidate.args, self._best_candidate.kwargs) ) torch._check( len(flat_inputs) == len(flat_dynamic_shapes), ( f"Length mismatch len(flat_inputs)={len(flat_inputs)}, " f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}" ), ) index = 0 def change_function(t): nonlocal index if index >= len(flat_dynamic_shapes): raise RuntimeError( f"Flattened {index} tensors when there are only " f"{len(flat_dynamic_shapes)}." ) res = flat_dynamic_shapes[index] index += 1 return res ds_args, ds_kwargs = _flatten_unflatten_for_dynamic_shapes( (self._best_candidate.args, self._best_candidate.kwargs), change_function=change_function, ) if self._best_candidate.cst_kwargs: ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} if not ds_kwargs: return tuple(ds_args) if not ds_args: return ds_kwargs pos_names = self.signature_names[: len(ds_args)] return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
[docs] def infer_arguments( self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False ) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]: """Infers arguments based on the collected tensors.""" # This is already checked by _build_inputs_completed_with_none_values # but this is not always well captured by tools checking types. self.align_inputs_none_values() torch._check(self._best_candidate is not None, lambda: "No input was captured.") # type checking assert self._best_candidate is not None candidate = None if index_or_candidate is None: for cand in self.inputs: args, kwargs = cand.args, cand.kwargs if len(args) == len(self._best_candidate.args) and len(kwargs) == len( self._best_candidate.kwargs ): candidate = cand break elif isinstance(index_or_candidate, int): torch._check( index_or_candidate < len(self.inputs), lambda: ( f"No stored input set for index=" f"{index_or_candidate}<{len(self.inputs)}." ), ) candidate = self.inputs[index_or_candidate] else: candidate = index_or_candidate torch._check(candidate is not None, "No input was captured.") # type checking assert candidate is not None if candidate.aligned_flat_list is None: raise RuntimeError( f"Candidate {candidate} has no aligned flat list of tensors, " f"index_or_candidate={index_or_candidate}. You should call " f"method 'align_with'." ) aligned_flat_list = candidate.aligned_flat_list if any(t is None for t in aligned_flat_list): dynamic_shapes = self.infer_dynamic_shapes(return_flat=True) # type checking assert isinstance(dynamic_shapes, tuple) aligned_flat_list = list(aligned_flat_list) for index in range(len(aligned_flat_list)): if aligned_flat_list[index] is not None: continue shape = dynamic_shapes[index] all_non_empty_tensors = [ c.aligned_flat_list[index] for c in self.inputs if c.aligned_flat_list is not None ] all_non_empty_tensors_not_none = [ t for t in all_non_empty_tensors if t is not None ] if not all_non_empty_tensors_not_none: raise RuntimeError( f"There is no tensor at position {index} in any flattened inputs." ) tensor = all_non_empty_tensors_not_none.pop() if tensor.numel() == 0: aligned_flat_list[index] = tensor continue if not shape: aligned_flat_list[index] = torch.zeros( tensor.shape, dtype=tensor.dtype, device=tensor.device ) continue dim = max(shape) torch._check( dim < tensor.ndim, lambda index=index, shape=shape, tshape=tensor.shape: ( f"Tensor shape {tshape} does not match the " f"dynamic shape {shape} at position {index}." ), ) new_shape = list(tensor.shape) new_shape[dim] = 0 aligned_flat_list[index] = torch.empty( tuple(new_shape), dtype=tensor.dtype, device=tensor.device ) if flat: # type checking assert all(t is not None for t in aligned_flat_list) # pyrefly: ignore[bad-return] return aligned_flat_list # type checking assert candidate is not None assert candidate.aligned_spec is not None args, kwargs = torch.utils._pytree.tree_unflatten( aligned_flat_list, candidate.aligned_spec ) if self._best_candidate.cst_kwargs: kwargs = {**kwargs, **self._best_candidate.cst_kwargs} if not kwargs: return args if not args: return kwargs # We need to move args to kwargs pos_names = self.signature_names[: len(args)] return {**dict(zip(pos_names, args)), **kwargs}
[docs] class InputObserver: """Steals forward method to collect inputs and outputs. This information is used to infer dynamic shapes and export arguments. Args: missing: If a named argument (in kwargs) is missing, a default value will be taken in this dictionary, this is used when after the prefill step, an argument disappears (such as `pixel_values`) and another one is added (such as `past_key_values`). The values are only to infer dynamic shapes and arguments, not to run the model. Examples -------- >>> input_observer = InputObserver() >>> with input_observer(model): >>> model(x1, y1) >>> model(x2, y2) >>> ep = torch.export.export( # or torch.onnx.export >>> model, >>> input_observer.infer_arguments(), >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), >>> ) With LLM: >>> input_observer = InputObserver() >>> with input_observer(model): >>> model.generate(input_ids) >>> ep = torch.export.export( # or torch.onnx.export >>> model, >>> (), >>> kwargs=input_observer.infer_arguments(), >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), >>> ) Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`, :ref:`l-plot-whisper-tiny-export-input-observer`, :ref:`l-plot-gemma3-tiny-export-input-observer`. """ def __init__(self, missing: dict[str, Any] | None = None): self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked] self.missing = missing or {} def _replaced_method( self, *args, _captured_method: Callable | None = None, _store_n_calls: int = 3, **kwargs, ): assert _captured_method is not None, "_captured_forward cannot be None" assert self.info is not None, "info cannot be None" n_stored = len(self.info) if n_stored < _store_n_calls: self.info.add_inputs(args, kwargs) begin = time.perf_counter() res = _captured_method(*args, **kwargs) duration = time.perf_counter() - begin if n_stored < _store_n_calls: self.info.add_outputs(res, latency=duration) return res
[docs] def num_obs(self) -> int: """Returns the number of stored set if inputs.""" return 0 if not self.info else len(self.info)
@contextlib.contextmanager def __call__( self, model: torch.nn.Module, store_n_calls: int = 3, method_name: str = "forward", ): """Starts collecting inputs and outputs of a specific method. The model method is replaced by a new one collecting tensors before and after the inner one is called. The original method is restored after the collection. Args: model: Model store_n_calls: The collection stops after this many calls to avoid taking too much memory. method_name: Method name to spy on. """ if not hasattr(model, method_name): raise ValueError(f"Model type {model} does not have a method {method_name!r}.") captured_method = getattr(model, method_name) sig = inspect.signature(captured_method) if self.info is None: self.info = InputObserverInfo( signature_names=list(sig.parameters), default_values={ p.name: p.default for p in sig.parameters.values() if p.default != inspect.Parameter.empty and isinstance(p.default, (int, bool, str, float)) }, missing=self.missing, ) n_already_stored = len(self.info) lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731 store_n_calls + n_already_stored ), **kwargs: self._replaced_method( *args, _captured_method=_cm, _store_n_calls=_snc, **kwargs ) # It may happen than the signature of the forward is used to trigger a preprocessing. # This is used in GenerationMixin (transformers): # position_ids_key = "decoder_position_ids" if ... else "position_ids" # if position_ids_key in set(inspect.signature(self.forward).parameters.keys()): lambda_method.__signature__ = sig # type: ignore[attr-defined] setattr(model, method_name, lambda_method) try: yield self finally: setattr(model, method_name, captured_method) def _check_captured(self): if self.info is None: raise RuntimeError("No inputs were captured.")
[docs] def infer_dynamic_shapes( self, set_batch_dimension_for: set[int | str] | bool | None = None ) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]: """ Infers dynamic shapes. Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument `set_batch_dimension_for` can be used to tell the first dimension is a dynamic dimension for a particular set of inputs referenced by their name (str) or their position (int). Args: set_batch_dimension_for (set[int | str] | None): A set of input identifiers (by position as ``int`` or by name as ``str``) for which the first dimension should be treated as a dynamic batch dimension. If ``None``, no dimensions are explicitly marked as dynamic. """ self._check_captured() assert self.info is not None # missed by type checking return self.info.infer_dynamic_shapes(set_batch_dimension_for=set_batch_dimension_for)
[docs] def infer_arguments( self, index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None, flat: bool = False, ) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]: """Infers arguments based on the collected tensors. Args: index_or_args_or_kwargs: If missing, the method selects one set of inputs among the available ones, usually this inputs containing the set of stored inputs with the highest number of tensors. The then replaces None values and missing tensors by empty tensors. If not missing, it can be an integer to fetch one of the stored set or some inputs. flat: If True, it returns a flattened list of tensors, if False, it returns a tuple or a dictionary preserving the nested structures. Returns: Inferred arguments, every optional tensor is replaced by a empty tensor. """ self._check_captured() assert self.info is not None # missed by type checking index_or_candidate: int | InputCandidate | None = None if index_or_args_or_kwargs is None or isinstance(index_or_args_or_kwargs, int): index_or_candidate = index_or_args_or_kwargs else: if isinstance(index_or_args_or_kwargs, tuple): index_or_candidate = InputCandidate( args=index_or_args_or_kwargs, kwargs={}, clone=False, cst_kwargs={} ) elif isinstance(index_or_args_or_kwargs, dict): index_or_candidate = InputCandidate( args=(), kwargs={ k: v for k, v in index_or_args_or_kwargs.items() if k not in self.info.default_values }, clone=False, cst_kwargs={ k: v for k, v in index_or_args_or_kwargs.items() if k in self.info.default_values }, ) else: raise ValueError( f"Unexpected type {type(index_or_args_or_kwargs)} " f"for index_or_args_or_kwargs." ) self.info.align_inputs_none_values() # type checking assert self.info._best_candidate is not None assert self.info._captured_inputs is not None index_or_candidate.align_with( self.info._best_candidate, self.info._captured_inputs, self.info.signature_names, ) return self.info.infer_arguments(index_or_candidate=index_or_candidate, flat=flat)
[docs] def check_discrepancies( self, onnx_model: str | onnx.ModelProto, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), progress_bar: bool = False, include_io: bool = True, ) -> list[dict[str, str | int | float | bool]]: """Computes the discrepancies between the saved inputs and outputs with the saved onnx model. Args: onnx_model: ONNX Model to verify. atol: Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16. rtol: Relative tolerance. hist: Thresholds, the function determines the number of discrepancies above these thresholds. progress_bar: Shows a progress bar (requires :epkg:`tqdm`). include_io: Shows inputs/outputs shapes in the summary returned by this function. Returns: A list of dictionaries, ready to be consumed by a dataframe. The function catches exceptions, it shows the error in the returned summary. """ sess = OnnxruntimeEvaluator(onnx_model, whole=True) input_names = sess.input_names self._check_captured() # type checking assert self.info is not None assert self.info.inputs is not None assert self.info.flat_outputs is not None assert self.info.latencies is not None io_sets = list(zip(self.info.inputs, self.info.flat_outputs, self.info.latencies)) if progress_bar: from tqdm import tqdm loop = tqdm(io_sets) else: loop = io_sets lhist = list(hist) data: list[dict[str, Any]] = [] for inputs, outputs, latency in loop: # type checking assert inputs.aligned_flat_list is not None if len(input_names) != len(inputs.aligned_flat_list): raise RuntimeError( f"There are ({len(inputs.aligned_flat_list)}) " f"tensors but the model expects {len(input_names)}." ) n_none = sum([t is None for t in inputs.aligned_flat_list]) n_empty = sum([t is None or t.numel() == 0 for t in inputs.aligned_flat_list]) feeds = dict(zip(input_names, self.info.infer_arguments(inputs, flat=True))) begin = time.perf_counter() try: ort_outputs = sess.run(None, feeds) error = None except Exception as e: error = str(e) ort_outputs = None duration = time.perf_counter() - begin if error: diff: dict[str, str | int | float | bool] = dict(error=error, SUCCESS=False) else: # The last output may be empty and torch could skip it. if isinstance(outputs, list) and isinstance(ort_outputs, list): while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: ort_outputs.pop() diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment] if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) del diff["rep"] diff["SUCCESS"] = ( isinstance(diff["abs"], float) and isinstance(diff["rel"], float) and diff["abs"] < atol and diff["rel"] < rtol ) diff.update( dict( index=len(diff), duration_torch=latency, ort_duration=duration, n_inputs=len(input_names), n_none=n_none, n_empty=n_empty, ) ) if include_io: diff["inputs"] = string_type(feeds, with_shape=True) diff["outputs_torch"] = string_type(outputs, with_shape=True) diff["outputs_ort"] = string_type(ort_outputs, with_shape=True) data.append(diff) return data