Source code for onnx_extended.reference

from typing import Union
import numpy as np
from onnx import SparseTensorProto, TensorProto
from onnx.reference.op_run import to_array_extended as onnx_to_array_extended
from .c_reference_evaluator import CReferenceEvaluator, from_array_extended


[docs]def to_array_extended( tensor: Union[SparseTensorProto, TensorProto] ) -> Union[np.ndarray, "scipy.sparse.coo_matrix"]: # noqa: F821 """ Overwrites function `onnx.reference.op_run.to_array_extended` to support sparse tensors. """ if isinstance(tensor, TensorProto): return onnx_to_array_extended(tensor) if isinstance(tensor, SparseTensorProto): import scipy.sparse as sp shape = tuple(d for d in tensor.dims) indices = onnx_to_array_extended(tensor.indices) values = onnx_to_array_extended(tensor.values) if len(indices.shape) == 1: t = sp.csr_matrix( (values, indices, np.array([0, len(indices)], dtype=np.int64)), shape=(1, np.prod(shape)), ) return t.reshape(shape) if len(indices.shape) == 2: t = sp.coo_matrix((values, (indices[:, 0], indices[:, 1])), shape=shape) return t raise RuntimeError(f"Unexpected indices shape: {indices.shape}.") raise TypeError(f"Unexpected type {type(tensor)}.")