Source code for onnx_diagnostic.reference.torch_ops.access_ops

from typing import Optional
import onnx
import torch
from . import OpRun, OpRunTensor


[docs] class Gather_1(OpRun): "Gather" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) axis = self.get_attribute_int(node, "axis", 0) assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}" self.axis = axis
[docs] def run(self, x, indices): if indices.tensor.numel() == 0: return torch.empty((0,), dtype=x.tensor.dtype, device=x.tensor.device) ind = [slice(0, s) for s in x.shape] ind[self.axis] = indices.tensor return OpRunTensor(x.tensor[tuple(ind)])
[docs] class ScatterND_16(OpRun): "ScatterND" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.reduction = self.get_attribute_string(node, "reduction", "none")
[docs] def run( self, data: OpRunTensor, indices: OpRunTensor, updates: OpRunTensor ) -> OpRunTensor: # This implementation is not efficient. grids = torch.meshgrid(*[torch.arange(s) for s in indices.shape[:-1]], indexing="ij") stacked = torch.stack(grids, dim=-1) index = stacked.reshape(-1, len(indices.shape) - 1) output = data.tensor.clone() for i in index: if self.reduction == "add": output[indices.tensor[i]] += updates.tensor[i] elif self.reduction == "mul": output[indices.tensor[i]] *= updates.tensor[i] elif self.reduction == "max": output[indices.tensor[i]] = torch.maximum( output[indices.tensor[i]], updates.tensor[i] ) elif self.reduction == "min": output[indices.tensor[i]] = torch.minimum( output[indices.tensor[i]], updates.tensor[i] ) else: output[indices.tensor[i]] = updates.tensor[i] return OpRunTensor(output)
[docs] class Slice_13(OpRun): "Slice"
[docs] def run( self, data: OpRunTensor, starts: OpRunTensor, ends: OpRunTensor, axes: Optional[OpRunTensor] = None, steps: Optional[OpRunTensor] = None, ) -> OpRunTensor: if axes is None: if steps is None: slices = [slice(s, e) for s, e in zip(starts.tensor, ends.tensor)] else: slices = [ slice(s, e, d) for s, e, d in zip(starts.tensor, ends.tensor, steps.tensor) ] else: if steps is None: slices = [slice(0, a) for a in data.shape] for s, e, a in zip(starts.tensor, ends.tensor, axes.tensor): slices[a] = slice(s, e) else: slices = [slice(0, a) for a in data.shape] for s, e, a, d in zip(starts.tensor, ends.tensor, axes.tensor, steps.tensor): slices[a] = slice(s, e, d) return OpRunTensor(data.tensor[tuple(slices)])