from typing import Any, Dict
import numpy as np
from onnx import NodeProto
from onnx.reference.op_run import OpRun
from .cpu.c_op_conv_ import ConvDouble, ConvFloat
[docs]class Conv(OpRun):
def __init__(
self, onnx_node: NodeProto, run_params: Dict[str, Any], schema: Any = None
):
OpRun.__init__(self, onnx_node, run_params, schema)
self.cache_: Dict[type, Any] = {}
def _run(
self,
X,
W,
B=None,
auto_pad=None,
dilations=None,
group=None,
kernel_shape=None,
pads=None,
strides=None,
):
if X.dtype not in self.cache_:
if X.dtype == np.float32:
rt = ConvFloat()
elif X.dtype == np.float64:
rt = ConvDouble()
else:
raise TypeError(
f"No C implementation C for operator 'Conv' and dtype={X.dtype}."
)
self.cache_[X.dtype] = rt
rt.init(
auto_pad,
np.array(dilations or [], dtype=np.int64),
group,
np.array(kernel_shape or [], dtype=np.int64),
np.array(pads or [], dtype=np.int64),
np.array(strides or [], dtype=np.int64),
)
rt = self.cache_[X.dtype]
if X is None:
raise ValueError(f"X cannot be None for operator {type(self)}.")
if min(X.shape) == 0:
raise RuntimeError(
f"Unable to run operator Conv on an empty matrix. X.shape={X.shape!r}."
)
if min(W.shape) == 0:
raise RuntimeError(
f"Unable to run operator Conv on an empty matrix. W.shape={W.shape!r}."
)
if B is not None and min(B.shape) == 0:
raise RuntimeError(
f"Unable to run operator Conv on an empty matrix. B.shape={B.shape!r}."
)
cv = rt.compute(X, W, B)
return (cv,)