Source code for onnx_diagnostic.reference.torch_ops.shape_ops

from typing import Optional, Tuple
import onnx
import torch
from . import OpRun, OpRunTensor


[docs] class ConstantOfShape_9(OpRun): "ConstantOfShape"
[docs] @classmethod def device_dependent(cls) -> bool: """ Returns True if the kernel needs a device to be efficiently initialized. """ return True
def __init__( self, node: onnx.NodeProto, version: Optional[int] = None, device: Optional[torch.device] = None, ): super().__init__(node, version) value = self.get_attribute_tensor(node, "value") if value is None: value = torch.tensor([0], dtype=torch.float32) self.dtype = value.dtype self.device = device self.value = value[0]
[docs] def run(self, shape: OpRunTensor) -> OpRunTensor: # The device is unknown as shapes usually take place on CPU. return OpRunTensor( torch.full( shape.as_tuple_int, fill_value=self.value, dtype=self.dtype, device=self.device ) )
[docs] class Expand_8(OpRun): "Expand"
[docs] def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor: ishape = tuple(-1 if i == 1 else i for i in shape.as_tuple_int) return OpRunTensor(data.tensor.expand(ishape))
[docs] class Reshape_14(OpRun): "Reshape" def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.allowzero = self.get_attribute_int(node, "allowzero", 0)
[docs] def run(self, data: OpRunTensor, shape: OpRunTensor) -> OpRunTensor: ishape = shape.as_tuple_int assert ishape is not None, f"Unexpected return for shape={shape!r}" if not self.allowzero and 0 in ishape: xshape = data.tensor.shape new_shape = [] for i, s in enumerate(ishape): new_shape.append(xshape[i] if s == 0 else s) return OpRunTensor(data.tensor.reshape(new_shape)) return OpRunTensor(data.tensor.reshape(ishape))
[docs] class Shape_15(OpRun): def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.start = self.get_attribute_int(node, "start", 0) self.end = self.get_attribute_int(node, "end", None)
[docs] def run(self, data: OpRunTensor) -> OpRunTensor: shape = data.shape sh = shape[self.start :] if self.end is None else shape[self.start : self.end] return OpRunTensor(torch.tensor(sh, dtype=torch.int64), is_constant=True)
[docs] class Split_18(OpRun): def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.axis = self.get_attribute_int(node, "axis", 0) self.num_outputs = self.get_attribute_int(node, "num_outputs", None)
[docs] def run( self, data: OpRunTensor, split: Optional[OpRunTensor] = None ) -> Tuple[OpRunTensor, ...]: if split is None: assert isinstance( self.num_outputs, int ), f"Incompatibilities: split is None and num_outputs={self.num_outputs}" size = data.tensor.shape[self.axis] split_size = ( size // self.num_outputs if size % self.num_outputs == 0 else size // self.num_outputs + 1 ) spl = torch.split(data.tensor, split_size, dim=self.axis) else: spl = torch.split(data.tensor, split.as_tuple_int, dim=self.axis) return tuple(OpRunTensor(t) for t in spl)
[docs] class Squeeze_13(OpRun): "Squeeze"
[docs] def run(self, data: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor: if axes is None: return OpRunTensor(data.tensor.squeeze()) return OpRunTensor(data.tensor.squeeze(axes.as_tuple_int))
[docs] class Unsqueeze_13(OpRun): "Unsqueeze"
[docs] def run(self, data: OpRunTensor, axes: OpRunTensor) -> OpRunTensor: t = data.tensor for i in axes.as_tuple_int: t = t.unsqueeze(i) return OpRunTensor(t)