Source code for experimental_experiment.reference.ops.op_qlinear_conv

from typing import Tuple
from onnx.defs import OpSchema
from onnx.helper import make_attribute
from onnx.reference.op_run import OpRun
from onnx.reference.ops.op_conv import Conv
from onnx.reference.ops.op_dequantize_linear import DequantizeLinear_19 as DequantizeLinear
from onnx.reference.ops.op_quantize_linear import QuantizeLinear_19 as QuantizeLinear


def _switch_dims_nchw_nhwc(dims: Tuple[int, ...], from_nchw_to_nhwc: bool):
    if len(dims) == 4:
        if from_nchw_to_nhwc:
            return (dims[0], *dims[2:], dims[1])
        return (dims[0], dims[-1], *dims[1:-1])
    if len(dims) == 3:
        if from_nchw_to_nhwc:
            return (*dims[1:], dims[0])
        return (dims[-1], *dims[:-1])
    raise NotImplementedError(f"Unable to process shape={dims}")


[docs] class QLinearConv(OpRun): op_domain = "com.microsoft" op_schema = OpSchema( "QLinearConv", "com.microsoft", 1, inputs=[ OpSchema.FormalParameter("x", "T"), OpSchema.FormalParameter("x_scale", "T"), OpSchema.FormalParameter("x_zero_point", "T1"), OpSchema.FormalParameter("w", "T"), OpSchema.FormalParameter("w_scale", "T"), OpSchema.FormalParameter("w_zero_point", "T2"), OpSchema.FormalParameter("y_scale", "T"), OpSchema.FormalParameter("y_zero_point", "T3"), OpSchema.FormalParameter( "B", "T3", param_option=OpSchema.FormalParameterOption.Optional ), ], outputs=[OpSchema.FormalParameter("y", "T3")], type_constraints=[ ("T", ["tensor(float)"], ""), ("T1", ["tensor(int8)", "tensor(uint8)"], ""), ("T2", ["tensor(int8)", "tensor(uint8)"], ""), ("T3", ["tensor(int8)", "tensor(uint8)"], ""), ], attributes=[ OpSchema.Attribute("auto_pad", make_attribute("auto_pad", "NOTSET"), ""), OpSchema.Attribute("kernel_shape", OpSchema.AttrType.INTS, "", required=False), OpSchema.Attribute("dilations", OpSchema.AttrType.INTS, "", required=False), OpSchema.Attribute("strides", OpSchema.AttrType.INTS, "", required=False), OpSchema.Attribute("pads", OpSchema.AttrType.INTS, "", required=False), OpSchema.Attribute("group", make_attribute("group", 1), ""), OpSchema.Attribute("channels_last", make_attribute("channels_last", 0), ""), ], ) def _run( self, x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, B=None, auto_pad=None, channels_last=None, dilations=None, group=None, kernel_shape=None, pads=None, strides=None, ): dqx = DequantizeLinear.eval(x, x_scale, x_zero_point) dqw = DequantizeLinear.eval(w, w_scale, w_zero_point) if channels_last: dqx = dqx.reshape(_switch_dims_nchw_nhwc(x.shape, False)) dqb = ( DequantizeLinear.eval(B, x_scale * w_scale, 0).astype(dqx.dtype) if B is not None else None ) y = Conv.eval( dqx, dqw, dqb, auto_pad=auto_pad, dilations=dilations, group=group, kernel_shape=kernel_shape, pads=pads, strides=strides, ) if channels_last: y = y.reshape(_switch_dims_nchw_nhwc(y.shape, True)) qy = QuantizeLinear.eval(y, y_scale, y_zero_point) return (qy,)