from typing import Optional, Tuple
import onnx
import torch
from . import OpRunKernel, OpRunTensor
[docs]
class ConstantOfShape_9(OpRunKernel):
    "ConstantOfShape"
[docs]
    @classmethod
    def device_dependent(cls) -> bool:
        """
        Returns True if the kernel needs a device to be efficiently initialized.
        """
        return True 
    def __init__(
        self,
        node: onnx.NodeProto,
        version: Optional[int] = None,
        device: Optional[torch.device] = None,
        verbose: int = 0,
    ):
        super().__init__(node, version, verbose=verbose)
        value = self.get_attribute_tensor(node, "value")
        if value is None:
            value = torch.tensor([0], dtype=torch.float32)
        self.dtype = value.dtype
        self.device = device
        self.value = value[0]
[docs]
    def run(self, shape: OpRunTensor) -> OpRunTensor:
        # The device is unknown as shapes usually take place on CPU.
        return OpRunTensor(
            torch.full(
                shape.as_tuple_int, fill_value=self.value, dtype=self.dtype, device=self.device
            )
        ) 
 
[docs]
class Expand_8(OpRunKernel):
    "Expand"
[docs]
    def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
        ishape = tuple(-1 if i == 1 else i for i in shape.as_tuple_int)
        return OpRunTensor(data.tensor.expand(ishape)) 
 
[docs]
class Reshape_14(OpRunKernel):
    "Reshape"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        self.allowzero = self.get_attribute_int(node, "allowzero", 0)
[docs]
    def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor:
        ishape = shape.as_tuple_int
        assert ishape is not None, f"Unexpected return for shape={shape!r}"
        if not self.allowzero and 0 in ishape:
            xshape = data.tensor.shape
            new_shape = []
            for i, s in enumerate(ishape):
                new_shape.append(xshape[i] if s == 0 else s)
            return OpRunTensor(data.tensor.reshape(new_shape))
        return OpRunTensor(data.tensor.reshape(ishape)) 
 
[docs]
class Shape_15(OpRunKernel):
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        self.start = self.get_attribute_int(node, "start", 0)
        self.end = self.get_attribute_int(node, "end", None)
[docs]
    def run(self, data: OpRunTensor) -> OpRunTensor:
        shape = data.shape
        sh = shape[self.start :] if self.end is None else shape[self.start : self.end]
        return OpRunTensor(torch.tensor(sh, dtype=torch.int64), is_constant=True) 
 
[docs]
class Split_18(OpRunKernel):
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        self.axis = self.get_attribute_int(node, "axis", 0)
        self.num_outputs = self.get_attribute_int(node, "num_outputs", None)
[docs]
    def run(
        self, data: OpRunTensor, split: Optional[OpRunTensor] = None
    ) -> Tuple[OpRunTensor, ...]:
        if split is None:
            assert isinstance(
                self.num_outputs, int
            ), f"Incompatibilities: split is None and num_outputs={self.num_outputs}"
            size = data.tensor.shape[self.axis]
            split_size = (
                size // self.num_outputs
                if size % self.num_outputs == 0
                else size // self.num_outputs + 1
            )
            spl = torch.split(data.tensor, split_size, dim=self.axis)
        else:
            spl = torch.split(data.tensor, split.as_tuple_int, dim=self.axis)
        return tuple(OpRunTensor(t) for t in spl) 
 
[docs]
class Squeeze_13(OpRunKernel):
    "Squeeze"
[docs]
    def run(self, data: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
        if axes is None:
            return OpRunTensor(data.tensor.squeeze())
        return OpRunTensor(data.tensor.squeeze(axes.as_tuple_int)) 
 
[docs]
class Unsqueeze_13(OpRunKernel):
    "Unsqueeze"
[docs]
    def run(self, data: OpRunTensor, axes: OpRunTensor) -> OpRunTensor:
        t = data.tensor
        for i in axes.as_tuple_int:
            t = t.unsqueeze(i)
        return OpRunTensor(t)