Source code for onnx_diagnostic.reference.torch_ops.binary_ops
import torch
from . import OpRun, OpRunTensor
[docs]
class OpRunBinary(OpRun):
"Binary Op"
[docs]
def run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
if x.get_device() != y.get_device():
if x.get_device() >= 0:
y = y.to(x.device)
else:
x = x.to(y.device)
return self._run(x, y)
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
raise NotImplementedError(f"Operator {self.__class__.__name__!r} is not complete.")
[docs]
class And_1(OpRunBinary):
"""And"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor & y.tensor)
[docs]
class Add_1(OpRunBinary):
"""Add"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor + y.tensor)
[docs]
class Div_1(OpRunBinary):
"""Div"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor / y.tensor)
[docs]
class Equal_1(OpRunBinary):
"""Equal"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor == y.tensor)
[docs]
class Greater_1(OpRunBinary):
"""Greater"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor > y.tensor)
[docs]
class GreaterOrEqual_1(OpRunBinary):
"""GreaterOrEqual"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor >= y.tensor)
[docs]
class Less_1(OpRunBinary):
"""Less"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor < y.tensor)
[docs]
class LessOrEqual_1(OpRunBinary):
"""LessOrEqual"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor <= y.tensor)
[docs]
class MatMul_1(OpRunBinary):
"""MatMul"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor @ y.tensor)
[docs]
class Mul_1(OpRunBinary):
"""Mul"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor * y.tensor)
[docs]
class Or_1(OpRunBinary):
"""Or"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor | y.tensor)
[docs]
class Pow_12(OpRunBinary):
"""Pow"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(torch.pow(x.tensor, y.tensor))
[docs]
class Sub_1(OpRunBinary):
"""Sub"""
def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
return OpRunTensor(x.tensor - y.tensor)