from typing import Any, Dict, List, Union
import numpy as np
import onnx
import torch
from .helper import string_type, flatten_object
from .onnx_helper import dtype_to_tensor_dtype
from .cache_helper import is_cache_dynamic_registered
def name_type_to_onnx_dtype(name: str) -> int:
    if name == "tensor(int64)":
        return onnx.TensorProto.INT64
    if name == "tensor(float)":
        return onnx.TensorProto.FLOAT
    if name == "tensor(float16)":
        return onnx.TensorProto.FLOAT16
    raise AssertionError(f"Unexpected value {name!r}")
[docs]
def make_feeds(
    proto: Union[onnx.ModelProto, List[str]],
    inputs: Any,
    use_numpy: bool = False,
    copy: bool = False,
    check_flatten: bool = True,
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
    """
    Serializes the inputs to produce feeds expected
    by :class:`onnxruntime.InferenceSession`.
    :param proto: onnx model or list of names
    :param inputs: any kind of inputs
    :param use_numpy: if True, converts torch tensors into numpy arrays
    :param copy: a copy is made, this should be the case if the inputs is ingested
        by ``OrtValue``
    :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
        returns the same number of outputs
    :return: feeds dictionary
    """
    # position_ids is a special case because ModelBuilder does not usually use it.
    # We use types to detect the best inputs.
    flat = flatten_object(inputs, drop_keys=True)
    assert (
        not check_flatten
        or not all(isinstance(obj, torch.Tensor) for obj in flat)
        or not is_cache_dynamic_registered(fast=True)
        or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
    ), (
        f"Unexpected number of flattened objects, "
        f"{string_type(flat, with_shape=True)} != "
        f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
    )
    if use_numpy:
        flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat]
    names = (
        [i.name for i in proto.graph.input]
        if isinstance(proto, onnx.ModelProto)
        else (
            [i.name for i in proto.get_inputs()]
            if hasattr(proto, "get_inputs")
            else (proto.input_names if hasattr(proto, "input_names") else proto)
        )
    )
    assert (
        isinstance(names, list)
        and len(names) <= len(flat)
        and (
            len(names) == len(flat)
            or isinstance(proto, onnx.ModelProto)
            or hasattr(proto, "get_inputs")
        )
    ), (
        f"Not the same number of given inputs {len(flat)} "
        f"and the number of model inputs {len(names)}, "
        f"type(names)={type(names)}, type(proto)={type(proto)}"
        f"\n-- inputs={string_type(inputs, with_shape=True)}"
        f"\n-- names={names}"
    )
    if len(names) < len(flat) and (
        isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
    ):
        typed_names = (
            [(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
            if isinstance(proto, onnx.ModelProto)
            else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
        )
        new_flat = []
        pos = 0
        for _name, dtype in typed_names:
            assert isinstance(
                dtype, int
            ), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
            itype = dtype_to_tensor_dtype(flat[pos].dtype)
            while dtype != itype:
                pos += 1
                if pos >= len(flat):
                    break
                itype = dtype_to_tensor_dtype(flat[pos].dtype)
            if pos >= len(flat):
                break
            new_flat.append(flat[pos])
            pos += 1
        assert len(new_flat) == len(names), (
            f"Unable to align expected input {names} with the given input, "
            f"type(proto)={type(proto)}"
            f"\n-- inputs: {string_type(inputs, with_shape=True)}"
            f"\n-- typed_names: {typed_names}"
        )
        flat = new_flat
    if copy:
        flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
    # bool, int, float, onnxruntime does not support float, bool, int
    new_flat = []
    for i in flat:
        if isinstance(i, bool):
            i = np.array(i, dtype=np.bool_)
        elif isinstance(i, int):
            i = np.array(i, dtype=np.int64)
        elif isinstance(i, float):
            i = np.array(i, dtype=np.float32)
        new_flat.append(i)
    return dict(zip(names, new_flat))