from typing import Optional, Tuple
import onnx
import torch
from ...helpers.torch_helper import onnx_dtype_to_torch_dtype
from . import OpRun, OpRunTensor
[docs]
class ReduceOp(OpRun):
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
super().__init__(node, version)
self.keepdims = bool(self.get_attribute_int(node, "keepdims", 1))
self.noop_with_empty_axes = bool(
self.get_attribute_int(node, "noop_with_empty_axes", 0)
)
assert isinstance(
self.keepdims, bool
), f"Unexpected value for attribute keepdims={self.keepdims!r}"
assert isinstance(self.noop_with_empty_axes, bool), (
f"Unexpected value for attribute "
f"noop_with_empty_axes={self.noop_with_empty_axes!r}"
)
assert (
not self.noop_with_empty_axes
), f"Not implemented with noop_with_empty_axes={self.noop_with_empty_axes}"
# this is out of spec
stash_type = self.get_attribute_int(node, "stash_type", None)
self.stash_type = None if stash_type is None else onnx_dtype_to_torch_dtype(stash_type)
[docs]
class ReduceOpAxes(ReduceOp):
def __init__(self, node: onnx.NodeProto, version: Optional[int] = None):
super().__init__(node, version)
self.axes: Tuple[int, ...] = self.get_attribute_ints(node, "axes") or tuple()
[docs]
class ReduceMax_18(ReduceOp):
"""ReduceMax"""
[docs]
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
if axes is None:
assert (
not self.keepdims
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
return OpRunTensor(x.tensor.max())
taxes = axes.as_tuple_int
if len(taxes) == 1:
t = x.tensor.max(taxes[0], keepdim=self.keepdims)
return OpRunTensor(t.values)
t = x.tensor
for a in reversed(taxes):
t = t.max(a, keepdim=self.keepdims).values
return OpRunTensor(t)
[docs]
class ReduceMean_18(ReduceOp):
"""ReduceMean"""
[docs]
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
if axes is None:
assert (
not self.keepdims
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
return OpRunTensor(torch.mean(x.tensor))
taxes = axes.as_tuple_int
if len(taxes) == 1:
t = x.tensor.mean(taxes[0], keepdim=self.keepdims)
return OpRunTensor(t)
t = x.tensor.mean(taxes, keepdim=self.keepdims)
return OpRunTensor(t)
[docs]
class ReduceMin_17(ReduceOpAxes):
"""ReduceMin"""
[docs]
def run(self, x: OpRunTensor) -> OpRunTensor:
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
axes = self.axes
if not axes:
assert (
not self.keepdims
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
return OpRunTensor(x.tensor.min())
taxes = tuple(axes)
if len(taxes) == 1:
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
return OpRunTensor(t.values)
t = x.tensor
for a in reversed(taxes):
t = t.min(a, keepdim=self.keepdims).values
return OpRunTensor(t)
[docs]
class ReduceMin_18(ReduceOp):
"""ReduceMin"""
[docs]
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
if axes is None:
assert (
not self.keepdims
), f"axes is empty, keepdims={self.keepdims} for {self.__class__.__name__}"
return OpRunTensor(torch.min(x.tensor))
taxes = axes.as_tuple_int
if len(taxes) == 1:
t = x.tensor.min(taxes[0], keepdim=self.keepdims)
return OpRunTensor(t.values)
t = x.tensor
for a in reversed(taxes):
t = t.min(a, keepdim=self.keepdims).values
return OpRunTensor(t)
[docs]
class ReduceSum_13(ReduceOp):
"""ReduceSum"""
[docs]
def run(self, x: OpRunTensor, axes: Optional[OpRunTensor] = None) -> OpRunTensor:
assert self.stash_type is None, f"Not implemented with stash_type={self.stash_type}"
if axes is None:
assert (
not self.keepdims
), f"axes is Empty, keepdims={self.keepdims} for {self.__class__.__name__}"
return OpRunTensor(torch.sum(x.tensor))
taxes = axes.as_tuple_int
if len(taxes) == 1:
t = x.tensor.sum(taxes[0], keepdim=self.keepdims)
return OpRunTensor(t)
t = x.tensor.sum(taxes, keepdim=self.keepdims)
return OpRunTensor(t)