from typing import Optional
import onnx
import torch
from . import OpRun, OpRunTensor
[docs]
class Gather_1(OpRun):
    "Gather"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
        super().__init__(node, version)
        axis = self.get_attribute_int(node, "axis", 0)
        assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}"
        self.axis = axis
[docs]
    def run(self, x, indices):
        if indices.tensor.numel() == 0:
            return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device)
        ind = [slice(0, s) for s in x.shape]
        ind[self.axis] = indices.tensor
        return OpRunTensor(x.tensor[tuple(ind)]) 
 
[docs]
class ScatterND_16(OpRun):
    "ScatterND"
    def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
        super().__init__(node, version)
        self.reduction = self.get_attribute_string(node, "reduction", "none")
[docs]
    def run(
        self, data: OpRunTensor, indices: OpRunTensor, updates: OpRunTensor
    ) -> OpRunTensor:
        # This implementation is not efficient.
        grids = torch.meshgrid(*[torch.arange(s) for s in indices.shape[:-1]], indexing="ij")
        stacked = torch.stack(grids, dim=-1)
        index = stacked.reshape(-1, len(indices.shape) - 1)
        output = data.tensor.clone()
        for i in index:
            if self.reduction == "add":
                output[indices.tensor[i]] += updates.tensor[i]
            elif self.reduction == "mul":
                output[indices.tensor[i]] *= updates.tensor[i]
            elif self.reduction == "max":
                output[indices.tensor[i]] = torch.maximum(
                    output[indices.tensor[i]], updates.tensor[i]
                )
            elif self.reduction == "min":
                output[indices.tensor[i]] = torch.minimum(
                    output[indices.tensor[i]], updates.tensor[i]
                )
            else:
                output[indices.tensor[i]] = updates.tensor[i]
        return OpRunTensor(output) 
 
[docs]
class Slice_13(OpRun):
    "Slice"
[docs]
    def run(
        self,
        data: OpRunTensor,
        starts: OpRunTensor,
        ends: OpRunTensor,
        axes: Optional[OpRunTensor] = None,
        steps: Optional[OpRunTensor] = None,
    ) -> OpRunTensor:
        if axes is None:
            if steps is None:
                slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)]
            else:
                slices = [
                    slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor)
                ]
        else:
            if steps is None:
                slices = [slice(0, a) for a in data.shape]
                for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor):
                    slices[a] = slice(s, e)
            else:
                slices = [slice(0, a) for a in data.shape]
                for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor):
                    slices[a] = slice(s, e, d)
        return OpRunTensor(data.tensor[tuple(slices)])