import numpy as np
from onnx.reference.op_run import OpRun
[docs]
class AddAdd(OpRun):
    op_domain = "onnx_extended.ortops.optim.cuda"
    def _run(self, x, y, z):
        return (x + y + z,) 
[docs]
class MulMul(OpRun):
    op_domain = "onnx_extended.ortops.optim.cuda"
    def _run(self, x, y, z):
        return (x * y * z,) 
[docs]
class AddMul(OpRun):
    op_domain = "onnx_extended.ortops.optim.cuda"
    def _run(self, x, y, z, transposeMiddle=None):
        res = (x + y) * z
        if transposeMiddle:
            res = np.transpose(res, axes=[0, 2, 1, 3])
        return (res,) 
[docs]
class MulAdd(OpRun):
    op_domain = "onnx_extended.ortops.optim.cuda"
    def _run(self, x, y, z, transposeMiddle=None):
        res = (x * y) + z
        if transposeMiddle:
            res = np.transpose(res, axes=[0, 2, 1, 3])
        return (res,) 
[docs]
class SubMul(OpRun):
    op_domain = "onnx_extended.ortops.optim.cuda"
    def _run(self, x, y, z, negative=None):
        if negative:
            return ((y - x) * z,)
        return ((x - y) * z,) 
[docs]
class MulSub(OpRun):
    op_domain = "onnx_extended.ortops.optim.cuda"
    def _run(self, x, y, z, negative=None):
        if negative:
            return (z - (x * y),)
        return ((x * y) - z,)