from typing import Optional
import onnx
import torch
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
from . import OpRunKernel, OpRunTensor
[docs]
class Cast_6(OpRunKernel):
    "Cast"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        to = self.get_attribute_int(node, "to", 0)
        assert isinstance(to, int), f"Unexpected value for attribute to={to!r}"
        self.to = onnx_dtype_to_torch_dtype(to)
        self.saturate = self.get_attribute_int(node, "saturate", 1)
        assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
[docs]
    def run(self, data: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(data.tensor.to(self.to)) 
 
[docs]
class CastLike_15(OpRunKernel):
    "Cast"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        self.saturate = self.get_attribute_int(node, "saturate", 1)
        assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
[docs]
    def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(data.tensor.to(like.tensor.dtype)) 
 
[docs]
class Concat_1(OpRunKernel):
    "Concat"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        axis = self.get_attribute_int(node, "axis", 0)
        assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
        self.axis = axis
[docs]
    def run(self, *data: OpRunTensor) -> OpRunTensor:
        assert data, f"No tensor to concatenate in node name {self.name!r}"
        devices = [d.get_device() for d in data]
        if len(set(devices)) == 1:
            return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis))
        if (
            data[0].dtype == torch.int64
            and self.axis == 0
            and max(d.tensor.ndim for d in data) == 1
            and max(d.tensor.numel() for d in data) <= 8
        ):
            # This is a shape
            return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis))
        index = devices.index(max(devices))
        device = data[index].tensor.device
        return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis)) 
 
[docs]
class NonZero_13(OpRunKernel):
    "NonZero"
[docs]
    def run(self, x: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(torch.nonzero(x.tensor).T) 
 
[docs]
class Tile_6(OpRunKernel):
    "Tile"
[docs]
    def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int)) 
 
[docs]
class Transpose_1(OpRunKernel):
    "Transpose"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        self.perm = self.get_attribute_ints(node, "perm", None)
[docs]
    def run(self, data: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(torch.permute(data.tensor, self.perm)) 
 
[docs]
class Trilu_14(OpRunKernel):
    "Trilu"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: int = 0):
        super().__init__(node, version, verbose=verbose)
        self.upper = self.get_attribute_int(node, "upper", 1)
[docs]
    def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor:
        diagonal = 0 if k is None else k.tensor.item()
        if self.upper:
            return OpRunTensor(torch.triu(data.tensor, diagonal=diagonal))
        return OpRunTensor(torch.tril(data.tensor, diagonal=diagonal)) 
 
[docs]
class Where_9(OpRunKernel):
    "Where"
[docs]
    def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor)
        return OpRunTensor(torch.where(tcond, tx, ty))