from typing import Any, Dict, Iterator, Optional, Tuple, Union
import onnx
import torch
from ..helpers import string_type, string_diff, max_diff
from ..helpers.onnx_helper import to_array_extended
[docs]
def validate_fx_tensor(
    node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
) -> None:
    """
    Validates the shape of tensor is expected.
    :param node: node
    :param tensor: tensor
    :param expected_shape: expected shape
    """
    assert len(tensor.shape) == len(expected_shape), (
        f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
        f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
        f"node.args={node.args}, node.kwargs={node.kwargs}, "
        f"node.meta={node.meta}"
    )
    for a, b in zip(tensor.shape, expected_shape):
        assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
            f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
            f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
            f"node.args={node.args}, node.kwargs={node.kwargs}, "
            f"node.meta={node.meta}"
        ) 
[docs]
def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
    """
    Validates the outputs of a node using metadata stored in the node.
    :param node: node
    :param outputs: outputs
    """
    if "val" not in node.meta:
        return
    if isinstance(outputs, torch.Tensor):
        validate_fx_tensor(node, outputs, node.meta["val"].shape)
        return
    if isinstance(outputs, (tuple, list)):
        assert isinstance(node.meta["val"], (list, tuple)), (
            f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
            f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
            f"node.args={node.args}, node.kwargs={node.kwargs}, "
            f"node.meta={node.meta}"
        )
        assert len(outputs) == len(node.meta["val"]), (
            f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
            f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
            f"node.args={node.args}, node.kwargs={node.kwargs}, "
            f"node.meta={node.meta}"
        )
        for a, b in zip(outputs, node.meta["val"]):
            validate_fx_tensor(node, a, b.shape)
        return
    if isinstance(outputs, int):
        assert (
            isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
            or outputs == node.meta["val"]
        ), (
            f"Int mismatch, got {outputs} expected {node.meta['val']}, "
            f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
            f"node.args={node.args}, node.kwargs={node.kwargs}, "
            f"node.meta={node.meta}"
        )
        return
    if outputs is None:
        assert node.meta["val"] is None, (
            f"None mismatch, got {outputs} expected {node.meta['val']}, "
            f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
            f"node.args={node.args}, node.kwargs={node.kwargs}, "
            f"node.meta={node.meta}"
        )
        return
    raise NotImplementedError(
        f"Validation for output type {type(outputs)} is not implemented, "
        f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
        f"node.args={node.args}, node.kwargs={node.kwargs}, "
        f"node.meta={node.meta}"
    ) 
[docs]
def run_fx_node(
    node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> Tuple[Any, ...]:
    """
    Executes a node
    :param node: runs a node
    :param args: unnamed inputs to the node
    :param kwargs: named inputs to the node
    :return: results
    """
    if node.op == "output":
        assert len(args) == 1 and not kwargs, (
            f"Unexpected inputs: args={string_type(args, limit=20)} "
            f"kwargs={string_type(kwargs, limit=20)}"
        )
        return args
    if node.op == "call_function":
        assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
        outputs = node.target(*args, **(kwargs or {}))
        validate_fx_outputs(node, outputs)
        return outputs
    raise NotImplementedError(
        f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
    ) 
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
    "See :func:`prepare_args_kwargs`."
    if isinstance(ref, torch.fx.Node):
        return torch_results[ref.name]
    if isinstance(ref, list):
        return [_pick_result(torch_results, n) for n in ref]
    if isinstance(ref, tuple):
        return tuple(_pick_result(torch_results, n) for n in ref)
    if isinstance(ref, dict):
        return {k: _pick_result(torch_results, v) for k, v in ref.items()}
    if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
        return ref
    if ref is None:
        return None
    raise NotImplementedError(f"Unable to process args type {type(ref)}")
[docs]
def prepare_args_kwargs(
    torch_results: Dict[str, Any], node: torch.fx.Node
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    """
    Prepares args and kwargs before executing a fx node.
    :param torch_results: existing results
    :param node: node to execute
    :return: new args and kwargs
    """
    new_args = _pick_result(torch_results, node.args)
    new_kwargs = _pick_result(torch_results, node.kwargs)
    return new_args, new_kwargs 
[docs]
def run_aligned(
    ep: torch.export.ExportedProgram,
    onx: Union[onnx.ModelProto, onnx.FunctionProto],
    args: Tuple[torch.Tensor, ...],
    check_conversion_cls: Union[Dict[str, Any], type],
    kwargs: Optional[Dict[str, Any]] = None,
    verbose: int = 0,
) -> Iterator[Tuple[Any, ...]]:
    """
    Runs in parallel both the exported program
    and the onnx proto and looks for discrepancies.
    The function does match on result names so it assumes
    the exported program and the onnx model have the same names
    for equivalent results.
    :param ep: exported program
    :param onx: model or function proto
    :param args: input args
    :param check_conversion_cls: defines the runtime to use for this task
    :param kwargs: input kwargs
    :param verbose: verbosity level
    :return: a list of tuples containing the results, they come in tuple,
    Example:
    .. runpython::
        :showcode:
        :warningout: UserWarning
        import pprint
        import pandas
        import torch
        from onnx_diagnostic.reference import (
            # This can be replace by any runtime taking NodeProto as an input.
            ExtendedReferenceEvaluator as ReferenceEvaluator,
        )
        from onnx_diagnostic.torch_onnx.sbs import run_aligned
        class Model(torch.nn.Module):
            def forward(self, x):
                ry = x.abs()
                rz = ry.exp()
                rw = rz + 1
                ru = rw.log() + rw
                return ru
        def post_process(obs):
            dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
            dobs["err_abs"] = obs[-1]["abs"]
            dobs["err_rel"] = obs[-1]["rel"]
            return dobs
        x = torch.randn((5, 4))
        Model()(x)  # to make sure the model is running
        ep = torch.export.export(
            Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
        )
        onx = torch.onnx.export(
            Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True
        ).model_proto
        results = list(
            map(
                post_process,
                run_aligned(
                    ep,
                    onx,
                    (x,),
                    check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
                    verbose=1,
                ),
            ),
        )
        print("------------")
        print("final results")
        df = pandas.DataFrame(results)
        print(df)
    """
    assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}"
    cls, atol, rtol = (
        (
            check_conversion_cls["cls"],
            check_conversion_cls["atol"],
            check_conversion_cls["rtol"],
        )
        if isinstance(check_conversion_cls, dict)
        else (check_conversion_cls, None, None)
    )
    # retrieve the positions
    positions: Dict[str, Any] = {}
    for i, node in enumerate(ep.graph.nodes):
        if isinstance(node.name, str):
            positions[node.name] = dict(fx=i)
        else:
            for n in node.name:
                positions[n] = dict(fx=i)
    for i, node in enumerate(onx.graph.node):
        for n in node.output:
            if n in positions:
                positions[n]["onnx"] = i
            else:
                positions[n] = dict(onnx=i)
    onnx_results: Dict[str, Any] = {}
    for init in onx.graph.initializer:  # type: ignore
        positions[init.name] = -1
        onnx_results[init.name] = to_array_extended(init)
        param_name = f"p_{init.name.replace('.', '_')}"
        if param_name == init.name:
            continue
        assert param_name not in onnx_results, (
            f"Some confusion may happen because {init.name!r} -> {param_name!r} "
            f"and onnx_results has {sorted(onnx_results)}"
        )
        onnx_results[param_name] = onnx_results[init.name]
    torch_results: Dict[str, Any] = {
        k: torch.from_numpy(v.copy())
        for k, v in onnx_results.items()
        if not k.startswith("init")
    }
    last_position = 0
    torch_output_names = None
    for node in ep.graph.nodes:
        if node.op == "output":
            torch_output_names = [n.name for n in node.args[0]]
    onnx_outputs_names = [o.name for o in onx.graph.output]
    assert torch_output_names is not None and len(torch_output_names) == len(
        onnx_outputs_names
    ), (
        f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
        f"onnx_outputs_names={onnx_outputs_names}"
    )
    mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
    if verbose:
        for k, v in torch_results.items():
            print(
                f"[run_aligned] +torch-cst: {k}: "
                f"{string_type(v, with_shape=True, with_min_max=True)}"
            )
        for k, v in onnx_results.items():
            print(
                f"[run_aligned] +onnx-init: {k}: "
                f"{string_type(v, with_shape=True, with_min_max=True)}"
            )
    for inp, v in zip(onx.graph.input, args):
        onnx_results[inp.name] = v.numpy()
        if verbose:
            print(
                f"[run_aligned] +onnx-input: {inp.name}: "
                f"{string_type(v, with_shape=True, with_min_max=True)}"
            )
    for i, node in enumerate(ep.graph.nodes):
        if verbose:
            if node.op == "call_function":
                print(
                    f"[run_aligned] run ep.graph.nodes[{i}]: "
                    f"{node.op}[{node.target}] -> {node.name!r}"
                )
            else:
                print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
        if node.op == "placeholder":
            if node.name in onnx_results:
                torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy())
                if verbose:
                    t = torch_results[node.name]
                    print(
                        f"[run_aligned] +torch {node.name}="
                        f"{string_type(t, with_shape=True, with_min_max=True)}"
                    )
                continue
            raise AssertionError(
                f"unable to process node {node.op} -> {node.name!r} "
                f"not in {sorted(onnx_results)}, len(args)={len(args)}, "
                f"onx.graph.input={[i.name for i in onx.graph.input]}"
            )
        outputs = [node.name] if isinstance(node.name, str) else list(node.name)
        args, kwargs = prepare_args_kwargs(torch_results, node)
        new_outputs = run_fx_node(node, args, kwargs)
        if isinstance(new_outputs, (torch.Tensor, int, float, list)):
            new_outputs = (new_outputs,)
        if new_outputs is None:
            # Probably an assert.
            continue
        for k, v in zip(outputs, new_outputs):
            torch_results[k] = v
        if verbose:
            for k, v in zip(outputs, new_outputs):
                print(
                    f"[run_aligned] +torch {k}="
                    f"{string_type(v, with_shape=True, with_min_max=True)}"
                )
        max_pos = -2
        for n in outputs:
            if n in positions and "onnx" in positions[n]:
                max_pos = max(max_pos, positions[n]["onnx"])
        if max_pos == -2:
            # we skip.
            continue
        for i_onnx in range(last_position, max_pos + 1):
            node = onx.graph.node[i_onnx]
            if verbose:
                print(
                    f"[run_aligned] run onx.graph.node[{i_onnx}]: "
                    f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
                )
            ref = cls(node)
            feeds = {k: onnx_results[k] for k in node.input}
            res = ref.run(None, feeds)
            for o, r in zip(node.output, res):
                onnx_results[o] = r
                if verbose:
                    print(
                        f"[run_aligned] +onnx {o}="
                        f"{string_type(r, with_shape=True, with_min_max=True)}"
                    )
                to = mapping_onnx_to_torch.get(o, o)
                if to in torch_results:
                    d = max_diff(torch_results[to], r)
                    if verbose:
                        if o == to:
                            print(f"[run_aligned] =common results {to}: {string_diff(d)}")
                        else:
                            print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}")
                        if not (
                            atol is None
                            or rtol is None
                            or (d["abs"] <= atol and d["rel"] <= rtol)
                        ):
                            skw = dict(with_shape=True, with_min_max=True)
                            raise ValueError(
                                f"discrepancies detected for results [{to}/{o}]: "
                                f"{string_diff(d)}"
                                f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
                                f"\n-- onnx_results: {string_type(r, **skw)}"
                                f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
                            )
                    yield (i, i_onnx, o, to, d)
        last_position = max_pos + 1
    # complete the execution of the onnx graph
    for i_onnx in range(last_position, len(onx.graph.node)):
        node = onx.graph.node[i_onnx]
        if verbose:
            print(
                f"[run_aligned] run onx.graph.node[{i_onnx}]: "
                f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
            )
        ref = cls(node)
        feeds = {k: onnx_results[k] for k in node.input}
        res = ref.run(None, feeds)
        for o, r in zip(node.output, res):
            onnx_results[o] = r
            if verbose:
                print(
                    f"[run_aligned] +onnx {o}="
                    f"{string_type(r, with_shape=True, with_min_max=True)}"
                )
            to = mapping_onnx_to_torch.get(o, o)
            if to in torch_results:
                d = max_diff(torch_results[to], r)
                if verbose:
                    if o == to:
                        print(f"[run_aligned] =common results* {to}: {string_diff(d)}")
                    else:
                        print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}")
                    if not (
                        atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)
                    ):
                        skw = dict(with_shape=True, with_min_max=True)
                        raise ValueError(
                            f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}"
                            f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
                            f"\n-- onnx_results: {string_type(r, **skw)}"
                            f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
                        )
                yield (i, i_onnx, o, to, d)