from typing import Any, Dict, List, Tuple, Union
ReportKeyNameType = Union[str, Tuple[str, int, str]]
ReportKeyValueType = Tuple[int, Tuple[int, ...]]
[docs]
class ReportResultComparison:
"""
Holds tensors a runtime can use as a reference to compare
intermediate results.
See :meth:`onnx_diagnostic.reference.TorchOnnxEvaluator.run`.
:param tensors: tensor
"""
def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
from ..helpers.onnx_helper import dtype_to_tensor_dtype
from ..helpers import max_diff, string_type
assert all(
hasattr(v, "shape") and hasattr(v, "dtype") for v in tensors.values()
), f"One of the tensors is not: {string_type(tensors, with_shape=True)}"
self.dtype_to_tensor_dtype = dtype_to_tensor_dtype
self.max_diff = max_diff
self.tensors = tensors
self._build_mapping()
[docs]
def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
"Returns a key for a tensor, (onnx dtype, shape)."
return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
def _build_mapping(self):
mapping = {}
for k, v in self.tensors.items():
key = self.key(v)
if key not in mapping:
mapping[key] = []
mapping[key].append(k)
self.mapping = mapping
self.clear()
[docs]
def clear(self):
"""Clears the last report."""
self.report_cmp = {}
self.unique_run_names = set()
@property
def value(
self,
) -> Dict[Tuple[Tuple[int, str], ReportKeyNameType], Dict[str, Union[float, str]]]:
"Returns the report."
return self.report_cmp
@property
def data(self) -> List[Dict[str, Any]]:
"Returns data which can be consumed by a dataframe."
rows = []
for k, v in self.value.items():
(i_run, run_name), ref_name = k
d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
d.update(v)
rows.append(d)
return rows
[docs]
def report(
self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
"""
For every tensor in outputs, compares it to every tensor held by
this class if it shares the same type and shape. The function returns
the results of the comparison. The function also collects the results
into a dictionary the user can retrieve later.
"""
res: List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]] = []
for name, tensor in outputs.items():
i_run = len(self.unique_run_names)
self.unique_run_names.add(name)
key = self.key(tensor)
if key not in self.mapping:
continue
cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
for held_key in self.mapping[key]:
t2 = self.tensors[held_key]
if hasattr(t2, "device") and hasattr(tensor, "device"):
if t2.device in cache:
t = cache[t2.device]
else:
cache[t2.device] = t = tensor.to(t2.device)
diff = self.max_diff(t, t2)
else:
diff = self.max_diff(tensor, t2)
res.append((i_run, name, held_key, diff)) # type: ignore[arg-type]
self.report_cmp[(i_run, name), held_key] = diff
return res