Source code for onnx_diagnostic.reference.torch_ops.reduce_ops

from typing import Optional, Tuple
import onnx
import torch
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
from . import OpRun, OpRunTensor


[docs] class ReduceOp(OpRun): def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1)) self.noop_with_empty_axes = bool( self.get_attribute_int(node, "noop_with_empty_axes", 0) ) assert isinstance( self.keepdims, bool ), f"Unexpected value for attribute keepdims={self.keepdims!r}" assert isinstance(self.noop_with_empty_axes, bool), ( f"Unexpected value for attribute " f"noop_with_empty_axes={self.noop_with_empty_axes!r}" ) assert ( not self.noop_with_empty_axes ), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}" # this is out of spec stash_type = self.get_attribute_int(node, "stash_type", None) self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
[docs] class ReduceOpAxes(ReduceOp): def __init__(self, node: onnx.NodeProto, version: Optional[int] = None): super().__init__(node, version) self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
[docs] class ReduceMax_18(ReduceOp): """ReduceMax"""
[docs] def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor: assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}" if axes is None: assert ( not self.keepdims ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}" return OpRunTensor(x.tensor.max()) taxes = axes.as_tuple_int if len(taxes) == 1: t = x.tensor.max(taxes[0], keepdim=self.keepdims) return OpRunTensor(t.values) t = x.tensor for a in reversed(taxes): t = t.max(a, keepdim=self.keepdims).values return OpRunTensor(t)
[docs] class ReduceMean_18(ReduceOp): """ReduceMean"""
[docs] def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor: assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}" if axes is None: assert ( not self.keepdims ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}" return OpRunTensor(torch.mean(x.tensor)) taxes = axes.as_tuple_int if len(taxes) == 1: t = x.tensor.mean(taxes[0], keepdim=self.keepdims) return OpRunTensor(t) t = x.tensor.mean(taxes, keepdim=self.keepdims) return OpRunTensor(t)
[docs] class ReduceMin_17(ReduceOpAxes): """ReduceMin"""
[docs] def run(self, x: OpRunTensor) -> OpRunTensor: assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}" axes = self.axes if not axes: assert ( not self.keepdims ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}" return OpRunTensor(x.tensor.min()) taxes = tuple(axes) if len(taxes) == 1: t = x.tensor.min(taxes[0], keepdim=self.keepdims) return OpRunTensor(t.values) t = x.tensor for a in reversed(taxes): t = t.min(a, keepdim=self.keepdims).values return OpRunTensor(t)
[docs] class ReduceMin_18(ReduceOp): """ReduceMin"""
[docs] def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor: assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}" if axes is None: assert ( not self.keepdims ), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}" return OpRunTensor(torch.min(x.tensor)) taxes = axes.as_tuple_int if len(taxes) == 1: t = x.tensor.min(taxes[0], keepdim=self.keepdims) return OpRunTensor(t.values) t = x.tensor for a in reversed(taxes): t = t.min(a, keepdim=self.keepdims).values return OpRunTensor(t)
[docs] class ReduceSum_13(ReduceOp): """ReduceSum"""
[docs] def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor: assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}" if axes is None: assert ( not self.keepdims ), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}" return OpRunTensor(torch.sum(x.tensor)) taxes = axes.as_tuple_int if len(taxes) == 1: t = x.tensor.sum(taxes[0], keepdim=self.keepdims) return OpRunTensor(t) t = x.tensor.sum(taxes, keepdim=self.keepdims) return OpRunTensor(t)