yobx.sklearn.mixture.gaussian_mixture#
- yobx.sklearn.mixture.gaussian_mixture.sklearn_gaussian_mixture(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: GaussianMixture, X: str, name: str = 'gaussian_mixture') Tuple[str, str][source]#
Converts a
sklearn.mixture.GaussianMixtureinto ONNX.The converter supports all four covariance types supported by
GaussianMixture:'full','tied','diag', and'spherical'.In each case the weighted log-probability for sample n under component k is computed as:
log_p[n, k] = log(weight_k) + log_det_k - 0.5 * n_features * log(2π) - 0.5 * quad[n, k]where
quad[n, k]is the Mahalanobis distance squared, computed differently depending oncovariance_type.labelis theArgMaxoflog_pandprobais itsSoftmax.‘full’ — per-component Cholesky of the precision matrix
L_k(shape(K, F, F)):L_2d = L.transpose(1,0,2).reshape(F, K*F) # (F, K*F) constant b = einsum('ki,kij->kj', means_, L) # (K, F) constant XL = MatMul(X, L_2d) # (N, K*F) Y = Reshape(XL - b, [-1, K, F]) # (N, K, F) quad = ReduceSum(Y * Y, axis=2) # (N, K)‘tied’ — single shared Cholesky
L(shape(F, F)):means_L = means_ @ L # (K, F) constant XL = MatMul(X, L) # (N, F) Y = Reshape(XL, [-1, 1, F]) - means_L # (N, K, F) quad = ReduceSum(Y * Y, axis=2) # (N, K)
‘diag’ — per-component diagonal precision
A = prec_chol**2(shape(K, F)):B = means_ * A # (K, F) constant log_p = -0.5 * MatMul(X², Aᵀ) + MatMul(X, Bᵀ) + c # (N, K)
‘spherical’ — scalar precision
prec = prec_chol**2per component (shape(K,)):x_sq = ReduceSum(X * X, axis=1, keepdims=1) # (N, 1) cross = MatMul(X, means_ᵀ) # (N, K) log_p = prec * cross - 0.5 * prec * x_sq + c # (N, K)
- Parameters:
g – the graph builder to add nodes to
sts – shapes defined by scikit-learn
outputs – desired output names;
outputs[0]receives the predicted component labels andoutputs[1]receives the posterior probabilitiesestimator – a fitted
GaussianMixtureX – input tensor name
name – prefix for added node names
- Returns:
tuple
(label_result_name, proba_result_name)- Raises:
NotImplementedError – for unsupported
covariance_typevalues