Source code for onnx_diagnostic.reference.torch_ops.other_ops

from typing import Optional
import onnx
import torch
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
from . import OpRun, OpRunTensor


[docs] class Cast_6(OpRun): "Cast" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) to = self.get_attribute_int(node, "to", 0) assert isinstance(to, int), f"Unexpected value for attribute to={to!r}" self.to = onnx_dtype_to_torch_dtype(to) self.saturate = self.get_attribute_int(node, "saturate", 1) assert self.saturate == 1, f"saturate={self.saturate} not implemented for Cast"
[docs] def run(self, data: OpRunTensor) -> OpRunTensor: return OpRunTensor(data.tensor.to(self.to))
[docs] class CastLike_15(OpRun): "Cast" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.saturate = self.get_attribute_int(node, "saturate", 1) assert self.saturate == 1, f"saturate={self.saturate} not implemented for CastLike"
[docs] def run(self, data: OpRunTensor, like: OpRunTensor) -> OpRunTensor: return OpRunTensor(data.tensor.to(like.tensor.dtype))
[docs] class Concat_1(OpRun): "Concat" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) axis = self.get_attribute_int(node, "axis", 0) assert isinstance(axis, int), f"Unexpected value for attribute axis={axis!r}" self.axis = axis
[docs] def run(self, *data: OpRunTensor) -> OpRunTensor: assert data, f"No tensor to concatenate in node name {self.name!r}" devices = [d.get_device() for d in data] if len(set(devices)) == 1: return OpRunTensor(torch.cat([t.tensor for t in data], axis=self.axis)) if ( data[0].dtype == torch.int64 and self.axis == 0 and max(d.tensor.ndim for d in data) == 1 and max(d.tensor.numel() for d in data) <= 8 ): # This is a shape return OpRunTensor(torch.cat([t.tensor.cpu() for t in data], axis=self.axis)) index = devices.index(max(devices)) device = data[index].tensor.device return OpRunTensor(torch.cat([t.tensor.to(device) for t in data], axis=self.axis))
[docs] class NonZero_13(OpRun): "NonZero"
[docs] def run(self, x: OpRunTensor) -> OpRunTensor: return OpRunTensor(torch.nonzero(x.tensor).T)
[docs] class Tile_6(OpRun): "Tile"
[docs] def run(self, x: OpRunTensor, repeat: OpRunTensor) -> OpRunTensor: return OpRunTensor(torch.tile(x.tensor, repeat.as_tuple_int))
[docs] class Transpose_1(OpRun): "Transpose" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.perm = self.get_attribute_ints(node, "perm", None)
[docs] def run(self, data: OpRunTensor) -> OpRunTensor: return OpRunTensor(torch.permute(data.tensor, self.perm))
[docs] class Trilu_14(OpRun): "Trilu" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.upper = self.get_attribute_int(node, "upper", 1)
[docs] def run(self, data: OpRunTensor, k: Optional[OpRunTensor] = None) -> OpRunTensor: diagonal = 0 if k is None else k.tensor.item() if self.upper: return OpRunTensor(torch.triu(data.tensor, diagonal=diagonal)) return OpRunTensor(torch.tril(data.tensor, diagonal=diagonal))
[docs] class Where_9(OpRun): "Where"
[docs] def run(self, cond: OpRunTensor, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor: tcond, tx, ty = self.same_device(cond.tensor, x.tensor, y.tensor) return OpRunTensor(torch.where(tcond, tx, ty))