Source code for onnx_diagnostic.investigate.input_observer
import contextlib
import inspect
import time
from typing import Any, Callable, Sequence
import onnx
import torch
from ..helpers import max_diff, string_type
from ..reference import OnnxruntimeEvaluator
EOL = "\n"
def _flatten_unflatten_for_dynamic_shapes(
obj: Any,
use_dict: bool = True,
change_function: Callable[[torch.Tensor], Any] | None = None,
) -> Any:
"""Returns the object in a different structure similar to what
the definition of the dynamic shapes should use.
Args:
obj:
object from a custom class
use_dict:
closer to the original result but
:func:`torch.export.export` only considers the values,
the context gives the dictionary keys but it is not expressed
in the dynamic shapes, these specifications seems to be different
for the strict and non strict mode. It also preserves tuple.
change_function:
to modify the tensor in the structure itself,
like replace them by a shape
Returns:
the flattened object
"""
if isinstance(obj, torch.Tensor):
return change_function(obj) if change_function else obj
flat, spec = torch.utils._pytree.tree_flatten(obj)
start = 0
end = 0
subtrees = []
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = _flatten_unflatten_for_dynamic_shapes(
value, use_dict=use_dict, change_function=change_function
)
subtrees.append(value)
start = end
if use_dict:
if spec.type is dict:
# This is a dictionary.
return dict(zip(spec.context, subtrees))
if spec.type is tuple:
return tuple(subtrees)
if spec.type is list:
return list(subtrees)
if spec.type is None and not subtrees:
return None
if spec.context:
# This is a custom class with attributes.
# It is returned as a list.
return list(subtrees)
raise ValueError(
f"Unable to interpret spec type {spec.type} "
f"(type is {type(spec.type)}, context is {spec.context}), "
f"spec={spec}, subtrees={subtrees}"
)
# This is a list.
return subtrees
def _infer_dynamic_dimensions(
shape_list: Sequence[tuple[int, ...]], set_batch_dimension: bool = False
) -> list[int]:
"""Returns the list of dynamic dimensions given a list of shapes
corresponding to the same tensor.
Args:
shape_list:
list of shapes, they must all have the same length
set_batch_dimension:
forces the first dimension to be treated as dynamic,
even if all shapes have the same value for that dimension
Returns:
list of dynamic dimensions
"""
unique_ranks = {len(shape) for shape in shape_list}
torch._check(
len(unique_ranks) == 1,
lambda: "all shapes in shape_list must have the same rank",
)
rank = unique_ranks.pop()
dynamic = []
for i in range(rank):
dims = [shape[i] for shape in shape_list]
if len(set(dims)) > 1 or (i == 0 and set_batch_dimension):
dynamic.append(i)
return dynamic
[docs]
class InputCandidate:
"""Retains one set of inputs given to the forward method or any
other method the class :class:`InputObserver` is stealing from.
Any class is allowed as long as it can be flattened.
Args:
args:
Positional arguments.
kwargs:
Optional arguments.
clone:
Clones the inputs before storing them. Some tensors
may be modified inplace, the original value must be retained.
cst_kwargs:
Any optional arguments constant over multiple calls.
int, float, str, bool values must be stored here.
The constructor flattens the received arguments.
Any necessary flattening function should have been registered first.
"""
def __init__(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
clone: bool,
cst_kwargs: dict[str, int | str | float | bool],
):
self.args = args
self.kwargs = kwargs
self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs))
self.n_tensors = sum(t is not None for t in self.flat_list)
self._position_to_args_kwargs: list[int | str] | None = None
self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None
self.cst_kwargs = cst_kwargs.copy()
assert "use_cache" not in self.cst_kwargs
if clone:
self.flat_list = [
(None if not isinstance(t, torch.Tensor) else t.clone().detach())
for t in self.flat_list
]
self.args, self.kwargs = torch.utils._pytree.tree_unflatten(
self.flat_list, self.spec
)
self.aligned_spec: torch.utils._pytree.PyTreeSpec | None = None
self.aligned_flat_list: list[torch.Tensor | None] | None = None
def __str__(self) -> str:
return (
f"{self.__class__.__name__}({len(self.args)} args, "
f"{len(self.kwargs)} kwargs, {len(self.flat_list)} tensors, "
f"{len(self.aligned_flat_list or [])} aligned tensors)"
)
def __len__(self) -> int:
"""Returns the number of flattended tensors, None tensors are included."""
return len(self.flat_list)
[docs]
def str_obs(self) -> str:
"""Prints out some information about the osbervations."""
return (
f"InputCandidate(args={string_type(self.args, with_shape=True)}, "
f"kwargs={string_type(self.kwargs, with_shape=True)}, "
f"cst_kwargs={self.cst_kwargs})"
)
def build_mappings(self) -> list[int | str]:
if self._position_to_args_kwargs is not None:
return self._position_to_args_kwargs
self._n_tensors_for_args_kwargs = {}
flat_index_to_args: list[int | str] = []
for index_args, a in enumerate(self.args):
size = len(torch.utils._pytree.tree_flatten(a)[0])
self._n_tensors_for_args_kwargs[index_args] = size
flat_index_to_args.extend([index_args] * size)
for k, v in self.kwargs.items():
size = len(torch.utils._pytree.tree_flatten(v)[0])
self._n_tensors_for_args_kwargs[k] = size
flat_index_to_args.extend([k] * size)
self._position_to_args_kwargs = flat_index_to_args
return self._position_to_args_kwargs
@property
def position_to_args_kwargs(self) -> list[int | str]:
"""Returns the corresponding args or kwargs
for every tensor in the flattened inputs.
"""
if self._position_to_args_kwargs is None:
self.build_mappings()
# type checking is missing it
assert self._position_to_args_kwargs is not None
return self._position_to_args_kwargs
@property
def n_tensors_for_args_kwargs(self) -> dict[int | str, int]:
"""Returns the number of flat tensors in every args or kwargs."""
if self._n_tensors_for_args_kwargs is None:
self.build_mappings()
# type checking is missing it
assert self._n_tensors_for_args_kwargs is not None
return self._n_tensors_for_args_kwargs
def _set_aligned_flat_list(
self,
aligned_flat_list: list[torch.Tensor | None],
aligned_spec: torch.utils._pytree.PyTreeSpec,
):
self.aligned_flat_list = aligned_flat_list
self.aligned_spec = aligned_spec
[docs]
def align_with(
self,
best_candidate: "InputCandidate",
captured_inputs: dict[int | str, int],
signature_names: list[str],
):
"""Two candidates are considered as aligned if after being flattened
if they have the same number of tensors (None allowed)."""
if self.cst_kwargs != best_candidate.cst_kwargs:
raise RuntimeError(
f"Two calls were made with different constant values, "
f"{self.cst_kwargs} != {best_candidate.cst_kwargs}"
)
args = self.args
if len(self.args) > len(best_candidate.args):
# We need to move some args to kwargs as the best_candidate does.
new_kwargs = {}
for i in range(len(best_candidate.args), len(self.args)):
new_kwargs[signature_names[i]] = args[i]
args = args[: len(best_candidate.args)]
kwargs = {**new_kwargs, **self.kwargs}
else:
kwargs = self.kwargs
flat = []
for i in range(len(best_candidate.args)):
if i < len(args) and (isinstance(args[i], torch.Tensor) or args[i]):
ts = torch.utils._pytree.tree_flatten(self.args[i])[0]
if i in captured_inputs and captured_inputs[i] != len(ts):
raise RuntimeError(
f"Positional argument {i} has {len(ts)} tensors "
f"but previously got {captured_inputs[i]} tensors. "
f"Inference is impossible in that case."
)
captured_inputs[i] = len(ts)
flat.extend(ts)
continue
# If the argument i is not specified or is None or an empty container.
flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[i])])
for k in best_candidate.kwargs:
if k in kwargs and (isinstance(kwargs[k], torch.Tensor) or kwargs[k]):
ts = torch.utils._pytree.tree_flatten(kwargs[k])[0]
if k in captured_inputs and captured_inputs[k] != len(ts):
raise RuntimeError(
f"Named argument {k!r} has {len(ts)} tensors "
f"but previously got {captured_inputs[k]} tensors in "
f"kwargs={list(kwargs)}. "
f"Inference is impossible in that case."
)
captured_inputs[k] = len(ts)
flat.extend(ts)
continue
# If the argument k is not specified or is None or an empty container.
flat.extend([None for _ in range(best_candidate.n_tensors_for_args_kwargs[k])])
self._set_aligned_flat_list(flat, best_candidate.spec)
@property
def n_aligned_tensors(self) -> int:
if self.aligned_flat_list is None:
raise RuntimeError("This input was not aligned with the others.")
return len(self.aligned_flat_list)
[docs]
class InputObserverInfo:
"""Contains all the necessary information to infer dynamic shapes
and the arguments to send to :func:`torch.export.export`.
Args:
signature_names: Names of the arguments of the method
the collector tensors come from. They are used if it becomes
necessary to move positional arguments to named ones.
They are used a second time because :func:`torch.export.export`
cares about the order in kwargs and dynamic shapes, it needs
to be the same in the ordered dictionaries `add_inputs` receive.
default_values: Default values defined by the signature of the function,
any value equal to that is ignore to simplify the export.
missing: If a named argument (in kwargs) is missing,
a default value will be taken in this dictionary,
this is used when after the prefill step, an argument
disappears (such as `pixel_values`) and another one
is added (such as `past_key_values`).
The values are only to infer dynamic shapes and arguments,
not to run the model.
"""
def __init__(
self,
signature_names: list[str],
default_values: dict[str, int | bool | str | float],
missing: dict[str, Any],
):
self.default_values = default_values
self.missing = missing
self.inputs: list[InputCandidate] = []
self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = []
self.flat_outputs: list[list[torch.Tensor | None]] = []
self.latencies: list[float] = []
self.signature_names = signature_names
self._best_candidate: InputCandidate | None = None
self._captured_inputs: dict[int | str, int] | None = None
def __len__(self) -> int:
"""Returns the number of collected set of inputs/outputs."""
return len(self.inputs)
[docs]
def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]):
"""Stores one set of inputs. They are deepcopied.
Args:
args: Positional arguments.
kwargs: Named arguments.
"""
cst_kwargs = {
k: v
for k, v in kwargs.items()
if k in self.signature_names
and isinstance(v, (int, float, bool, str))
and v != self.default_values.get(k, None)
and self.default_values.get(k, None) is not None
}
kwargs = {
k: v
for k, v in kwargs.items()
if v is not None and not isinstance(v, (int, float, bool))
}
# adds missing attributes
for k, v in self.missing.items():
if k not in kwargs:
kwargs[k] = v
# kwargs may come in a different ordeer teach.
# dictionaries are ordered and torch.export.export expects
# dynamic shapes an kwargs to follow the same order.
ordered_kwargs = {k: kwargs[k] for k in self.signature_names if k in kwargs}
for k, v in kwargs.items():
if k not in ordered_kwargs:
ordered_kwargs[k] = v
candidate = InputCandidate(args, ordered_kwargs, clone=True, cst_kwargs=cst_kwargs)
self.inputs.append(candidate)
if self._best_candidate is None or len(self._best_candidate) < len(candidate):
self._best_candidate = candidate
[docs]
def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...], latency: float):
"""Stores outputs. They are deepcopied."""
flat_res, spec = torch.utils._pytree.tree_flatten(res)
self.outputs_specs.append(spec)
self.flat_outputs.append(
[(None if t is None else t.clone().detach()) for t in flat_res]
)
self.latencies.append(latency)
[docs]
def align_inputs_none_values(self):
"""Once the best candidate is chosen, this method aligns every set of inputs
on the best candidate, it inserts None at the right position when
optional inputs are not specified. We consider a set of inputs is aligned
if this method does not change the original flattened inputs.
"""
if not self.inputs or self._best_candidate is None:
raise RuntimeError("No inputs were captured.")
if all(candidate.aligned_flat_list is not None for candidate in self.inputs):
# No new inputs, no alignment is necessary.
return
# Let's reprocess everything.
self._captured_inputs = {}
for candidate in self.inputs:
if len(set(candidate.kwargs) | set(self._best_candidate.kwargs)) > len(
self._best_candidate.kwargs
):
raise RuntimeError(
f"At least one call to the observed model "
f"must contain all the named arguments. "
f"candidate kwargs={list(candidate.kwargs)}, "
f"best candidate kwargs={list(self._best_candidate.kwargs)}, "
f"all candidate kwargs={EOL}"
f"{EOL.join(string_type(c.kwargs, with_shape=True) for c in self.inputs)}"
)
candidate.align_with(
self._best_candidate, self._captured_inputs, self.signature_names
)
[docs]
def infer_dynamic_shapes(
self,
set_batch_dimension_for: set[int | str] | bool | None = None,
return_flat: bool = False,
) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]:
"""Infers dynamic shapes based on the collected tensors.
Most of the time, models do support a batch dimension
but this batch dimension has the same value for every input sample.
Instead of running inference on new samples, argument `set_batch_dimension_for`
can be used to tell the first dimension is a dynamic dimension for a particular
set of inputs referenced by their name (str) or their position (int).
Args:
set_batch_dimension_for (set[int | str] | None): Set of input identifiers,
by name (``str``) or position (``int``), for which the first dimension
should be treated as a dynamic batch dimension. If ``None`` or empty,
no additional batch dimensions are marked as dynamic.
return_flat: Tells the function to return a flat tuple instead of
nested structured.
"""
self.align_inputs_none_values()
# type checking
assert self._best_candidate is not None
assert self._best_candidate.flat_list is not None
assert self._best_candidate.aligned_flat_list is not None
def _set_batch_dimension(name_or_position):
if not set_batch_dimension_for:
return False
if (
isinstance(set_batch_dimension_for, bool) and set_batch_dimension_for
) or name_or_position in set_batch_dimension_for:
return True
if isinstance(name_or_position, int):
torch._check(
name_or_position < len(self.signature_names),
lambda: f"argument at position {name_or_position} is out of boundary",
)
if self.signature_names[name_or_position] in set_batch_dimension_for:
return True
return False
def _set_batch_dimension_for_flat_index(index):
# type checking
assert self._best_candidate is not None
return _set_batch_dimension(self._best_candidate.position_to_args_kwargs[index])
if len(self._best_candidate.flat_list) != len(self._best_candidate.aligned_flat_list):
raise NotImplementedError(
"infer_dynamic_shapes is not implemented "
"when the best candidate is not 'aligned'."
"This happens when there is not stored set inputs where "
"all optional inputs showing in other sets are defined."
)
if len({inputs.n_aligned_tensors for inputs in self.inputs}) != 1:
raise NotImplementedError(
f"infer_dynamic_shapes is not implemented "
f"when the number of input tensors are not the same in "
f"every set of inputs "
f"{[inputs.n_aligned_tensors for inputs in self.inputs]}."
)
shape_lists = [
[(None if t is None else t.shape) for t in candidate.aligned_flat_list]
for candidate in self.inputs
if candidate.aligned_flat_list is not None
]
n_tensors = len(shape_lists[0])
dynamic_shapes = [
_infer_dynamic_dimensions(
[s for s in [shapes[index] for shapes in shape_lists] if s is not None],
set_batch_dimension=_set_batch_dimension_for_flat_index(index),
)
for index in range(n_tensors)
]
cst = torch.export.Dim.DYNAMIC
flat_dynamic_shapes = [dict.fromkeys(dims, cst) for dims in dynamic_shapes]
if return_flat:
return tuple(flat_dynamic_shapes)
if len(flat_dynamic_shapes) == len(self._best_candidate.args) + len(
self._best_candidate.kwargs
):
# It means forward method is called with tensors only.
if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs:
# only positional arguments
return tuple(flat_dynamic_shapes)
if not self._best_candidate.args:
# only named arguments
ds = dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes))
return {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
# positional arguments needs to be moved to the named arguments
n_args = len(self._best_candidate.args)
pos_names = self.signature_names[:n_args]
return {
**dict(zip(pos_names, flat_dynamic_shapes[:n_args])),
**dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])),
**dict.fromkeys(self._best_candidate.cst_kwargs, None),
}
# nested types, here comes the fun part because the shapes cannot be unflattened,
# custom classes must appear in their flattened shape.
# This does not work in all cases but every time every available argument is flattened
# with the same number of tensors. The function does not check
# if that assumption is true.
flat_inputs, _max_spec = torch.utils._pytree.tree_flatten(
(self._best_candidate.args, self._best_candidate.kwargs)
)
torch._check(
len(flat_inputs) == len(flat_dynamic_shapes),
(
f"Length mismatch len(flat_inputs)={len(flat_inputs)}, "
f"len(flat_dynamic_shapes)={len(flat_dynamic_shapes)}"
),
)
index = 0
def change_function(t):
nonlocal index
if index >= len(flat_dynamic_shapes):
raise RuntimeError(
f"Flattened {index} tensors when there are only "
f"{len(flat_dynamic_shapes)}."
)
res = flat_dynamic_shapes[index]
index += 1
return res
ds_args, ds_kwargs = _flatten_unflatten_for_dynamic_shapes(
(self._best_candidate.args, self._best_candidate.kwargs),
change_function=change_function,
)
if self._best_candidate.cst_kwargs:
ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)}
if not ds_kwargs:
return tuple(ds_args)
if not ds_args:
return ds_kwargs
pos_names = self.signature_names[: len(ds_args)]
return {**dict(zip(pos_names, ds_args)), **ds_kwargs}
[docs]
def infer_arguments(
self, index_or_candidate: InputCandidate | int | None = None, flat: bool = False
) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
"""Infers arguments based on the collected tensors."""
# This is already checked by _build_inputs_completed_with_none_values
# but this is not always well captured by tools checking types.
self.align_inputs_none_values()
torch._check(self._best_candidate is not None, lambda: "No input was captured.")
# type checking
assert self._best_candidate is not None
candidate = None
if index_or_candidate is None:
for cand in self.inputs:
args, kwargs = cand.args, cand.kwargs
if len(args) == len(self._best_candidate.args) and len(kwargs) == len(
self._best_candidate.kwargs
):
candidate = cand
break
elif isinstance(index_or_candidate, int):
torch._check(
index_or_candidate < len(self.inputs),
lambda: (
f"No stored input set for index="
f"{index_or_candidate}<{len(self.inputs)}."
),
)
candidate = self.inputs[index_or_candidate]
else:
candidate = index_or_candidate
torch._check(candidate is not None, "No input was captured.")
# type checking
assert candidate is not None
if candidate.aligned_flat_list is None:
raise RuntimeError(
f"Candidate {candidate} has no aligned flat list of tensors, "
f"index_or_candidate={index_or_candidate}. You should call "
f"method 'align_with'."
)
aligned_flat_list = candidate.aligned_flat_list
if any(t is None for t in aligned_flat_list):
dynamic_shapes = self.infer_dynamic_shapes(return_flat=True)
# type checking
assert isinstance(dynamic_shapes, tuple)
aligned_flat_list = list(aligned_flat_list)
for index in range(len(aligned_flat_list)):
if aligned_flat_list[index] is not None:
continue
shape = dynamic_shapes[index]
all_non_empty_tensors = [
c.aligned_flat_list[index]
for c in self.inputs
if c.aligned_flat_list is not None
]
all_non_empty_tensors_not_none = [
t for t in all_non_empty_tensors if t is not None
]
if not all_non_empty_tensors_not_none:
raise RuntimeError(
f"There is no tensor at position {index} in any flattened inputs."
)
tensor = all_non_empty_tensors_not_none.pop()
if tensor.numel() == 0:
aligned_flat_list[index] = tensor
continue
if not shape:
aligned_flat_list[index] = torch.zeros(
tensor.shape, dtype=tensor.dtype, device=tensor.device
)
continue
dim = max(shape)
torch._check(
dim < tensor.ndim,
lambda index=index, shape=shape, tshape=tensor.shape: (
f"Tensor shape {tshape} does not match the "
f"dynamic shape {shape} at position {index}."
),
)
new_shape = list(tensor.shape)
new_shape[dim] = 0
aligned_flat_list[index] = torch.empty(
tuple(new_shape), dtype=tensor.dtype, device=tensor.device
)
if flat:
# type checking
assert all(t is not None for t in aligned_flat_list)
# pyrefly: ignore[bad-return]
return aligned_flat_list
# type checking
assert candidate is not None
assert candidate.aligned_spec is not None
args, kwargs = torch.utils._pytree.tree_unflatten(
aligned_flat_list, candidate.aligned_spec
)
if self._best_candidate.cst_kwargs:
kwargs = {**kwargs, **self._best_candidate.cst_kwargs}
if not kwargs:
return args
if not args:
return kwargs
# We need to move args to kwargs
pos_names = self.signature_names[: len(args)]
return {**dict(zip(pos_names, args)), **kwargs}
[docs]
class InputObserver:
"""Steals forward method to collect inputs and outputs.
This information is used to infer dynamic shapes and
export arguments.
Args:
missing: If a named argument (in kwargs) is missing,
a default value will be taken in this dictionary,
this is used when after the prefill step, an argument
disappears (such as `pixel_values`) and another one
is added (such as `past_key_values`).
The values are only to infer dynamic shapes and arguments,
not to run the model.
Examples
--------
>>> input_observer = InputObserver()
>>> with input_observer(model):
>>> model(x1, y1)
>>> model(x2, y2)
>>> ep = torch.export.export( # or torch.onnx.export
>>> model,
>>> input_observer.infer_arguments(),
>>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
>>> )
With LLM:
>>> input_observer = InputObserver()
>>> with input_observer(model):
>>> model.generate(input_ids)
>>> ep = torch.export.export( # or torch.onnx.export
>>> model,
>>> (),
>>> kwargs=input_observer.infer_arguments(),
>>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
>>> )
Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`,
:ref:`l-plot-whisper-tiny-export-input-observer`,
:ref:`l-plot-gemma3-tiny-export-input-observer`.
"""
def __init__(self, missing: dict[str, Any] | None = None):
self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked]
self.missing = missing or {}
def _replaced_method(
self,
*args,
_captured_method: Callable | None = None,
_store_n_calls: int = 3,
**kwargs,
):
assert _captured_method is not None, "_captured_forward cannot be None"
assert self.info is not None, "info cannot be None"
n_stored = len(self.info)
if n_stored < _store_n_calls:
self.info.add_inputs(args, kwargs)
begin = time.perf_counter()
res = _captured_method(*args, **kwargs)
duration = time.perf_counter() - begin
if n_stored < _store_n_calls:
self.info.add_outputs(res, latency=duration)
return res
[docs]
def num_obs(self) -> int:
"""Returns the number of stored set if inputs."""
return 0 if not self.info else len(self.info)
@contextlib.contextmanager
def __call__(
self,
model: torch.nn.Module,
store_n_calls: int = 3,
method_name: str = "forward",
):
"""Starts collecting inputs and outputs of a specific method.
The model method is replaced by a new one collecting tensors
before and after the inner one is called.
The original method is restored after the collection.
Args:
model: Model
store_n_calls: The collection stops after this many calls
to avoid taking too much memory.
method_name: Method name to spy on.
"""
if not hasattr(model, method_name):
raise ValueError(f"Model type {model} does not have a method {method_name!r}.")
captured_method = getattr(model, method_name)
sig = inspect.signature(captured_method)
if self.info is None:
self.info = InputObserverInfo(
signature_names=list(sig.parameters),
default_values={
p.name: p.default
for p in sig.parameters.values()
if p.default != inspect.Parameter.empty
and isinstance(p.default, (int, bool, str, float))
},
missing=self.missing,
)
n_already_stored = len(self.info)
lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731
store_n_calls + n_already_stored
), **kwargs: self._replaced_method(
*args, _captured_method=_cm, _store_n_calls=_snc, **kwargs
)
# It may happen than the signature of the forward is used to trigger a preprocessing.
# This is used in GenerationMixin (transformers):
# position_ids_key = "decoder_position_ids" if ... else "position_ids"
# if position_ids_key in set(inspect.signature(self.forward).parameters.keys()):
lambda_method.__signature__ = sig # type: ignore[attr-defined]
setattr(model, method_name, lambda_method)
try:
yield self
finally:
setattr(model, method_name, captured_method)
def _check_captured(self):
if self.info is None:
raise RuntimeError("No inputs were captured.")
[docs]
def infer_dynamic_shapes(
self, set_batch_dimension_for: set[int | str] | bool | None = None
) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]:
"""
Infers dynamic shapes. Most of the time, models do support a batch dimension
but this batch dimension has the same value for every input sample.
Instead of running inference on new samples, argument `set_batch_dimension_for`
can be used to tell the first dimension is a dynamic dimension for a particular
set of inputs referenced by their name (str) or their position (int).
Args:
set_batch_dimension_for (set[int | str] | None): A set of input
identifiers (by position as ``int`` or by name as ``str``) for
which the first dimension should be treated as a dynamic batch
dimension. If ``None``, no dimensions are explicitly marked as
dynamic.
"""
self._check_captured()
assert self.info is not None # missed by type checking
return self.info.infer_dynamic_shapes(set_batch_dimension_for=set_batch_dimension_for)
[docs]
def infer_arguments(
self,
index_or_args_or_kwargs: tuple[Any] | dict[str, Any] | int | None = None,
flat: bool = False,
) -> list[torch.Tensor] | tuple[torch.Tensor, ...] | dict[str, torch.Tensor]:
"""Infers arguments based on the collected tensors.
Args:
index_or_args_or_kwargs: If missing, the method selects one set of inputs
among the available ones, usually this inputs containing
the set of stored inputs with the highest number of tensors.
The then replaces None values and missing tensors by empty tensors.
If not missing, it can be an integer to fetch one of the stored set
or some inputs.
flat: If True, it returns a flattened list of tensors,
if False, it returns a tuple or a dictionary preserving
the nested structures.
Returns:
Inferred arguments, every optional tensor is replaced by a empty tensor.
"""
self._check_captured()
assert self.info is not None # missed by type checking
index_or_candidate: int | InputCandidate | None = None
if index_or_args_or_kwargs is None or isinstance(index_or_args_or_kwargs, int):
index_or_candidate = index_or_args_or_kwargs
else:
if isinstance(index_or_args_or_kwargs, tuple):
index_or_candidate = InputCandidate(
args=index_or_args_or_kwargs, kwargs={}, clone=False, cst_kwargs={}
)
elif isinstance(index_or_args_or_kwargs, dict):
index_or_candidate = InputCandidate(
args=(),
kwargs={
k: v
for k, v in index_or_args_or_kwargs.items()
if k not in self.info.default_values
},
clone=False,
cst_kwargs={
k: v
for k, v in index_or_args_or_kwargs.items()
if k in self.info.default_values
},
)
else:
raise ValueError(
f"Unexpected type {type(index_or_args_or_kwargs)} "
f"for index_or_args_or_kwargs."
)
self.info.align_inputs_none_values()
# type checking
assert self.info._best_candidate is not None
assert self.info._captured_inputs is not None
index_or_candidate.align_with(
self.info._best_candidate,
self.info._captured_inputs,
self.info.signature_names,
)
return self.info.infer_arguments(index_or_candidate=index_or_candidate, flat=flat)
[docs]
def check_discrepancies(
self,
onnx_model: str | onnx.ModelProto,
atol: float = 1e-4,
rtol: float = 0.1,
hist=(0.1, 0.01),
progress_bar: bool = False,
include_io: bool = True,
) -> list[dict[str, str | int | float | bool]]:
"""Computes the discrepancies between the saved inputs and outputs
with the saved onnx model.
Args:
onnx_model:
ONNX Model to verify.
atol:
Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
rtol:
Relative tolerance.
hist:
Thresholds, the function determines the number of discrepancies
above these thresholds.
progress_bar:
Shows a progress bar (requires :epkg:`tqdm`).
include_io:
Shows inputs/outputs shapes in the summary
returned by this function.
Returns:
A list of dictionaries, ready to be consumed by a dataframe.
The function catches exceptions, it shows the error in the returned
summary.
"""
sess = OnnxruntimeEvaluator(onnx_model, whole=True)
input_names = sess.input_names
self._check_captured()
# type checking
assert self.info is not None
assert self.info.inputs is not None
assert self.info.flat_outputs is not None
assert self.info.latencies is not None
io_sets = list(zip(self.info.inputs, self.info.flat_outputs, self.info.latencies))
if progress_bar:
from tqdm import tqdm
loop = tqdm(io_sets)
else:
loop = io_sets
lhist = list(hist)
data: list[dict[str, Any]] = []
for inputs, outputs, latency in loop:
# type checking
assert inputs.aligned_flat_list is not None
if len(input_names) != len(inputs.aligned_flat_list):
raise RuntimeError(
f"There are ({len(inputs.aligned_flat_list)}) "
f"tensors but the model expects {len(input_names)}."
)
n_none = sum([t is None for t in inputs.aligned_flat_list])
n_empty = sum([t is None or t.numel() == 0 for t in inputs.aligned_flat_list])
feeds = dict(zip(input_names, self.info.infer_arguments(inputs, flat=True)))
begin = time.perf_counter()
try:
ort_outputs = sess.run(None, feeds)
error = None
except Exception as e:
error = str(e)
ort_outputs = None
duration = time.perf_counter() - begin
if error:
diff: dict[str, str | int | float | bool] = dict(error=error, SUCCESS=False)
else:
# The last output may be empty and torch could skip it.
if isinstance(outputs, list) and isinstance(ort_outputs, list):
while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0:
ort_outputs.pop()
diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment]
if "rep" in diff and isinstance(diff["rep"], dict):
diff.update(diff["rep"])
del diff["rep"]
diff["SUCCESS"] = (
isinstance(diff["abs"], float)
and isinstance(diff["rel"], float)
and diff["abs"] < atol
and diff["rel"] < rtol
)
diff.update(
dict(
index=len(diff),
duration_torch=latency,
ort_duration=duration,
n_inputs=len(input_names),
n_none=n_none,
n_empty=n_empty,
)
)
if include_io:
diff["inputs"] = string_type(feeds, with_shape=True)
diff["outputs_torch"] = string_type(outputs, with_shape=True)
diff["outputs_ort"] = string_type(ort_outputs, with_shape=True)
data.append(diff)
return data