Source code for onnx_diagnostic.reference.torch_ops.binary_ops
import torch
from . import OpRunKernel, OpRunTensor
[docs]
class OpRunBinary(OpRunKernel):
    "Binary Op"
[docs]
    def run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        if x.get_device() != y.get_device():
            if x.get_device() >= 0:
                y = y.to(x.device)
            else:
                x = x.to(y.device)
        return self._run(x, y) 
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        raise NotImplementedError(f"Operator {self.__class__.__name__!r} is not complete.") 
[docs]
class And_1(OpRunBinary):
    """And"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor & y.tensor) 
[docs]
class Add_1(OpRunBinary):
    """Add"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor + y.tensor) 
[docs]
class Div_1(OpRunBinary):
    """Div"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor / y.tensor) 
[docs]
class Equal_1(OpRunBinary):
    """Equal"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor == y.tensor) 
[docs]
class Greater_1(OpRunBinary):
    """Greater"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor > y.tensor) 
[docs]
class GreaterOrEqual_1(OpRunBinary):
    """GreaterOrEqual"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor >= y.tensor) 
[docs]
class Less_1(OpRunBinary):
    """Less"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor < y.tensor) 
[docs]
class LessOrEqual_1(OpRunBinary):
    """LessOrEqual"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor <= y.tensor) 
[docs]
class MatMul_1(OpRunBinary):
    """MatMul"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor @ y.tensor) 
[docs]
class Mul_1(OpRunBinary):
    """Mul"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor * y.tensor) 
[docs]
class Or_1(OpRunBinary):
    """Or"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor | y.tensor) 
[docs]
class Pow_12(OpRunBinary):
    """Pow"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(torch.pow(x.tensor, y.tensor)) 
[docs]
class Sub_1(OpRunBinary):
    """Sub"""
    def _run(self, x: OpRunTensor, y: OpRunTensor) -> OpRunTensor:
        return OpRunTensor(x.tensor - y.tensor)