Source code for onnx_diagnostic.reference.ops.op_constant_of_shape
import numpy as np
from onnx.reference.op_run import OpRun
try:
    import ml_dtypes
except ImportError:
    ml_dtypes = None  # type: ignore
[docs]
class ConstantOfShape(OpRun):
    @staticmethod
    def _process(value):
        if (
            value is not None
            and ml_dtypes is not None
            and value.dtype == (np.uint16, [("bfloat16", "<u2")])
        ):
            value = value.view(ml_dtypes.bfloat16)
        cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
        if isinstance(value, np.ndarray):
            if not value.shape:
                cst = value
            elif value.size > 0:
                cst = value.ravel()[0]
            else:
                raise ValueError(f"Unexpected fill_value={value!r}")
        if isinstance(cst, bool):
            cst = np.bool_(cst)
        elif isinstance(cst, int):
            cst = np.int64(cst)
        elif isinstance(cst, float):
            cst = np.float64(cst)
        elif cst is None:
            cst = np.float32(0)
        if ml_dtypes is not None and isinstance(cst, ml_dtypes.bfloat16):
            return cst
        if not isinstance(
            cst,
            (
                np.float16,
                np.float32,
                np.float64,
                np.int64,
                np.int32,
                np.int16,
                np.int8,
                np.uint64,
                np.uint32,
                np.uint16,
                np.uint8,
                np.bool_,
            ),
        ):
            raise TypeError(f"value must be a real not {type(cst)}")
        return cst
    def _run(self, data, value=None):
        cst = self._process(value)
        try:
            res = np.full(tuple(data), cst)
        except TypeError as e:
            raise RuntimeError(
                f"Unable to create a constant of shape "
                f"{data!r} with value {cst!r} "
                f"(raw value={value!r})."
            ) from e
        return (res,)