from typing import Tuple
from onnx.defs import OpSchema
from onnx.helper import make_attribute
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_conv import Conv
from onnx.reference.ops.op_dequantize_linear import DequantizeLinear_19 as DequantizeLinear
from onnx.reference.ops.op_quantize_linear import QuantizeLinear_19 as QuantizeLinear
def _switch_dims_nchw_nhwc(dims: Tuple[int, ...], from_nchw_to_nhwc: bool):
    if len(dims) == 4:
        if from_nchw_to_nhwc:
            return (dims[0], *dims[2:], dims[1])
        return (dims[0], dims[-1], *dims[1:-1])
    if len(dims) == 3:
        if from_nchw_to_nhwc:
            return (*dims[1:], dims[0])
        return (dims[-1], *dims[:-1])
    raise NotImplementedError(f"Unable to process shape={dims}")
[docs]
class QLinearConv(OpRun):
    op_domain = "com.microsoft"
    op_schema = OpSchema(
        "QLinearConv",
        "com.microsoft",
        1,
        inputs=[
            OpSchema.FormalParameter("x", "T"),
            OpSchema.FormalParameter("x_scale", "T"),
            OpSchema.FormalParameter("x_zero_point", "T1"),
            OpSchema.FormalParameter("w", "T"),
            OpSchema.FormalParameter("w_scale", "T"),
            OpSchema.FormalParameter("w_zero_point", "T2"),
            OpSchema.FormalParameter("y_scale", "T"),
            OpSchema.FormalParameter("y_zero_point", "T3"),
            OpSchema.FormalParameter(
                "B", "T3", param_option=OpSchema.FormalParameterOption.Optional
            ),
        ],
        outputs=[OpSchema.FormalParameter("y", "T3")],
        type_constraints=[
            ("T", ["tensor(float)"], ""),
            ("T1", ["tensor(int8)", "tensor(uint8)"], ""),
            ("T2", ["tensor(int8)", "tensor(uint8)"], ""),
            ("T3", ["tensor(int8)", "tensor(uint8)"], ""),
        ],
        attributes=[
            OpSchema.Attribute("auto_pad", make_attribute("auto_pad", "NOTSET"), ""),
            OpSchema.Attribute("kernel_shape", OpSchema.AttrType.INTS, "", required=False),
            OpSchema.Attribute("dilations", OpSchema.AttrType.INTS, "", required=False),
            OpSchema.Attribute("strides", OpSchema.AttrType.INTS, "", required=False),
            OpSchema.Attribute("pads", OpSchema.AttrType.INTS, "", required=False),
            OpSchema.Attribute("group", make_attribute("group", 1), ""),
            OpSchema.Attribute("channels_last", make_attribute("channels_last", 0), ""),
        ],
    )
    def _run(
        self,
        x,
        x_scale,
        x_zero_point,
        w,
        w_scale,
        w_zero_point,
        y_scale,
        y_zero_point,
        B=None,
        auto_pad=None,
        channels_last=None,
        dilations=None,
        group=None,
        kernel_shape=None,
        pads=None,
        strides=None,
    ):
        dqx = DequantizeLinear.eval(x, x_scale, x_zero_point)
        dqw = DequantizeLinear.eval(w, w_scale, w_zero_point)
        if channels_last:
            dqx = dqx.reshape(_switch_dims_nchw_nhwc(x.shape, False))
        dqb = (
            DequantizeLinear.eval(B, x_scale * w_scale, 0).astype(dqx.dtype)
            if B is not None
            else None
        )
        y = Conv.eval(
            dqx,
            dqw,
            dqb,
            auto_pad=auto_pad,
            dilations=dilations,
            group=group,
            kernel_shape=kernel_shape,
            pads=pads,
            strides=strides,
        )
        if channels_last:
            y = y.reshape(_switch_dims_nchw_nhwc(y.shape, True))
        qy = QuantizeLinear.eval(y, y_scale, y_zero_point)
        return (qy,)