import inspect
import itertools
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from ..helpers import string_type, max_diff, string_diff
from ..helpers.torch_helper import torch_deepcopy
from .dynamic_shapes import CoupleInputsDynamicShapes
[docs]
def compare_modules(
    modep: torch.nn.Module,
    mod: Optional[torch.nn.Module] = None,
    args: Optional[Tuple[Any, ...]] = None,
    kwargs: Optional[Dict[str, Any]] = None,
    copy: bool = False,
    exc: bool = True,
    verbose: int = 0,
    atol: float = 1e-2,
    rtol: float = 1e-1,
) -> Dict[str, Any]:
    """
    Compares two torch modules, usually one coming from an exported program,
    the other being the origin model.
    :param model: first module
    :param mod: second module (it produces the expected values)
    :param args: positional arguments
    :param kwargs: named arguments
    :param copy: copy the inputs before executing the model (they may modify them inplace)
    :param exc: raise exception if discrepancies are too high
    :param verbose: verbosity level
    :param atol: absolute tolerance
    :param rtol: relative tolerance
    :return: dictionary with inputs, outputs and tolerance
    Example:
    .. runpython::
        :showcode:
        import torch
        from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes
        class Model(torch.nn.Module):
            def forward(self, x, y):
                return x + y
        model = Model()
        x = torch.randn((5, 6))
        y = torch.randn((1, 6))
        model(x, y)  # to make it is running
        ds = ({0: "a", 1: "b"}, {1: "b"})
        cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
        ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
        validate_ep(
            ep,
            model,
            args=(x, y),
            verbose=2,
            copy=True,
            dynamic_shapes=ds,
            values_to_try={"a": [5, 10], "b": [10, 20]},
        )
    """
    args = args or ()
    kwargs = kwargs or {}
    def _get(a):
        return torch_deepcopy(a) if copy else a
    if verbose:
        begin = time.perf_counter()
        print(
            f"[compare_modules] check ep with "
            f"args={string_type(args, with_shape=True, with_device=True)}, "
            f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..."
        )
    got = modep(*_get(args), **_get(kwargs))
    if verbose:
        d = time.perf_counter() - begin
        print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
    if mod:
        if verbose:
            begin = time.perf_counter()
            print("[compare_modules] run torch module...")
        expected = mod(*_get(args), **_get(kwargs))
        diff = max_diff(expected, got)
        if verbose:
            d = time.perf_counter() - begin
            print(
                f"[compare_modules] done in {d} with "
                f"output={string_type(expected, with_shape=True)}"
            )
            print(f"[compare_modules] discrepancies={string_diff(diff)}")
        assert not exc or (
            isinstance(diff["abs"], float)
            and isinstance(diff["rel"], float)
            and diff["abs"] <= atol
            and diff["rel"] <= rtol
        ), f"Discrepancies={string_diff(diff)} higher than expected."
        return dict(args=args, kwargs=kwargs, expected=expected, got=got, diff=diff)
    return dict(args=args, kwargs=kwargs, got=got) 
[docs]
def validate_ep(
    ep: Union[torch.nn.Module, torch.export.ExportedProgram],
    mod: Optional[torch.nn.Module] = None,
    args: Optional[Tuple[Any, ...]] = None,
    kwargs: Optional[Dict[str, Any]] = None,
    copy: bool = False,
    dynamic_shapes: Optional[Any] = None,
    values_to_try: Optional[Dict[str, List[int]]] = None,
    exc: bool = True,
    verbose: int = 0,
    atol: float = 1e-2,
    rtol: float = 1e-1,
) -> List[Dict[str, Any]]:
    """
    Validates an exported program.
    :param model: first module
    :param mod: second module (it produces the expected values)
    :param args: positional arguments
    :param kwargs: named arguments
    :param copy: copy the inputs before executing the model (they may modify them inplace)
    :param dynamic_shapes: dynamic shapes, string should be used not ``torch.export.Dim``
    :param values_to_try: dictionary with the values to try for every dynamic dimension
    :param exc: raise exception if discrepancies are too high
    :param verbose: verbosity level
    :param atol: absolute tolerance
    :param rtol: relative tolerance
    :return: dictionary with inputs, outputs and tolerance
    """
    modep = ep.module() if isinstance(ep, torch.export.ExportedProgram) else ep
    results = [
        compare_modules(
            modep, mod, args, kwargs, copy=copy, verbose=verbose, atol=atol, rtol=rtol
        )
    ]
    assert (dynamic_shapes and values_to_try) or (
        not dynamic_shapes and not values_to_try
    ), "Either both dynamic_shapes and values_to_try are specified, either none."
    if not dynamic_shapes or not values_to_try:
        return results
    items = list(values_to_try.items())
    keys = [_[0] for _ in items]
    values = [_[1] for _ in items]
    all_vals = list(itertools.product(*values))
    cpl = CoupleInputsDynamicShapes(
        args or (),
        kwargs or {},
        dynamic_shapes,
        args_names=(
            list(inspect.signature(modep.forward).parameters) if args and kwargs else None
        ),
    )
    for i, vals in enumerate(all_vals):
        change_dims = dict(zip(keys, vals))
        if verbose:
            print(f"[validate_ep] try {i}/{len(all_vals)}: {change_dims}")
        new_params = cpl.change_dynamic_dimensions(change_dims, args_kwargs=True)
        na, nkw = new_params
        c = compare_modules(
            modep, mod, na, nkw, copy=copy, verbose=max(verbose - 1, 0), atol=atol, rtol=rtol
        )
        results.append(c)
    return results