Source code for onnx_extended.reference.c_ops.c_op_svm_classifier

from typing import Any, Dict
import numpy as np
from onnx import NodeProto
from onnx.reference.op_run import OpRun
from .cpu.c_op_svm_py_ import (
    RuntimeSVMClassifierFloat,
    RuntimeSVMClassifierDouble,
)


[docs] class SVMClassifier(OpRun): op_domain = "ai.onnx.ml" def __init__( self, onnx_node: NodeProto, run_params: Dict[str, Any], schema: Any = None ): OpRun.__init__(self, onnx_node, run_params, schema=schema) self.rt_ = None @classmethod def _post_process_label_attributes(self, classlabels_int64s, classlabels_strings): """ Replaces string labels by int64 labels. It creates attributes *_classlabels_int64s_string*. """ if classlabels_strings: class_ints = np.arange(len(classlabels_strings), dtype=np.int64) self_classlabels_int64s_string = classlabels_strings else: class_ints = classlabels_int64s self_classlabels_int64s_string = None return class_ints, self_classlabels_int64s_string @classmethod def _post_process_predicted_label(cls, classlabels_int64s_string, label, scores): """ Replaces int64 predicted labels by the corresponding strings. """ if classlabels_int64s_string is not None: label = np.array([classlabels_int64s_string[i] for i in label]) return label, scores def _run( self, x, classlabels_ints=None, classlabels_strings=None, coefficients=None, kernel_params=None, kernel_type=None, post_transform=None, prob_a=None, prob_b=None, rho=None, support_vectors=None, vectors_per_class=None, ): """ This is a C++ implementation coming from :epkg:`onnxruntime`. `svm_regressor.cc <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/svm_classifier.cc>`_. See class :class:`RuntimeSVMClasssifier <mlprodict.onnxrt.ops_cpu.op_svm_regressor_.RuntimeSVMClasssifier>`. """ if self.rt_ is None: ( classlabels_ints, self.classlabels_int64s_string, ) = self._post_process_label_attributes( classlabels_ints, classlabels_strings ) if x.dtype == np.float32: self.rt_ = RuntimeSVMClassifierFloat() elif x.dtype == np.float64: self.rt_ = RuntimeSVMClassifierDouble() else: raise NotImplementedError(f"Not implemented for dtype={x.dtype}.") self.rt_.init( classlabels_ints, classlabels_strings.tolist() if classlabels_strings is not None else [], coefficients, kernel_params, kernel_type, post_transform, prob_a, prob_b, rho, support_vectors, vectors_per_class, ) res = self.rt_.compute(x) label, scores = res if scores.shape[0] != label.shape[0]: scores = scores.reshape(label.shape[0], scores.shape[0] // label.shape[0]) return self._post_process_predicted_label( self.classlabels_int64s_string, label, scores )