import enum
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import onnx
import torch
from ..api import TensorLike
from ..helpers import string_type
[docs]
class RuntimeValueKind(enum.IntEnum):
    "Kind of result."
    RESULT = 1
    INITIALIZER = 3
    INPUT = 5
    OUTPUT = 9
    def to_str(self) -> str:
        for k, v in self.__class__.__dict__.items():
            if v == int(self):
                return k
        raise RuntimeError(f"Unable to display {self!r}") 
[docs]
class RuntimeDevice(enum.IntEnum):
    "Device definition"
    UNKNOWN = 0
    NEW = 1
    CPU = 2
    CUDA = 4
    def to_str(self) -> str:
        for k, v in self.__class__.__dict__.items():
            if v == int(self):
                return k
        raise RuntimeError(f"Unable to display {self!r}") 
[docs]
class RuntimeValue:
    """Describes a value used during the execution of a model."""
    def __init__(
        self,
        name: str,
        dtype: Optional[Any] = None,
        shape: Optional[Tuple[Union[str, int], ...]] = None,
        value: Optional[Any] = None,
        first_used: Optional[int] = None,
        last_used: Optional[int] = None,
        created: Optional[int] = None,
        is_shape: Optional[bool] = None,
        kind: Optional[RuntimeValueKind] = None,
        device: Optional[RuntimeDevice] = None,
    ):
        self.name = name
        self.dtype = dtype
        self.shape = shape
        self.value = value
        self.first_used = first_used
        self.last_used = last_used
        self.created = created
        self.is_shape = is_shape
        self.kind = kind
        self.device = device
    def __repr__(self) -> str:
        "usual"
        ad = {}
        for att in [
            "name",
            "dtype",
            "shape",
            "first_used",
            "last_used",
            "is_shape",
            "kind",
            "created",
            "device",
        ]:
            v = getattr(self, att)
            if v is not None:
                ad[att] = v
        if self.value is not None:
            ad["value"] = (
                self.value.string_type()
                if hasattr(self.value, "string_type")
                else string_type(self.value, with_shape=True)
            )
        msg = ", ".join(
            f"{name}={t.to_str()}" if hasattr(t, "to_str") else f"{name}={t}"
            for name, t in ad.items()
        )
        return f"{self.__class__.__name__}({msg})"
    @property
    def has_value(self) -> bool:
        "Tells if value is specified."
        return self.value is not None
[docs]
    def string_type(self) -> str:
        "Returns a string describing the value."
        rows = []
        if self.shape is not None:
            rows.append(f"shape={self.shape}")
        if self.is_shape is not None:
            rows.append(f"is_shape={self.is_shape}")
        if self.device is not None:
            rows.append(f"device={self.device}")
        text = f", {', '.join(rows)}" if rows else ""
        if self.value is None:
            return (
                f"RuntimeValue(name={self.name!r}{text}"
                f", dtype={self.dtype}, kind={self.kind})"
            )
        return (
            f"RuntimeValue(name={self.name!r}, "
            f"kind={self.kind}{text}, value={self.value.string_type()})"
        ) 
[docs]
    def set_value(self, value: Union[torch.Tensor, TensorLike]):
        """Sets the value."""
        assert value is not None, "Use clean_value to set a value to None"
        self.value = value
        is_sequence = hasattr(value, "is_sequence") and value.is_sequence()
        if self.dtype:
            assert value is None or self.dtype == value.dtype, (
                f"Unexpected dtype={value.dtype}, previous dtype was {self.dtype}, "
                f"is_sequence={is_sequence}"
            )
        else:
            self.dtype = value.dtype
        self.shape = None if is_sequence else tuple(map(int, value.shape)) 
[docs]
    def clean_value(self):
        """Sets value to None."""
        self.value = None 
    @property
    def is_output(self) -> bool:
        "Tells if it is an output."
        return self.kind == RuntimeValueKind.OUTPUT
    @property
    def is_input(self) -> bool:
        "Tells if it is an input."
        return self.kind == RuntimeValueKind.INPUT
    @property
    def is_initializer(self) -> bool:
        "Tells if it is an initializer."
        return self.kind == RuntimeValueKind.INITIALIZER 
[docs]
def set_is_shape(
    node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
) -> List[str]:
    """
    Sets attribute ``is_shape`` for outputs of a node.
    :param node: node to process
    :param values: stored results, values in this dictionary are updated
    :param drop: variables not to consider because the come from the graph
        holding this subgraph
    :return: list of modified results
    """
    if not node.input:
        # Constant
        return []
    drop = drop or set()
    if node.op_type in ("Shape", "Size") and node.domain == "":
        values[node.output[0]].is_shape = True
        return [node.output[0]]
    is_shapes = [values[i].is_shape for i in node.input if i not in drop]
    if any(is_shapes):
        if is_shapes[0] and len(node.output) == 1:
            values[node.output[0]].is_shape = True
            return [node.output[0]]
    else:
        for o in node.output:
            values[o].is_shape = False
        return list(node.output)
    return [] 
[docs]
def first_used_last_used(
    proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
    constant_as_initializer: bool = False,
) -> Dict[str, RuntimeValue]:
    """
    Builds first used, last used information for every result
    in the model.
    :param proto: model, graph or function
    :param constant_as_initializer: outputs of node Constant is tagged as INITIALIZER
    :return: dictionary of RuntimeValue
    """
    values = {}
    if isinstance(proto, onnx.ModelProto):
        initializer = proto.graph.initializer
        sparse_initializer = proto.graph.sparse_initializer
        _input = proto.graph.input
        output = proto.graph.output
        _node = proto.graph.node
        allow_unknown = False
    elif isinstance(proto, onnx.GraphProto):
        initializer = proto.initializer
        sparse_initializer = proto.sparse_initializer
        _input = proto.input
        output = proto.output
        _node = proto.node
        allow_unknown = True
    else:
        initializer = []
        sparse_initializer = []
        _input = proto.input
        output = proto.output
        _node = proto.node
        allow_unknown = False
    for init in initializer:
        values[init.name] = RuntimeValue(
            init.name, kind=RuntimeValueKind.INITIALIZER, created=-1
        )
    for init in sparse_initializer:
        values[init.name] = RuntimeValue(
            init.name, created=-1, kind=RuntimeValueKind.INITIALIZER
        )
    for inp in _input:
        n = inp if isinstance(inp, str) else inp.name
        values[n] = RuntimeValue(n, created=-1, kind=RuntimeValueKind.INPUT)
    drop = set()
    for it, node in enumerate(_node):
        for i in node.input:
            if i not in values:
                assert allow_unknown, f"Input {i!r} is unknown."
                # This input comes from a context and the model is a GraphProto
                drop.add(i)
                continue
            if values[i].first_used is None:
                values[i].first_used = it
            values[i].last_used = it
        for att in node.attribute:
            if att.type == onnx.AttributeProto.GRAPH:
                for n in get_hidden_inputs(att.g):
                    if values[n].first_used is None:
                        values[n].first_used = it
                    values[n].last_used = it
        is_constant = node.op_type == "Constant" and node.domain == ""
        for o in node.output:
            values[o] = RuntimeValue(
                o,
                created=it,
                kind=(
                    RuntimeValueKind.INITIALIZER
                    if is_constant and constant_as_initializer
                    else RuntimeValueKind.RESULT
                ),
            )
        set_is_shape(node, values, drop=drop)
    for out in output:
        n = out if isinstance(out, str) else out.name
        values[n].kind = RuntimeValueKind.OUTPUT
        values[n].last_used = len(_node)
    return values