yobx.sklearn.statsmodels.glm#

Converter for statsmodels.genmod.generalized_linear_model.GLMResultsWrapper.

class yobx.sklearn.statsmodels.glm.StatsmodelsGLMWrapper(glm_result)[source]#

Wraps a fitted statsmodels.genmod.generalized_linear_model.GLMResultsWrapper as a scikit-learn-compatible estimator so it can be converted to ONNX via yobx.sklearn.to_onnx().

The class extracts the fitted parameters, intercept, and link function from the statsmodels result at construction time. The predict() method replicates statsmodels’ prediction logic without requiring the full design matrix (i.e. users pass the raw feature matrix X, without any constant/intercept column).

Supported link functions (formulas below are the inverse link, i.e. the transformation applied to the linear predictor eta = X @ coef + intercept to obtain the response-scale prediction):

  • Identity (Gaussian default) – eta (pass-through)

  • Log (Poisson / NegativeBinomial / Tweedie default) – exp(eta)

  • Logit (Binomial default) – sigmoid(eta) = 1 / (1 + exp(-eta))

  • Power (general, link: g(μ) = μ^p) – eta ** (1 / p)

    • InversePower (p = -1, Gamma default) – 1 / eta

    • Sqrt (p = 0.5) – eta ** 2

    • InverseSquared (p = -2, InverseGaussian default) – eta ** (-0.5)

Note

Power(0) is a degenerate case (μ^0 = 1 for all μ), which is not invertible and therefore not supported. This should not be confused with the Log link, which is a distinct link class in statsmodels.

Example usage:

import statsmodels.api as sm
import statsmodels.genmod.families as families
import numpy as np
from yobx.sklearn import to_onnx
from yobx.sklearn.statsmodels import StatsmodelsGLMWrapper

X_train = np.column_stack([np.ones(100), np.random.randn(100, 3)])
y_train = np.random.poisson(lam=2.0, size=100)
result = sm.GLM(y_train, X_train, family=families.Poisson()).fit()

wrapper = StatsmodelsGLMWrapper(result)
# X_raw is the feature matrix WITHOUT the constant column
X_raw = X_train[:, 1:]
onx = to_onnx(wrapper, (X_raw.astype(np.float32),))
Parameters:

glm_result – a fitted GLMResultsWrapper

fit(X=None, y=None, **fit_params)[source]#

No-op placeholder required by the scikit-learn estimator API.

StatsmodelsGLMWrapper wraps an already-fitted statsmodels result, so this method is intentionally empty and returns self unchanged.

predict(X)[source]#

Predict using the GLM model.

Computes link.inverse(X @ coef.T + intercept) matching statsmodels’ predict().

Parameters:

X – raw feature matrix of shape (n_samples, n_features) without any constant column

Returns:

predicted values of shape (n_samples,)

yobx.sklearn.statsmodels.glm.statsmodels_glm_converter(g: GraphBuilderExtendedProtocol, sts: Dict, outputs: List[str], estimator: StatsmodelsGLMWrapper, X: str, name: str = 'statsmodels_glm') str[source]#

Converts a StatsmodelsGLMWrapper into ONNX.

The converter implements the GLM prediction formula:

eta = X @ coef.T + intercept
mu  = link⁻¹(eta)

where link is the link function stored on the fitted statsmodels model.

Graph structure:

X  ──Gemm(coef, intercept, transB=1)──► eta
                                          │
                              link⁻¹(·) ──►  mu  (output)

Supported link functions (ONNX node used for the inverse link):

  • IdentityIdentity node (pass-through)

  • LogExp (inverse of log is exp)

  • LogitSigmoid (inverse of logit is sigmoid)

  • Power(p)Pow(eta, 1/p); special cases:

    • p = 1 (Identity): Pow(eta, 1) = pass-through

    • p = -1 (InversePower): Pow(eta, -1) = reciprocal

    • p = 0.5 (Sqrt): Pow(eta, 2) = square

    • p = -2 (InverseSquared): Pow(eta, -0.5)

Note

Power(0) is not invertible (μ^0 = 1 is constant) and raises NotImplementedError. Use the statsmodels Log link class for log-link models.

Parameters:
  • g – the graph builder to add nodes to

  • sts – shapes defined by scikit-learn

  • outputs – desired output names

  • estimator – a StatsmodelsGLMWrapper wrapping a fitted GLM result

  • X – input tensor name

  • name – prefix for added node names

Returns:

output tensor name

Raises:

NotImplementedError – when an unsupported link function is encountered