101: Linear Regression and export to ONNX

scikit-learn and torch to train a linear regression.

data

import numpy as np
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression, SGDRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
import torch
from onnxruntime import InferenceSession
from experimental_experiment.helpers import pretty_onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
from experimental_experiment.torch_interpreter import to_onnx


X, y = make_regression(1000, n_features=5, noise=10.0, n_informative=2)
print(X.shape, y.shape)

X_train, X_test, y_train, y_test = train_test_split(X, y)
(1000, 5) (1000,)

scikit-learn: the simple regression

A^* = (X'X)^{-1}X'Y

clr = LinearRegression()
clr.fit(X_train, y_train)

print(f"coefficients: {clr.coef_}, {clr.intercept_}")
coefficients: [ 0.11838142 81.71093483 14.49377687  0.40767291  0.5306029 ], 0.21586205794585567

Evaluation

y_pred = clr.predict(X_test)
l2 = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"LinearRegression: l2={l2}, r2={r2}")
LinearRegression: l2=87.17789525686865, r2=0.9876824124703169

scikit-learn: SGD algorithm

SGD = Stochastic Gradient Descent

clr = SGDRegressor(max_iter=5, verbose=1)
clr.fit(X_train, y_train)

print(f"coefficients: {clr.coef_}, {clr.intercept_}")
-- Epoch 1
Norm: 71.47, NNZs: 5, Bias: -1.469467, T: 750, Avg. loss: 707.486947
Total training time: 0.00 seconds.
-- Epoch 2
Norm: 79.92, NNZs: 5, Bias: -0.559689, T: 1500, Avg. loss: 72.682249
Total training time: 0.00 seconds.
-- Epoch 3
Norm: 81.85, NNZs: 5, Bias: -0.097504, T: 2250, Avg. loss: 51.499575
Total training time: 0.00 seconds.
-- Epoch 4
Norm: 82.57, NNZs: 5, Bias: 0.097567, T: 3000, Avg. loss: 49.686440
Total training time: 0.00 seconds.
-- Epoch 5
Norm: 82.80, NNZs: 5, Bias: 0.174982, T: 3750, Avg. loss: 49.403598
Total training time: 0.00 seconds.
~/vv/this312/lib/python3.12/site-packages/sklearn/linear_model/_stochastic_gradient.py:1579: ConvergenceWarning: Maximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.
  warnings.warn(
coefficients: [5.61454070e-02 8.15191741e+01 1.44813626e+01 4.55558223e-01
 5.85793413e-01], [0.17498211]

Evaluation

y_pred = clr.predict(X_test)
sl2 = mean_squared_error(y_test, y_pred)
sr2 = r2_score(y_test, y_pred)
print(f"SGDRegressor: sl2={sl2}, sr2={sr2}")
SGDRegressor: sl2=87.58380464235132, sr2=0.9876250604962863

Linrar Regression with pytorch

class TorchLinearRegression(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super().__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return self.linear(x)


def train_loop(dataloader, model, loss_fn, optimizer):
    total_loss = 0.0

    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for X, y in dataloader:
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred.ravel(), y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # training loss
        total_loss += loss

    return total_loss


model = TorchLinearRegression(X_train.shape[1], 1)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

device = "cpu"
model = model.to(device)
dataset = torch.utils.data.TensorDataset(
    torch.Tensor(X_train).to(device), torch.Tensor(y_train).to(device)
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)


for i in range(5):
    loss = train_loop(dataloader, model, loss_fn, optimizer)
    print(f"iteration {i}, loss={loss}")
iteration 0, loss=1711194.875
iteration 1, loss=151732.578125
iteration 2, loss=77832.0390625
iteration 3, loss=74372.796875
iteration 4, loss=74270.6953125

Let’s check the error

y_pred = model(torch.Tensor(X_test)).detach().numpy()
tl2 = mean_squared_error(y_test, y_pred)
tr2 = r2_score(y_test, y_pred)
print(f"TorchLinearRegression: tl2={tl2}, tr2={tr2}")
TorchLinearRegression: tl2=86.30657765848127, tr2=0.9878055231597027

And the coefficients.

print("coefficients:")
for p in model.parameters():
    print(p)
coefficients:
Parameter containing:
tensor([[ 0.2490, 81.8158, 14.3490,  0.2321,  0.7074]], requires_grad=True)
Parameter containing:
tensor([0.4751], requires_grad=True)

Conversion to ONNX

Let’s convert it to ONNX.

onx = to_onnx(model, (torch.Tensor(X_test[:2]),), input_names=["x"])

Let’s check it is work.

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(None, {"x": X_test.astype(np.float32)[:2]})
print(res)
[array([[-137.21176 ],
       [  22.883425]], dtype=float32)]

And the model.

plot torch linreg 101

With dynamic shapes

The dynamic shapes are used by torch.export.export() and must follow the convention described there. The dynamic dimension allows any value. The model is then valid for many different shapes. That’s usually what users need.

onx = to_onnx(
    model,
    (torch.Tensor(X_test[:2]),),
    input_names=["x"],
    dynamic_shapes={"x": {0: torch.export.Dim("batch")}},
)

print(pretty_onnx(onx))
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=['batch', 5]
init: name='GemmTransposePattern--p_linear_weight::T10' type=float32 shape=(1, 5)-- GraphBuilder.constant_folding.from/fold(init7_s2_1_-1,p_linear_weight::T10)##p_linear_weight::T10/GraphBuilder.constant_folding.from/fold(p_linear_weight)##p_linear_weight/DynamoInterpret.placeholder.1/P(linear.weight)##init7_s2_1_-1/TransposeEqualReshapePattern.apply.new_shape
init: name='linear.bias' type=float32 shape=(1,) -- array([0.4750792], dtype=float32)-- DynamoInterpret.placeholder.1/P(linear.bias)
Gemm(x, GemmTransposePattern--p_linear_weight::T10, linear.bias, transB=1) -> output_0
output: name='output_0' type=dtype('float32') shape=['batch', 1]

For simplicity, it is possible to use torch.export.Dim.DYNAMIC or torch.export.Dim.AUTO.

onx = to_onnx(
    model,
    (torch.Tensor(X_test[:2]),),
    input_names=["x"],
    dynamic_shapes={"x": {0: torch.export.Dim.DYNAMIC}},
)

print(pretty_onnx(onx))
opset: domain='' version=18
input: name='x' type=dtype('float32') shape=['batch', 5]
init: name='GemmTransposePattern--p_linear_weight::T10' type=float32 shape=(1, 5)-- GraphBuilder.constant_folding.from/fold(init7_s2_1_-1,p_linear_weight::T10)##p_linear_weight::T10/GraphBuilder.constant_folding.from/fold(p_linear_weight)##p_linear_weight/DynamoInterpret.placeholder.1/P(linear.weight)##init7_s2_1_-1/TransposeEqualReshapePattern.apply.new_shape
init: name='linear.bias' type=float32 shape=(1,) -- array([0.4750792], dtype=float32)-- DynamoInterpret.placeholder.1/P(linear.bias)
Gemm(x, GemmTransposePattern--p_linear_weight::T10, linear.bias, transB=1) -> output_0
output: name='output_0' type=dtype('float32') shape=['batch', 1]

Total running time of the script: (0 minutes 2.504 seconds)

Related examples

101: A custom backend for torch

101: A custom backend for torch

201: Use torch to export a scikit-learn model into ONNX

201: Use torch to export a scikit-learn model into ONNX

201: Evaluate different ways to export a torch model to ONNX

201: Evaluate different ways to export a torch model to ONNX

201: Better shape inference

201: Better shape inference

101: Onnx Model Optimization based on Pattern Rewriting

101: Onnx Model Optimization based on Pattern Rewriting

Gallery generated by Sphinx-Gallery