201: Use torch to export a scikit-learn model into ONNX

When sklearn-onnx is missing a converter, torch can be used to write it. We use sklearn.impute.KNNImputer as an example. The first step is to rewrite the scikit-learn model with torch functions. The code is then refactored and split into submodules to be able to bypass some pieces torch.export.export() cannot process.

torch implementation of nan_euclidean

Let’s start with a simple case, a pairwise distance. See sklearn.metrics.nan_euclidean().

Module

import contextlib
import io
import logging
import math
import numbers
import warnings
from typing import Any, Dict, List, Optional
import numpy as np
import onnx
from onnx.reference.ops.op_topk import TopK_11 as TopK
import sklearn
import torch
import onnxruntime
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.xbuilder import GraphBuilder
from experimental_experiment.helpers import max_diff, pretty_onnx
from experimental_experiment.skl.helpers import flatnonzero, _get_weights
from experimental_experiment.torch_interpreter import make_undefined_dimension, Dispatcher
from experimental_experiment.torch_interpreter.onnx_export_errors import (
    bypass_export_some_errors,
)
from experimental_experiment.torch_interpreter.piece_by_piece import (
    trace_execution_piece_by_piece,
    CustomOpStrategy,
)


class NanEuclidean(torch.nn.Module):
    """Implements :func:`sklearn.metrics.nan_euclidean`."""

    def __init__(self, squared=False, copy=True):
        super().__init__()
        self.squared = squared
        self.copy = copy

    def forward(self, X, Y):
        X = X.clone()
        Y = Y.to(X.dtype).clone()

        missing_X = torch.isnan(X)
        missing_Y = torch.isnan(Y)

        # set missing values to zero
        X[missing_X] = 0
        Y[missing_Y] = 0

        # Adjust distances for missing values
        XX = X * X
        YY = Y * Y

        distances = -2 * X @ Y.T + XX.sum(1, keepdim=True) + YY.sum(1, keepdim=True).T

        distances -= XX @ missing_Y.to(X.dtype).T
        distances -= missing_X.to(X.dtype) @ YY.T

        distances = torch.clip(distances, 0, None)

        present_X = 1 - missing_X.to(X.dtype)
        present_Y = ~missing_Y
        present_count = present_X @ present_Y.to(X.dtype).T
        distances[present_count == 0] = torch.nan
        # avoid divide by zero
        present_count = torch.maximum(
            torch.tensor([1], dtype=present_count.dtype), present_count
        )
        distances /= present_count
        distances *= X.shape[1]

        if not self.squared:
            distances = distances.sqrt()

        return distances

Validation

model = NanEuclidean()
X = torch.randn((5, 2))
Y = torch.randn((5, 2))
for i in range(5):
    X[i, i % 2] = torch.nan
for i in range(4):
    Y[i + 1, i % 2] = torch.nan

d1 = sklearn.metrics.nan_euclidean_distances(X.numpy(), Y.numpy())
d2 = model(X, Y)
print(f"discrepancies: {max_diff(d1, d2)}")
discrepancies: {'abs': 2.384185791015625e-07, 'rel': 2.0627543198679525e-07, 'sum': 4.172325134277344e-07, 'n': 25.0, 'dnan': 0.0}

torch implementation of KNNImputer

See sklearn.impute.KNNImputer. The code is split into several torch.nn.Module and refactored to avoid control flow.

def _get_mask(X, value_to_mask):
    return (
        torch.isnan(X)
        if (  # sklearn.utils._missing.is_scalar_nan(value_to_mask)
            not isinstance(value_to_mask, numbers.Integral)
            and isinstance(value_to_mask, numbers.Real)
            and math.isnan(value_to_mask)
        )
        else (value_to_mask == X)
    )


class SubTopKIndices(torch.nn.Module):
    def forward(self, x, k):
        # torch does not like nans
        xn = torch.nan_to_num(x, nan=1.0e10)
        return torch.topk(xn, k, dim=1, largest=False, sorted=True).indices


class SubWeightMatrix(torch.nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.weights = weights

    def forward(self, donors_dist):
        weight_matrix = _get_weights(donors_dist, self.weights)
        if weight_matrix is not None:
            weight_matrix = weight_matrix.clone()
            weight_matrix[torch.isnan(weight_matrix)] = 0.0
        else:
            weight_matrix = torch.ones_like(donors_dist)
            weight_matrix[torch.isnan(donors_dist)] = 0.0
        return weight_matrix


class SubDonorsIdx(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._topk = SubTopKIndices()

    def forward(self, dist_pot_donors, n_neighbors):
        donors_idx = self._topk(dist_pot_donors, n_neighbors)
        donors_dist = dist_pot_donors[torch.arange(donors_idx.shape[0])[:, None], donors_idx]
        return donors_idx, donors_dist


class MakeNewWeights(torch.nn.Module):
    def forward(self, donors_mask, donors, weight_matrix):
        return donors_mask.to(donors.dtype) * weight_matrix.to(donors.dtype)


class CalcImpute(torch.nn.Module):
    """Implements :meth:`sklearn.impute.KNNImputer._calc_impute`."""

    def __init__(self, weights):
        super().__init__()
        self._weights = SubWeightMatrix(weights)
        self._donors_idx = SubDonorsIdx()
        self._make_new_neights = MakeNewWeights()

    def _calc_impute(self, dist_pot_donors, n_neighbors, fit_X_col, mask_fit_X_col):
        donors_idx, donors_dist = self._donors_idx(dist_pot_donors, n_neighbors)
        weight_matrix = self._weights(donors_dist)
        # Retrieve donor values and calculate kNN average
        donors = fit_X_col.take(donors_idx)
        donors_mask = torch.tensor([1], dtype=donors_idx.dtype) - (
            mask_fit_X_col.take(donors_idx)
        ).to(donors_idx.dtype)

        new_weights = self._make_new_neights(donors_mask, donors, weight_matrix)

        weights_sum = new_weights.sum(axis=1, keepdim=True)
        div = torch.where(
            weights_sum == 0, torch.tensor([1], dtype=weights_sum.dtype), weights_sum
        )
        res = (donors * new_weights).sum(axis=1, keepdim=True) / div
        return res.squeeze(dim=1).to(dist_pot_donors.dtype)

    def forward(self, dist_pot_donors, n_neighbors, fit_X_col, mask_fit_X_col):
        return self._calc_impute(dist_pot_donors, n_neighbors, fit_X_col, mask_fit_X_col)


class ColProcessor(torch.nn.Module):
    """Processes one column (= one feature)."""

    def __init__(self, col, n_neighbors, weights):
        super().__init__()
        self._calc_impute = CalcImpute(weights)
        self.col = col
        self.n_neighbors = n_neighbors

    def process_one_col(
        self,
        X,
        dist_chunk,
        non_missing_fix_X,
        mask_fit_X,
        dist_idx_map,
        mask,
        row_missing_idx,
        _fit_X,
    ):
        col = self.col
        X = X.clone()
        row_missing_chunk = row_missing_idx
        col_mask = mask[row_missing_chunk, col]

        potential_donors_idx = torch.nonzero(non_missing_fix_X[:, col], as_tuple=True)[0]

        # receivers_idx are indices in X
        receivers_idx = row_missing_chunk[flatnonzero(col_mask)]

        # distances for samples that needed imputation for column
        dist_subset = dist_chunk[dist_idx_map[receivers_idx]][:, potential_donors_idx]

        # receivers with all nan distances impute with mean
        all_nan_dist_mask = torch.isnan(dist_subset).all(axis=1)
        all_nan_receivers_idx = receivers_idx[all_nan_dist_mask]

        # when all_nan_receivers_idx is not empty (training set is small)
        mask_ = (~mask_fit_X[:, col]).to(_fit_X.dtype)
        mask_sum = mask_.to(X.dtype).sum()

        col_sum = (_fit_X[mask_ == 1, col]).sum().to(X.dtype)
        div = torch.where(mask_sum > 0, mask_sum, torch.tensor([1], dtype=mask_sum.dtype))
        X[all_nan_receivers_idx, col] = col_sum / div

        # receivers with at least one defined distance
        receivers_idx = receivers_idx[~all_nan_dist_mask]
        dist_subset = dist_chunk[dist_idx_map[receivers_idx]][:, potential_donors_idx]

        # when all_nan_receivers_idx is not empty (training set is big)
        tn = torch.tensor(self.n_neighbors)
        n_neighbors = torch.where(
            tn < potential_donors_idx.shape[0], tn, potential_donors_idx.shape[0]
        )
        # to make sure n_neighbors > 0
        n_neighbors = torch.where(
            n_neighbors <= 0, torch.tensor([1], dtype=n_neighbors.dtype), n_neighbors
        )
        value = self._calc_impute(
            dist_subset,
            n_neighbors,
            _fit_X[potential_donors_idx, col],
            mask_fit_X[potential_donors_idx, col],
        )
        X[receivers_idx, col] = value.to(X.dtype)
        return X

    def forward(
        self,
        X,
        dist_chunk,
        non_missing_fix_X,
        mask_fit_X,
        dist_idx_map,
        mask,
        row_missing_idx,
        _fit_X,
    ):
        return self.process_one_col(
            X,
            dist_chunk,
            non_missing_fix_X,
            mask_fit_X,
            dist_idx_map,
            mask,
            row_missing_idx,
            _fit_X,
        )


class MakeDictIdxMap(torch.nn.Module):
    def forward(self, X, row_missing_idx):
        dist_idx_map = torch.zeros(X.shape[0], dtype=int)
        dist_idx_map[row_missing_idx] = torch.arange(row_missing_idx.shape[0])
        return dist_idx_map


class TorchKNNImputer(torch.nn.Module):
    def __init__(self, knn_imputer):
        super().__init__()
        assert (
            knn_imputer.metric == "nan_euclidean"
        ), f"Not implemented for metric={knn_imputer.metric!r}"
        self.dist = NanEuclidean()
        cols = []
        for col in range(knn_imputer._fit_X.shape[1]):
            cols.append(ColProcessor(col, knn_imputer.n_neighbors, knn_imputer.weights))
        self.columns = torch.nn.ModuleList(cols)
        # refactoring
        self._make_dict_idx_map = MakeDictIdxMap()
        # knn imputer
        self.missing_values = knn_imputer.missing_values
        self.n_neighbors = knn_imputer.n_neighbors
        self.weights = knn_imputer.weights
        self.metric = knn_imputer.metric
        self.keep_empty_features = knn_imputer.keep_empty_features
        self.add_indicator = knn_imputer.add_indicator
        # results of fitting
        self.indicator_ = knn_imputer.indicator_
        # The training results.
        # self._fit_X = torch.from_numpy(knn_imputer._fit_X)
        # self._mask_fit_X = torch.from_numpy(knn_imputer._mask_fit_X)
        # self._valid_mask = torch.from_numpy(knn_imputer._valid_mask)

    def _transform_indicator(self, X):
        if self.add_indicator:
            if not hasattr(self, "indicator_"):
                raise ValueError(
                    "Make sure to call _fit_indicator before _transform_indicator"
                )
            raise NotImplementedError(type(self.indicator_))
            # return self.indicator_.transform(X)
        return None

    def _concatenate_indicator(self, X_imputed, X_indicator):
        if not self.add_indicator:
            return X_imputed
        if X_indicator is None:
            raise ValueError(
                "Data from the missing indicator are not provided. Call "
                "_fit_indicator and _transform_indicator in the imputer "
                "implementation."
            )
        return torch.cat([X_imputed, X_indicator], dim=0)

    def transform(self, mask_fit_X, _valid_mask, _fit_X, X):
        X = X.clone()
        mask = _get_mask(X, self.missing_values)

        X_indicator = self._transform_indicator(mask)

        row_missing_idx = flatnonzero(mask[:, _valid_mask].any(axis=1))
        non_missing_fix_X = torch.logical_not(mask_fit_X)

        # Maps from indices from X to indices in dist matrix
        dist_idx_map = self._make_dict_idx_map(X, row_missing_idx)

        # process in fixed-memory chunks
        pairwise_distances = self.dist(X[row_missing_idx, :], _fit_X)

        # The export unfold the loop as it depends on the number of features.
        # Fixed in this case.
        for col_processor in self.columns:
            X = col_processor(
                X,
                pairwise_distances,
                non_missing_fix_X,
                mask_fit_X,
                dist_idx_map,
                mask,
                row_missing_idx,
                _fit_X,
            )

        if self.keep_empty_features:
            Xc = X.clone()
            Xc[:, ~_valid_mask] = 0
        else:
            Xc = X[:, _valid_mask]

        return self._concatenate_indicator(Xc, X_indicator)

    def forward(self, _mask_fit_X, _valid_mask, _fit_X, X):
        return self.transform(_mask_fit_X, _valid_mask, _fit_X, X)

Validation

We need to do that with different sizes of training set.

def validate(size, sizey):
    X = torch.randn((size, 2))
    Y = torch.randn((sizey, 2))
    for i in range(5):
        X[i, i % 2] = torch.nan
    for i in range(4):
        Y[i + 1, i % 2] = torch.nan

    knn_imputer = sklearn.impute.KNNImputer(n_neighbors=3)
    knn_imputer.fit(X)

    model = TorchKNNImputer(knn_imputer)

    p1 = knn_imputer.transform(Y)
    p2 = model.transform(
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X),
        Y,
    )
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    print(f"knn discrepancies for size={size}: {d}")

    p1 = knn_imputer.transform(Y[1:2])
    p2 = model.transform(
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X),
        Y[1:2],
    )
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    print(f"knn discrepancies for size={size}: {d}")
    return knn_imputer, Y


knn5, Y10 = validate(5, 10)
knn50, Y40 = validate(50, 40)
knn discrepancies for size=5: {'abs': 4.967053740534411e-09, 'rel': 3.8580606309014394e-08, 'sum': 9.934107481068821e-09, 'n': 20.0, 'dnan': 0.0}
knn discrepancies for size=5: {'abs': 0.0, 'rel': 0.0, 'sum': 0.0, 'n': 2.0, 'dnan': 0.0}
knn discrepancies for size=50: {'abs': 1.986821485111534e-08, 'rel': 2.3018476420929695e-08, 'sum': 6.705522526129215e-08, 'n': 80.0, 'dnan': 0.0}
knn discrepancies for size=50: {'abs': 1.986821485111534e-08, 'rel': 1.8620979822721472e-08, 'sum': 1.986821485111534e-08, 'n': 2.0, 'dnan': 0.0}

Export to ONNX

The module cannot be exported as is because one operator torch.topk() expects a fixed number of neighbour but the model makes it variable. This is case not supported by torch.export.export(). We need to isolate that part before exporting the model. It is done by replacing it with a custom op. This is automatically done by function trace_execution_piece_by_piece().

First step, we create two sets of inputs. A function will use this to infer the dynamic shapes.

Then we trace the execution to capture every input and output of every submodule. The model implementation was refactored to introduce many tiny one and get a fine-grained evaluation of the exportability.

__main__                  TorchKNNImputer   <OK-2i>
..dist                    NanEuclidean      <OK-2i>
..columns[0]              ColProcessor      <OK-2i>
...._calc_impute          CalcImpute        <OK-2i>
......_weights            SubWeightMatrix   <OK-2i>
......_donors_idx         SubDonorsIdx      <OK-2i>
........_topk             SubTopKIndices    <OK-2i>
......_make_new_neights   MakeNewWeights    <OK-2i>
..columns[1]              ColProcessor      <OK-2i>
...._calc_impute          CalcImpute        <OK-2i>
......_weights            SubWeightMatrix   <OK-2i>
......_donors_idx         SubDonorsIdx      <OK-2i>
........_topk             SubTopKIndices    <OK-2i>
......_make_new_neights   MakeNewWeights    <OK-2i>
.._make_dict_idx_map      MakeDictIdxMap    <OK-2i>

The dynamic shapes for the whole model:

print("dynamic shapes:")
print(trace.guess_dynamic_shapes())
dynamic shapes:
(({0: <_DimHint.DYNAMIC: 3>}, {}, {0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})

The method try_export cannot infer all links between input shapes and output shapes for every submodule. The following function fills this gap.

shape_functions = {
    "NanEuclidean": {
        0: lambda *args, **kwargs: torch.empty(
            (args[0].shape[0], args[1].shape[0]), dtype=args[0].dtype
        )
    },
    "CalcImpute": {
        0: lambda *args, **kwargs: torch.empty((args[0].shape[0],), dtype=args[0].dtype)
    },
    "SubTopKIndices": {
        0: lambda *args, **kwargs: torch.empty(
            (
                args[0].shape[0],
                make_undefined_dimension(min(args[0].shape[1], knn5.n_neighbors)),
            ),
            dtype=args[0].dtype,
        )
    },
    "SubDonorsIdx": {
        0: lambda *args, **kwargs: torch.empty(
            (
                args[0].shape[0],
                make_undefined_dimension(min(args[0].shape[1], knn5.n_neighbors)),
            ),
            dtype=args[0].dtype,
        ),
        1: lambda *args, **kwargs: torch.empty(
            (
                args[0].shape[0],
                make_undefined_dimension(min(args[0].shape[1], knn5.n_neighbors)),
            ),
            dtype=torch.float32,
        ),
    },
    "MakeDictIdxMap": {
        0: lambda *args, **kwargs: torch.empty((args[0].shape[0],), dtype=args[1].dtype),
    },
}

Then we we try to export piece by piece. We capture the standard output to avoid being overwhelmed and we use function bypass_export_some_errors() to skip some errors with shape checking made by torch.

logging.disable(logging.CRITICAL)

with contextlib.redirect_stderr(io.StringIO()), bypass_export_some_errors():
    ep = trace.try_export(
        exporter="fx",
        use_dynamic_shapes=True,
        exporter_kwargs=dict(strict=False),
        replace_by_custom_op=CustomOpStrategy.LOCAL,
        verbose=0,
        shape_functions=shape_functions,
    )

assert ep.status in (
    ep.status.OK,
    ep.status.OK_CHILDC,
), f"FAIL: {ep}\n-- report --\n{trace.get_export_report()}"
print(trace.get_export_report())
__main__                  TorchKNNImputer   OK_CHILDC -- ExportedProgram
..dist                    NanEuclidean      OK -- ExportedProgram
..columns[0]              ColProcessor      OK_CHILDC -- ExportedProgram
...._calc_impute          CalcImpute        OK_CHILDC -- ExportedProgram
......_weights            SubWeightMatrix   OK -- ExportedProgram
......_donors_idx         SubDonorsIdx      OK_CHILDC -- ExportedProgram
........_topk             SubTopKIndices    OK -- ExportedProgram
......_make_new_neights   MakeNewWeights    OK -- ExportedProgram
..columns[1]              ColProcessor      OK_CHILDC -- ExportedProgram
...._calc_impute          CalcImpute        OK_CHILDC -- ExportedProgram
......_weights            SubWeightMatrix   OK -- ExportedProgram
......_donors_idx         SubDonorsIdx      OK_CHILDC -- ExportedProgram
........_topk             SubTopKIndices    OK -- ExportedProgram
......_make_new_neights   MakeNewWeights    OK -- ExportedProgram
.._make_dict_idx_map      MakeDictIdxMap    OK -- ExportedProgram

OK means the module is exportable. OK_CHILDC means the module can be exported after its submodules are replaced by custom ops. It works except for the topk function. FAIL means the submodule cannot be exported at all but that module is simple enough and its ONNX conversion can be provided.

Final step

We first start by running the decompositions on every exported program.

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    for t in trace:
        if t.exporter_status.exported is None:
            print(f"[run_decompositions] {t.dot_name} - skipped")
            continue
        print(f"[run_decompositions] {t.dot_name}")
        t.exporter_status.exported = t.exporter_status.exported.run_decompositions({})
[run_decompositions]  M:__main__-TorchKNNImputer
[run_decompositions] .. M:dist-NanEuclidean
[run_decompositions] .. M:columns[0]-ColProcessor
[run_decompositions] .... M:_calc_impute-CalcImpute
[run_decompositions] ...... M:_weights-SubWeightMatrix
[run_decompositions] ...... M:_donors_idx-SubDonorsIdx
[run_decompositions] ........ M:_topk-SubTopKIndices
[run_decompositions] ...... M:_make_new_neights-MakeNewWeights
[run_decompositions] .. M:columns[1]-ColProcessor
[run_decompositions] .... M:_calc_impute-CalcImpute
[run_decompositions] ...... M:_weights-SubWeightMatrix
[run_decompositions] ...... M:_donors_idx-SubDonorsIdx
[run_decompositions] ........ M:_topk-SubTopKIndices
[run_decompositions] ...... M:_make_new_neights-MakeNewWeights
[run_decompositions] .. M:_make_dict_idx_map-MakeDictIdxMap

Let’s export everything. Every submodule is exported as a local function except topk for which we must provide an ONNX conversion.

T = str


def onnx_topk_indices(
    g: GraphBuilder,
    sts: Optional[Dict[str, Any]],
    outputs: List[str],
    x: T,
    k: T,
    name: str = "topk",
):
    assert len(outputs) == 1, f"Only one output is expected but outputs={outputs}"
    unique_name = g.unique_name("unused_topk_values")
    g.op.TopK(x, k, name=name, outputs=[unique_name, *outputs], largest=False, sorted=True)
    return outputs[0]

Let’s check it is working somehow.

x = torch.tensor([[0, 1, 2], [6, 5, 4]], dtype=torch.float32)
print("torch.topk", torch.topk(x, k=2).indices)
print("onnx.topk", TopK.eval(x.numpy(), np.array([2], dtype=np.int64))[1])
torch.topk tensor([[2, 1],
        [0, 1]])
onnx.topk [[2 1]
 [0 1]]

And with nan values

x = torch.tensor([[0, np.nan, 2], [6, np.nan, 4]], dtype=torch.float32)
print("torch.topk", torch.topk(torch.nan_to_num(x, nan=-1.0e10), k=2).indices)
print("onnx.topk", TopK.eval(x.numpy(), np.array([2], dtype=np.int64))[1])
torch.topk tensor([[2, 0],
        [0, 2]])
onnx.topk [[2 0]
 [0 2]]

That works. Then the dispatcher maps the custom ops calling topk to the previous converter functions.

dispatcher = Dispatcher(
    {
        (
            "diag_lib::C_TorchKNNImputer_columns_0___calc_impute__donors_idx__topk"
        ): onnx_topk_indices,
        (
            "diag_lib::C_TorchKNNImputer_columns_1___calc_impute__donors_idx__topk"
        ): onnx_topk_indices,
    }
)

Let’s run the conversion. We also check the conversion into ONNX is accurate. It is doable because every intermediate results were previously traced.

onx = trace.to_onnx_local(
    verbose=1,
    dispatcher=dispatcher,
    check_conversion_cls=dict(cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5),
    inline=False,
)
[to_onnx_local]  M:__main__-TorchKNNImputer - to_onnx_local
[to_onnx_local]  M:__main__-TorchKNNImputer - export child 'C_TorchKNNImputer_dist'
[to_onnx_local] .. M:dist-NanEuclidean - to_onnx_local
[to_onnx_local] .. M:dist-NanEuclidean - export starts C_TorchKNNImputer_dist
[to_onnx_local] .. M:dist-NanEuclidean - export done
[to_onnx_local] .. M:dist-NanEuclidean - run validation
[onnx_run_disc] .. M:dist-NanEuclidean run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s4x2,T11s50x2),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s4x2[nan,nan:AnanN4nans],T11s50x2[nan,nan:AnanN5nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s4x50[nan,nan:AnanN10nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s4x50[0.02562400884926319,6.643804550170898:A1.8692522174923827N10nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s4x2,T11s5x2),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s4x2[nan,nan:AnanN4nans],T11s5x2[nan,nan:AnanN5nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s4x5[nan,nan:AnanN10nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s4x5[0.27489498257637024,4.837934494018555:A2.1912896662950514N10nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:dist-NanEuclidean validation done
[to_onnx_local] .. M:dist-NanEuclidean - done
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0
[to_onnx_local]  M:__main__-TorchKNNImputer - export child 'C_TorchKNNImputer_columns_0_'
[to_onnx_local] .. M:columns[0]-ColProcessor - to_onnx_local
[to_onnx_local] .. M:columns[0]-ColProcessor - export child 'C_TorchKNNImputer_columns_0___calc_impute'
[to_onnx_local] .... M:_calc_impute-CalcImpute - to_onnx_local
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_0___calc_impute__weights'
[to_onnx_local] ...... M:_weights-SubWeightMatrix - to_onnx_local
[to_onnx_local] ...... M:_weights-SubWeightMatrix - export starts C_TorchKNNImputer_columns_0___calc_impute__weights
[to_onnx_local] ...... M:_weights-SubWeightMatrix - export done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - run validation
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s2x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s2x3[0.02562400884926319,0.5225339531898499:A0.19498917118956646],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s2x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s2x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0.0, rel=0.0
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x2,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x2[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x2[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x2[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0
[onnx_run_disc] ...... M:_weights-SubWeightMatrix validation done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_0___calc_impute__donors_idx'
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - to_onnx_local
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - skip child 'C_TorchKNNImputer_columns_0___calc_impute__donors_idx__topk'
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - export starts C_TorchKNNImputer_columns_0___calc_impute__donors_idx
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - export done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - run validation
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s2x47,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s2x47[nan,nan:AnanN4nans],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s2x3[8,25:A17.833333333333332],T1s2x3[0.02562400884926319,0.5225339531898499:A0.19498917118956646])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s2x3[8,25:A17.833333333333332],A1s2x3[0.02562400884926319,0.5225339531898499:A0.19498917118956646])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x2,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x2[empty],T7s1[2,2:A2.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x2[empty],T1s0x2[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x2[empty],A1s0x2[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx validation done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_0___calc_impute__make_new_neights'
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - to_onnx_local
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - export starts C_TorchKNNImputer_columns_0___calc_impute__make_new_neights
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - export done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - run validation
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s2x3,T11s2x3,T1s2x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s2x3[1,1:A1.0],T11s2x3[-2.6722490787506104,2.6974563598632812:A0.22883254289627075],T1s2x3[1.0,1.0:A1.0]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T11s2x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A11s2x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0.0, rel=0.0
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x2,T11s0x2,T1s0x2),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x2[empty],T11s0x2[empty],T1s0x2[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T11s0x2[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A11s0x2[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights validation done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export starts C_TorchKNNImputer_columns_0___calc_impute
[to_onnx_local] .... M:_calc_impute-CalcImpute - export done
[to_onnx_local] .... M:_calc_impute-CalcImpute - run validation
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s2x47,T7s1,T11s47,T9s47),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s2x47[nan,nan:AnanN4nans],T7s1[3,3:A3.0],T11s47[-2.6722490787506104,2.6974563598632812:A0.15366982125696985],T9s47[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s2[-1.065980076789856,1.5236451625823975:A0.22883254289627075],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s2[-1.065980076789856,1.5236451625823975:A0.22883254289627075]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0.0, rel=0.0
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x2,T7s1,T11s2,T9s2),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x2[empty],T7s1[2,2:A2.0],T11s2[-0.6274242401123047,2.5991315841674805:A0.9858536720275879],T9s2[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0
[onnx_run_disc] .... M:_calc_impute-CalcImpute validation done
[to_onnx_local] .... M:_calc_impute-CalcImpute - done
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0
[to_onnx_local] .. M:columns[0]-ColProcessor - export starts C_TorchKNNImputer_columns_0_
[to_onnx_local] .. M:columns[0]-ColProcessor - export done
[to_onnx_local] .. M:columns[0]-ColProcessor - run validation
[onnx_run_disc] .. M:columns[0]-ColProcessor run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .. M:columns[0]-ColProcessor run with ((T1s40x2,T1s4x50,T9s50x2,T9s50x2,T7s40,T9s40x2,T7s4,T11s50x2),{})
[onnx_run_disc] .. M:columns[0]-ColProcessor flattened into ((T1s40x2[nan,nan:AnanN4nans],T1s4x50[nan,nan:AnanN10nans],T9s50x2[False,True:A0.95],T9s50x2[False,True:A0.05],T7s40[0,3:A0.15],T9s40x2[False,True:A0.05],T7s4[1,4:A2.5],T11s50x2[nan,nan:AnanN5nans]),{})
[onnx_run_disc] .. M:columns[0]-ColProcessor expecting (T1s40x2[nan,nan:AnanN2nans],)
[onnx_run_disc] .. M:columns[0]-ColProcessor computing A1s40x2[-2.2343990802764893,2.7258105278015137:A-0.10593246023293035N2nans]
[onnx_run_disc] .. M:columns[0]-ColProcessor diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:columns[0]-ColProcessor run with ((T1s10x2,T1s4x5,T9s5x2,T9s5x2,T7s10,T9s10x2,T7s4,T11s5x2),{})
[onnx_run_disc] .. M:columns[0]-ColProcessor flattened into ((T1s10x2[nan,nan:AnanN4nans],T1s4x5[nan,nan:AnanN10nans],T9s5x2[False,True:A0.5],T9s5x2[False,True:A0.5],T7s10[0,3:A0.6],T9s10x2[False,True:A0.2],T7s4[1,4:A2.5],T11s5x2[nan,nan:AnanN5nans]),{})
[onnx_run_disc] .. M:columns[0]-ColProcessor expecting (T1s10x2[nan,nan:AnanN2nans],)
[onnx_run_disc] .. M:columns[0]-ColProcessor computing A1s10x2[-2.5295796394348145,3.1395082473754883:A0.46641130508699763N2nans]
[onnx_run_disc] .. M:columns[0]-ColProcessor diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:columns[0]-ColProcessor validation done
[to_onnx_local] .. M:columns[0]-ColProcessor - done
[to_onnx_local] .. M:columns[0]-ColProcessor - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] .. M:columns[0]-ColProcessor - discrepancies: abs=0.0, rel=0.0
[to_onnx_local]  M:__main__-TorchKNNImputer - export child 'C_TorchKNNImputer_columns_1_'
[to_onnx_local] .. M:columns[1]-ColProcessor - to_onnx_local
[to_onnx_local] .. M:columns[1]-ColProcessor - export child 'C_TorchKNNImputer_columns_1___calc_impute'
[to_onnx_local] .... M:_calc_impute-CalcImpute - to_onnx_local
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_1___calc_impute__weights'
[to_onnx_local] ...... M:_weights-SubWeightMatrix - to_onnx_local
[to_onnx_local] ...... M:_weights-SubWeightMatrix - export starts C_TorchKNNImputer_columns_1___calc_impute__weights
[to_onnx_local] ...... M:_weights-SubWeightMatrix - export done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - run validation
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s2x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s2x3[0.07385001331567764,1.1403498649597168:A0.518671645472447],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s2x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s2x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0.0, rel=0.0
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0
[onnx_run_disc] ...... M:_weights-SubWeightMatrix validation done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_1___calc_impute__donors_idx'
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - to_onnx_local
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - skip child 'C_TorchKNNImputer_columns_1___calc_impute__donors_idx__topk'
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - export starts C_TorchKNNImputer_columns_1___calc_impute__donors_idx
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - export done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - run validation
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s2x48,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s2x48[nan,nan:AnanN6nans],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s2x3[3,42:A18.666666666666668],T1s2x3[0.07385001331567764,1.1403498649597168:A0.518671645472447])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s2x3[3,42:A18.666666666666668],A1s2x3[0.07385001331567764,1.1403498649597168:A0.518671645472447])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x3,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x3[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx validation done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_1___calc_impute__make_new_neights'
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - to_onnx_local
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - export starts C_TorchKNNImputer_columns_1___calc_impute__make_new_neights
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - export done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - run validation
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s2x3,T11s2x3,T1s2x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s2x3[1,1:A1.0],T11s2x3[-1.5542010068893433,1.4427423477172852:A-0.13992430393894514],T1s2x3[1.0,1.0:A1.0]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T11s2x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A11s2x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0.0, rel=0.0
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T11s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T11s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T11s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A11s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights validation done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export starts C_TorchKNNImputer_columns_1___calc_impute
[to_onnx_local] .... M:_calc_impute-CalcImpute - export done
[to_onnx_local] .... M:_calc_impute-CalcImpute - run validation
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s2x48,T7s1,T11s48,T9s48),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s2x48[nan,nan:AnanN6nans],T7s1[3,3:A3.0],T11s48[-2.0484697818756104,1.8861173391342163:A-0.06413272099598544],T9s48[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s2[-1.0340979099273682,0.7542492747306824:A-0.1399243175983429],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s2[-1.0340979099273682,0.7542492747306824:A-0.1399243175983429]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0.0, rel=0.0
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x3,T7s1,T11s3,T9s3),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x3[empty],T7s1[3,3:A3.0],T11s3[-1.1547163724899292,1.8039337396621704:A-0.12774483362833658],T9s3[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0
[onnx_run_disc] .... M:_calc_impute-CalcImpute validation done
[to_onnx_local] .... M:_calc_impute-CalcImpute - done
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0
[to_onnx_local] .. M:columns[1]-ColProcessor - export starts C_TorchKNNImputer_columns_1_
[to_onnx_local] .. M:columns[1]-ColProcessor - export done
[to_onnx_local] .. M:columns[1]-ColProcessor - run validation
[onnx_run_disc] .. M:columns[1]-ColProcessor run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .. M:columns[1]-ColProcessor run with ((T1s40x2,T1s4x50,T9s50x2,T9s50x2,T7s40,T9s40x2,T7s4,T11s50x2),{})
[onnx_run_disc] .. M:columns[1]-ColProcessor flattened into ((T1s40x2[nan,nan:AnanN2nans],T1s4x50[nan,nan:AnanN10nans],T9s50x2[False,True:A0.95],T9s50x2[False,True:A0.05],T7s40[0,3:A0.15],T9s40x2[False,True:A0.05],T7s4[1,4:A2.5],T11s50x2[nan,nan:AnanN5nans]),{})
[onnx_run_disc] .. M:columns[1]-ColProcessor expecting (T1s40x2[-2.2343990802764893,2.7258105278015137:A-0.10678225666706567],)
[onnx_run_disc] .. M:columns[1]-ColProcessor computing A1s40x2[-2.2343990802764893,2.7258105278015137:A-0.10678225666706567]
[onnx_run_disc] .. M:columns[1]-ColProcessor diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:columns[1]-ColProcessor run with ((T1s10x2,T1s4x5,T9s5x2,T9s5x2,T7s10,T9s10x2,T7s4,T11s5x2),{})
[onnx_run_disc] .. M:columns[1]-ColProcessor flattened into ((T1s10x2[nan,nan:AnanN2nans],T1s4x5[nan,nan:AnanN10nans],T9s5x2[False,True:A0.5],T9s5x2[False,True:A0.5],T7s10[0,3:A0.6],T9s10x2[False,True:A0.2],T7s4[1,4:A2.5],T11s5x2[nan,nan:AnanN5nans]),{})
[onnx_run_disc] .. M:columns[1]-ColProcessor expecting (T1s10x2[-2.5295796394348145,3.1395082473754883:A0.40699569071875885],)
[onnx_run_disc] .. M:columns[1]-ColProcessor computing A1s10x2[-2.5295796394348145,3.1395082473754883:A0.40699569071875885]
[onnx_run_disc] .. M:columns[1]-ColProcessor diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:columns[1]-ColProcessor validation done
[to_onnx_local] .. M:columns[1]-ColProcessor - done
[to_onnx_local] .. M:columns[1]-ColProcessor - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] .. M:columns[1]-ColProcessor - discrepancies: abs=0.0, rel=0.0
[to_onnx_local]  M:__main__-TorchKNNImputer - export child 'C_TorchKNNImputer__make_dict_idx_map'
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - to_onnx_local
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - export starts C_TorchKNNImputer__make_dict_idx_map
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - export done
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - run validation
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap run with ((T1s40x2,T7s4),{})
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap flattened into ((T1s40x2[nan,nan:AnanN4nans],T7s4[1,4:A2.5]),{})
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap expecting (T7s40[0,3:A0.15],)
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap computing A7s40[0,3:A0.15]
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap run with ((T1s10x2,T7s4),{})
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap flattened into ((T1s10x2[nan,nan:AnanN4nans],T7s4[1,4:A2.5]),{})
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap expecting (T7s10[0,3:A0.6],)
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap computing A7s10[0,3:A0.6]
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap diff=abs=0.0, rel=0.0
[onnx_run_disc] .. M:_make_dict_idx_map-MakeDictIdxMap validation done
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - done
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - discrepancies: abs=0.0, rel=0.0
[to_onnx_local] .. M:_make_dict_idx_map-MakeDictIdxMap - discrepancies: abs=0.0, rel=0.0
[to_onnx_local]  M:__main__-TorchKNNImputer - export starts C_TorchKNNImputer
[to_onnx_local]  M:__main__-TorchKNNImputer - export done
[to_onnx_local]  M:__main__-TorchKNNImputer - run validation
[onnx_run_disc]  M:__main__-TorchKNNImputer run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc]  M:__main__-TorchKNNImputer run with ((T9s50x2,T9s2,T11s50x2,T1s40x2),{})
[onnx_run_disc]  M:__main__-TorchKNNImputer flattened into ((T9s50x2[False,True:A0.05],T9s2[True,True:A1.0],T11s50x2[nan,nan:AnanN5nans],T1s40x2[nan,nan:AnanN4nans]),{})
[onnx_run_disc]  M:__main__-TorchKNNImputer expecting (T1s40x2[-2.2343990802764893,2.7258105278015137:A-0.10678225666706567],)
[onnx_run_disc]  M:__main__-TorchKNNImputer computing A1s40x2[-2.2343990802764893,2.7258105278015137:A-0.10678225666706567]
[onnx_run_disc]  M:__main__-TorchKNNImputer diff=abs=0.0, rel=0.0
[onnx_run_disc]  M:__main__-TorchKNNImputer run with ((T9s5x2,T9s2,T11s5x2,T1s10x2),{})
[onnx_run_disc]  M:__main__-TorchKNNImputer flattened into ((T9s5x2[False,True:A0.5],T9s2[True,True:A1.0],T11s5x2[nan,nan:AnanN5nans],T1s10x2[nan,nan:AnanN4nans]),{})
[onnx_run_disc]  M:__main__-TorchKNNImputer expecting (T1s10x2[-2.5295796394348145,3.1395082473754883:A0.40699569071875885],)
[onnx_run_disc]  M:__main__-TorchKNNImputer computing A1s10x2[-2.5295796394348145,3.1395082473754883:A0.40699569071875885]
[onnx_run_disc]  M:__main__-TorchKNNImputer diff=abs=0.0, rel=0.0
[onnx_run_disc]  M:__main__-TorchKNNImputer validation done
[to_onnx_local]  M:__main__-TorchKNNImputer - done
[to_onnx_local]  M:__main__-TorchKNNImputer - discrepancies: abs=0.0, rel=0.0
[to_onnx_local]  M:__main__-TorchKNNImputer - discrepancies: abs=0.0, rel=0.0

Let’s save it.

onnx.save(onx, "plot_torch_sklearn_201.onnx")

We can also print it.

print(pretty_onnx(onx))
opset: domain='' version=18
opset: domain='local_domain' version=1
input: name='_mask_fit_x' type=dtype('bool') shape=['s0', 2]
input: name='_valid_mask' type=dtype('bool') shape=[2]
input: name='_fit_x' type=dtype('float64') shape=['s1', 2]
input: name='x' type=dtype('float32') shape=['s2', 2]
init: name='init7_s1_1' type=int64 shape=(1,) -- array([1])           -- Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape##Opset.make_node.1/Shape
init: name='init7_s1_-1' type=int64 shape=(1,) -- array([-1])         -- Opset.make_node.1/Shape##Opset.make_node.1/Shape
IsNaN(x) -> isnan
  Compress(isnan, _valid_mask, axis=1) -> _onx_compress_isnan0
    Cast(_onx_compress_isnan0, to=6) -> _onx_cast_index0
      ReduceMax(_onx_cast_index0, init7_s1_1, keepdims=0) -> _onx_reducemax_cast_index00
        Cast(_onx_reducemax_cast_index00, to=9) -> any_1
          Reshape(any_1, init7_s1_-1) -> view
            NonZero(view) -> _onx_nonzero_view0
              Reshape(_onx_nonzero_view0, init7_s1_-1) -> nonzero_numpy#0
                diag_lib_C_TorchKNNImputer__make_dict_idx_map_default[local_domain](x, nonzero_numpy#0) -> c_torch_knnimputer__make_dict_idx_map
Not(_mask_fit_x) -> logical_not
Gather(x, nonzero_numpy#0, axis=0) -> index_1
  diag_lib_C_TorchKNNImputer_dist_default[local_domain](index_1, _fit_x) -> c_torch_knnimputer_dist
  diag_lib_C_TorchKNNImputer_columns_0__default[local_domain](x, c_torch_knnimputer_dist, logical_not, _mask_fit_x, c_torch_knnimputer__make_dict_idx_map, isnan, nonzero_numpy#0, _fit_x) -> c_torch_knnimputer_columns_0_
  diag_lib_C_TorchKNNImputer_columns_1__default[local_domain](c_torch_knnimputer_columns_0_, c_torch_knnimputer_dist, logical_not, _mask_fit_x, c_torch_knnimputer__make_dict_idx_map, isnan, nonzero_numpy#0, _fit_x) -> c_torch_knnimputer_columns_1_
    Compress(c_torch_knnimputer_columns_1_, _valid_mask, axis=1) -> output_0
output: name='output_0' type=dtype('float32') shape=['s2', 'u4']
----- function name=diag_lib_C_TorchKNNImputer__make_dict_idx_map_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'x'
input: 'row_missing_idx'
Constant(value=0) -> init7_s_0
Constant(value=4) -> init7_s_4
Constant(value=1) -> init7_s_1
  Range(init7_s_0, init7_s_4, init7_s_1) -> arange
Constant(value=[-1]) -> init7_s1_-1
  Unsqueeze(row_missing_idx, init7_s1_-1) -> _onx_unsqueeze_row_missing_idx0
Shape(x, end=1, start=0) -> _shape_x0
  ConstantOfShape(_shape_x0, value=[0]) -> zeros
    ScatterND(zeros, _onx_unsqueeze_row_missing_idx0, arange) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_dist_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'x'
input: 'y'
Cast(y, to=1) -> _to_copy
  IsNaN(_to_copy) -> isnan_1
    Cast(isnan_1, to=1) -> _to_copy_1
Constant(value=0.0) -> c_lifted_tensor_0
Constant(value=0.0) -> c_lifted_tensor_1
  Where(isnan_1, c_lifted_tensor_1, _to_copy) -> index_put_1
    Mul(index_put_1, index_put_1) -> mul_11
Constant(value=nan) -> c_lifted_tensor_2
Constant(value=[1.0]) -> c_lifted_tensor_3
Constant(value=-2.0) -> init1_s_
Constant(value=[1]) -> init7_s1_1
  Reshape(init1_s_, init7_s1_1) -> _reshape_init1_s_0
Constant(value=[0.0]) -> init1_s1_
Constant(value=1.0) -> init1_s_2
  Reshape(init1_s_2, init7_s1_1) -> _reshape_init1_s_20
Constant(value=0.0) -> init1_s_3
  Reshape(init1_s_3, init7_s1_1) -> _reshape_init1_s_30
Constant(value=2.0) -> init1_s_4
  Reshape(init1_s_4, init7_s1_1) -> _reshape_init1_s_40
Constant(value=[1, -1]) -> init7_s2_1_-1
IsNaN(x) -> isnan
  Cast(isnan, to=1) -> _to_copy_2
    Gemm(_to_copy_2, mul_11, transA=0, transB=1) -> matmul_2
  Where(isnan, c_lifted_tensor_0, x) -> index_put
    Mul(index_put, index_put) -> mul_10
  ReduceSum(mul_10, init7_s1_1, keepdims=1) -> sum_1
Mul(index_put, _reshape_init1_s_0) -> _onx_mul_index_put0
  Gemm(_onx_mul_index_put0, index_put_1, transA=0, transB=1) -> matmul
    Add(matmul, sum_1) -> add_26
  ReduceSum(mul_11, init7_s1_1, keepdims=1) -> sum_2
  Reshape(sum_2, init7_s2_1_-1) -> permute_2
    Add(add_26, permute_2) -> add_35
Gemm(mul_10, _to_copy_1, transA=0, transB=1) -> matmul_1
  Sub(add_35, matmul_1) -> sub_18
    Sub(sub_18, matmul_2) -> sub_24
  Clip(sub_24, init1_s1_) -> clip
Sub(_reshape_init1_s_20, _to_copy_2) -> rsub
Not(isnan_1) -> bitwise_not
  Cast(bitwise_not, to=1) -> _to_copy_4
  Gemm(rsub, _to_copy_4, transA=0, transB=1) -> matmul_3
    Equal(matmul_3, _reshape_init1_s_30) -> eq_33
  Where(eq_33, c_lifted_tensor_2, clip) -> index_put_2
Max(c_lifted_tensor_3, matmul_3) -> maximum
  Div(index_put_2, maximum) -> div
    Mul(div, _reshape_init1_s_40) -> _onx_mul_div0
      Sqrt(_onx_mul_div0) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_0___calc_impute__donors_idx_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'dist_pot_donors'
input: 'n_neighbors'
Constant(value=0) -> init7_s_0
Constant(value=[1]) -> init7_s1_1
Constant(value=1) -> init7_s_1
Shape(dist_pot_donors, end=1, start=0) -> _shape_dist_pot_donors0
  Squeeze(_shape_dist_pot_donors0) -> sym_size_int_4
  Range(init7_s_0, sym_size_int_4, init7_s_1) -> arange
  Unsqueeze(arange, init7_s1_1) -> unsqueeze
    GatherND(dist_pot_donors, unsqueeze, batch_dims=0) -> _onx_gathernd_dist_pot_donors0
TopK(dist_pot_donors, n_neighbors, largest=0, sorted=1) -> unused_topk_values, output_0
  GatherElements(_onx_gathernd_dist_pot_donors0, output_0, axis=1) -> output_1
output: name='output_0' type=? shape=?
output: name='output_1' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_0___calc_impute__weights_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'donors_dist'
Constant(value=0.0) -> c_lifted_tensor_0
Shape(donors_dist) -> _shape_donors_dist0
  ConstantOfShape(_shape_donors_dist0, value=[1.0]) -> ones_like
IsNaN(donors_dist) -> isnan
  Where(isnan, c_lifted_tensor_0, ones_like) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_0___calc_impute__make_new_neights_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'donors_mask'
input: 'donors'
input: 'weight_matrix'
Cast(donors_mask, to=11) -> _to_copy
Cast(weight_matrix, to=11) -> _to_copy_1
  Mul(_to_copy, _to_copy_1) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_0___calc_impute_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='local_domain' version=1
input: 'dist_pot_donors'
input: 'n_neighbors'
input: 'fit_x_col'
input: 'mask_fit_x_col'
Constant(value=[1]) -> c_lifted_tensor_0
Constant(value=[1.0]) -> c_lifted_tensor_1
Constant(value=[-1]) -> init7_s1_-1
  Reshape(fit_x_col, init7_s1_-1) -> _reshape_fit_x_col0
Constant(value=0.0) -> init11_s_
  Reshape(init11_s_, c_lifted_tensor_0) -> _reshape_init11_s_0
diag_lib_C_TorchKNNImputer_columns_0___calc_impute__donors_idx_default[local_domain](dist_pot_donors, n_neighbors) -> c_torch_knnimputer_columns_0___calc_impute__donors_idx#0, c_torch_knnimputer_columns_0___calc_impute__donors_idx#1
  diag_lib_C_TorchKNNImputer_columns_0___calc_impute__weights_default[local_domain](c_torch_knnimputer_columns_0___calc_impute__donors_idx#1) -> c_torch_knnimputer_columns_0___calc_impute__weights
Gather(_reshape_fit_x_col0, c_torch_knnimputer_columns_0___calc_impute__donors_idx#0) -> take
Reshape(mask_fit_x_col, init7_s1_-1) -> _reshape_mask_fit_x_col0
  Gather(_reshape_mask_fit_x_col0, c_torch_knnimputer_columns_0___calc_impute__donors_idx#0) -> take_1
    Cast(take_1, to=7) -> _to_copy
  Sub(c_lifted_tensor_0, _to_copy) -> sub_12
  diag_lib_C_TorchKNNImputer_columns_0___calc_impute__make_new_neights_default[local_domain](sub_12, take, c_torch_knnimputer_columns_0___calc_impute__weights) -> c_torch_knnimputer_columns_0___calc_impute__make_new_neights
  ReduceSum(c_torch_knnimputer_columns_0___calc_impute__make_new_neights, c_lifted_tensor_0, keepdims=1) -> sum_1
    Equal(sum_1, _reshape_init11_s_0) -> eq_17
  Where(eq_17, c_lifted_tensor_1, sum_1) -> where
Mul(take, c_torch_knnimputer_columns_0___calc_impute__make_new_neights) -> mul_17
  ReduceSum(mul_17, c_lifted_tensor_0, keepdims=1) -> sum_2
    Div(sum_2, where) -> div
  Squeeze(div, c_lifted_tensor_0) -> _onx_squeeze_div0
    Cast(_onx_squeeze_div0, to=1) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_0__default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='local_domain' version=1
input: 'x'
input: 'dist_chunk'
input: 'non_missing_fix_x'
input: 'mask_fit_x'
input: 'dist_idx_map'
input: 'mask'
input: 'row_missing_idx'
input: '_fit_x'
Constant(value=[1.0]) -> c_lifted_tensor_0
Constant(value=3) -> c_lifted_tensor_1
Constant(value=[1]) -> c_lifted_tensor_2
Constant(value=0) -> init7_s_0
  Gather(mask, init7_s_0, axis=1) -> select
    Gather(select, row_missing_idx, axis=0) -> index
Constant(value=[-1]) -> init7_s1_-1
  Reshape(index, init7_s1_-1) -> view
    NonZero(view) -> _onx_nonzero_view0
  Reshape(_onx_nonzero_view0, init7_s1_-1) -> nonzero_numpy_1#0
    Gather(row_missing_idx, nonzero_numpy_1#0, axis=0) -> index_1
      Gather(dist_idx_map, index_1, axis=0) -> index_2
        Gather(dist_chunk, index_2, axis=0) -> index_3
Constant(value=1.0) -> init11_s_
  Reshape(init11_s_, c_lifted_tensor_2) -> _reshape_init11_s_0
Constant(value=0.0) -> init1_s_
Constant(value=[0]) -> init7_s1_0
Constant(value=1) -> init7_s_1
Gather(non_missing_fix_x, init7_s_0, axis=1) -> select_1
  NonZero(select_1) -> _onx_nonzero_select_10
  Reshape(_onx_nonzero_select_10, init7_s1_-1) -> nonzero_numpy#0
    Shape(nonzero_numpy#0, end=1, start=0) -> _shape_getitem_20
      Squeeze(_shape_getitem_20) -> sym_size_int_20
  Less(c_lifted_tensor_1, sym_size_int_20) -> lt
  Where(lt, c_lifted_tensor_1, sym_size_int_20) -> where_1
  LessOrEqual(where_1, init7_s_0) -> le_3
  Where(le_3, c_lifted_tensor_2, where_1) -> where_2
Gather(index_3, nonzero_numpy#0, axis=1) -> _onx_gather_index_30
  IsNaN(_onx_gather_index_30) -> isnan
    Cast(isnan, to=6) -> _onx_cast_isnan0
  ReduceMin(_onx_cast_isnan0, c_lifted_tensor_2, keepdims=0) -> _onx_reducemin_cast_isnan00
    Cast(_onx_reducemin_cast_isnan00, to=9) -> all_1
      Compress(index_1, all_1, axis=0) -> index_5
  Unsqueeze(index_5, init7_s1_-1) -> _onx_unsqueeze_index_50
Gather(mask_fit_x, init7_s_0, axis=1) -> select_2
  Not(select_2) -> bitwise_not
    Cast(bitwise_not, to=11) -> _to_copy
      Cast(_to_copy, to=1) -> _to_copy_1
        ReduceSum(_to_copy_1, keepdims=0) -> sum_1
  Greater(sum_1, init1_s_) -> gt
  Where(gt, sum_1, c_lifted_tensor_0) -> where
Equal(_to_copy, _reshape_init11_s_0) -> eq_23
Gather(_fit_x, init7_s_0, axis=1) -> select_3
  Compress(select_3, eq_23, axis=0) -> index_6
    ReduceSum(index_6, keepdims=0) -> sum_2
      Cast(sum_2, to=1) -> _to_copy_2
  Reshape(_to_copy_2, c_lifted_tensor_2) -> _reshape__to_copy_20
    Div(_reshape__to_copy_20, where) -> div
  Squeeze(div, init7_s1_0) -> view_1
Gather(x, init7_s_0, axis=1) -> select_4
Shape(index_5) -> _shape_index_502
  Expand(view_1, _shape_index_502) -> _onx_expand_view_10
  ScatterND(select_4, _onx_unsqueeze_index_50, _onx_expand_view_10) -> index_put
  Unsqueeze(index_put, init7_s_1) -> _onx_unsqueeze_index_put0
    Shape(_onx_unsqueeze_index_put0) -> _shape_unsqueeze_index_put00
  Expand(init7_s1_0, _shape_unsqueeze_index_put00) -> _onx_expand_init7_s1_00
    ScatterElements(x, _onx_expand_init7_s1_00, _onx_unsqueeze_index_put0, axis=1, reduction=b'none') -> select_scatter
  Gather(select_scatter, init7_s_0, axis=1) -> select_9
Not(all_1) -> bitwise_not_1
  Compress(index_1, bitwise_not_1, axis=0) -> index_7
    Gather(dist_idx_map, index_7, axis=0) -> index_8
      Gather(dist_chunk, index_8, axis=0) -> index_9
    Gather(index_9, nonzero_numpy#0, axis=1) -> _onx_gather_index_90
  Gather(_fit_x, init7_s_0, axis=1) -> select_6
    Gather(select_6, nonzero_numpy#0, axis=0) -> index_11
  Gather(mask_fit_x, init7_s_0, axis=1) -> select_7
    Gather(select_7, nonzero_numpy#0, axis=0) -> index_12
    diag_lib_C_TorchKNNImputer_columns_0___calc_impute_default[local_domain](_onx_gather_index_90, where_2, index_11, index_12) -> c_torch_knnimputer_columns_0___calc_impute
  Unsqueeze(index_7, init7_s1_-1) -> _onx_unsqueeze_index_70
    ScatterND(select_9, _onx_unsqueeze_index_70, c_torch_knnimputer_columns_0___calc_impute) -> index_put_1
  Unsqueeze(index_put_1, init7_s_1) -> _onx_unsqueeze_index_put_10
    Shape(_onx_unsqueeze_index_put_10) -> _shape_unsqueeze_index_put_100
  Expand(init7_s1_0, _shape_unsqueeze_index_put_100) -> _onx_expand_init7_s1_002
    ScatterElements(select_scatter, _onx_expand_init7_s1_002, _onx_unsqueeze_index_put_10, axis=1, reduction=b'none') -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_1___calc_impute__donors_idx_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'dist_pot_donors'
input: 'n_neighbors'
Constant(value=0) -> init7_s_0
Constant(value=[1]) -> init7_s1_1
Constant(value=1) -> init7_s_1
Shape(dist_pot_donors, end=1, start=0) -> _shape_dist_pot_donors0
  Squeeze(_shape_dist_pot_donors0) -> sym_size_int_4
  Range(init7_s_0, sym_size_int_4, init7_s_1) -> arange
  Unsqueeze(arange, init7_s1_1) -> unsqueeze
    GatherND(dist_pot_donors, unsqueeze, batch_dims=0) -> _onx_gathernd_dist_pot_donors0
TopK(dist_pot_donors, n_neighbors, largest=0, sorted=1) -> unused_topk_values, output_0
  GatherElements(_onx_gathernd_dist_pot_donors0, output_0, axis=1) -> output_1
output: name='output_0' type=? shape=?
output: name='output_1' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_1___calc_impute__weights_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'donors_dist'
Constant(value=0.0) -> c_lifted_tensor_0
Shape(donors_dist) -> _shape_donors_dist0
  ConstantOfShape(_shape_donors_dist0, value=[1.0]) -> ones_like
IsNaN(donors_dist) -> isnan
  Where(isnan, c_lifted_tensor_0, ones_like) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_1___calc_impute__make_new_neights_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
input: 'donors_mask'
input: 'donors'
input: 'weight_matrix'
Cast(donors_mask, to=11) -> _to_copy
Cast(weight_matrix, to=11) -> _to_copy_1
  Mul(_to_copy, _to_copy_1) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_1___calc_impute_default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='local_domain' version=1
input: 'dist_pot_donors'
input: 'n_neighbors'
input: 'fit_x_col'
input: 'mask_fit_x_col'
Constant(value=[1]) -> c_lifted_tensor_0
Constant(value=[1.0]) -> c_lifted_tensor_1
Constant(value=[-1]) -> init7_s1_-1
  Reshape(fit_x_col, init7_s1_-1) -> _reshape_fit_x_col0
Constant(value=0.0) -> init11_s_
  Reshape(init11_s_, c_lifted_tensor_0) -> _reshape_init11_s_0
diag_lib_C_TorchKNNImputer_columns_1___calc_impute__donors_idx_default[local_domain](dist_pot_donors, n_neighbors) -> c_torch_knnimputer_columns_1___calc_impute__donors_idx#0, c_torch_knnimputer_columns_1___calc_impute__donors_idx#1
  diag_lib_C_TorchKNNImputer_columns_1___calc_impute__weights_default[local_domain](c_torch_knnimputer_columns_1___calc_impute__donors_idx#1) -> c_torch_knnimputer_columns_1___calc_impute__weights
Gather(_reshape_fit_x_col0, c_torch_knnimputer_columns_1___calc_impute__donors_idx#0) -> take
Reshape(mask_fit_x_col, init7_s1_-1) -> _reshape_mask_fit_x_col0
  Gather(_reshape_mask_fit_x_col0, c_torch_knnimputer_columns_1___calc_impute__donors_idx#0) -> take_1
    Cast(take_1, to=7) -> _to_copy
  Sub(c_lifted_tensor_0, _to_copy) -> sub_12
  diag_lib_C_TorchKNNImputer_columns_1___calc_impute__make_new_neights_default[local_domain](sub_12, take, c_torch_knnimputer_columns_1___calc_impute__weights) -> c_torch_knnimputer_columns_1___calc_impute__make_new_neights
  ReduceSum(c_torch_knnimputer_columns_1___calc_impute__make_new_neights, c_lifted_tensor_0, keepdims=1) -> sum_1
    Equal(sum_1, _reshape_init11_s_0) -> eq_17
  Where(eq_17, c_lifted_tensor_1, sum_1) -> where
Mul(take, c_torch_knnimputer_columns_1___calc_impute__make_new_neights) -> mul_17
  ReduceSum(mul_17, c_lifted_tensor_0, keepdims=1) -> sum_2
    Div(sum_2, where) -> div
  Squeeze(div, c_lifted_tensor_0) -> _onx_squeeze_div0
    Cast(_onx_squeeze_div0, to=1) -> output_0
output: name='output_0' type=? shape=?
----- function name=diag_lib_C_TorchKNNImputer_columns_1__default domain=local_domain
----- doc_string: -- function_options=FunctionOptions(export_as_function=...
opset: domain='' version=18
opset: domain='local_domain' version=1
input: 'x'
input: 'dist_chunk'
input: 'non_missing_fix_x'
input: 'mask_fit_x'
input: 'dist_idx_map'
input: 'mask'
input: 'row_missing_idx'
input: '_fit_x'
Constant(value=[1.0]) -> c_lifted_tensor_0
Constant(value=3) -> c_lifted_tensor_1
Constant(value=[1]) -> c_lifted_tensor_2
Constant(value=1) -> init7_s_1
  Gather(mask, init7_s_1, axis=1) -> select
    Gather(select, row_missing_idx, axis=0) -> index
Constant(value=[-1]) -> init7_s1_-1
  Reshape(index, init7_s1_-1) -> view
    NonZero(view) -> _onx_nonzero_view0
  Reshape(_onx_nonzero_view0, init7_s1_-1) -> nonzero_numpy_1#0
    Gather(row_missing_idx, nonzero_numpy_1#0, axis=0) -> index_1
      Gather(dist_idx_map, index_1, axis=0) -> index_2
        Gather(dist_chunk, index_2, axis=0) -> index_3
Constant(value=0) -> init7_s_0
Constant(value=1.0) -> init11_s_
  Reshape(init11_s_, c_lifted_tensor_2) -> _reshape_init11_s_0
Constant(value=0.0) -> init1_s_
Constant(value=[0]) -> init7_s1_0
Gather(non_missing_fix_x, init7_s_1, axis=1) -> select_1
  NonZero(select_1) -> _onx_nonzero_select_10
  Reshape(_onx_nonzero_select_10, init7_s1_-1) -> nonzero_numpy#0
    Shape(nonzero_numpy#0, end=1, start=0) -> _shape_getitem_20
      Squeeze(_shape_getitem_20) -> sym_size_int_20
  Less(c_lifted_tensor_1, sym_size_int_20) -> lt
  Where(lt, c_lifted_tensor_1, sym_size_int_20) -> where_1
  LessOrEqual(where_1, init7_s_0) -> le_3
  Where(le_3, c_lifted_tensor_2, where_1) -> where_2
Gather(index_3, nonzero_numpy#0, axis=1) -> _onx_gather_index_30
  IsNaN(_onx_gather_index_30) -> isnan
    Cast(isnan, to=6) -> _onx_cast_isnan0
  ReduceMin(_onx_cast_isnan0, c_lifted_tensor_2, keepdims=0) -> _onx_reducemin_cast_isnan00
    Cast(_onx_reducemin_cast_isnan00, to=9) -> all_1
      Compress(index_1, all_1, axis=0) -> index_5
  Unsqueeze(index_5, init7_s1_-1) -> _onx_unsqueeze_index_50
Gather(mask_fit_x, init7_s_1, axis=1) -> select_2
  Not(select_2) -> bitwise_not
    Cast(bitwise_not, to=11) -> _to_copy
      Cast(_to_copy, to=1) -> _to_copy_1
        ReduceSum(_to_copy_1, keepdims=0) -> sum_1
  Greater(sum_1, init1_s_) -> gt
  Where(gt, sum_1, c_lifted_tensor_0) -> where
Equal(_to_copy, _reshape_init11_s_0) -> eq_23
Gather(_fit_x, init7_s_1, axis=1) -> select_3
  Compress(select_3, eq_23, axis=0) -> index_6
    ReduceSum(index_6, keepdims=0) -> sum_2
      Cast(sum_2, to=1) -> _to_copy_2
  Reshape(_to_copy_2, c_lifted_tensor_2) -> _reshape__to_copy_20
    Div(_reshape__to_copy_20, where) -> div
  Squeeze(div, init7_s1_0) -> view_1
Gather(x, init7_s_1, axis=1) -> select_4
Shape(index_5) -> _shape_index_502
  Expand(view_1, _shape_index_502) -> _onx_expand_view_10
  ScatterND(select_4, _onx_unsqueeze_index_50, _onx_expand_view_10) -> index_put
  Unsqueeze(index_put, init7_s_1) -> _onx_unsqueeze_index_put0
    Shape(_onx_unsqueeze_index_put0) -> _shape_unsqueeze_index_put00
  Expand(c_lifted_tensor_2, _shape_unsqueeze_index_put00) -> _onx_expand_c_lifted_tensor_20
    ScatterElements(x, _onx_expand_c_lifted_tensor_20, _onx_unsqueeze_index_put0, axis=1, reduction=b'none') -> select_scatter
  Gather(select_scatter, init7_s_1, axis=1) -> select_9
Not(all_1) -> bitwise_not_1
  Compress(index_1, bitwise_not_1, axis=0) -> index_7
    Gather(dist_idx_map, index_7, axis=0) -> index_8
      Gather(dist_chunk, index_8, axis=0) -> index_9
    Gather(index_9, nonzero_numpy#0, axis=1) -> _onx_gather_index_90
  Gather(_fit_x, init7_s_1, axis=1) -> select_6
    Gather(select_6, nonzero_numpy#0, axis=0) -> index_11
  Gather(mask_fit_x, init7_s_1, axis=1) -> select_7
    Gather(select_7, nonzero_numpy#0, axis=0) -> index_12
    diag_lib_C_TorchKNNImputer_columns_1___calc_impute_default[local_domain](_onx_gather_index_90, where_2, index_11, index_12) -> c_torch_knnimputer_columns_1___calc_impute
  Unsqueeze(index_7, init7_s1_-1) -> _onx_unsqueeze_index_70
    ScatterND(select_9, _onx_unsqueeze_index_70, c_torch_knnimputer_columns_1___calc_impute) -> index_put_1
  Unsqueeze(index_put_1, init7_s_1) -> _onx_unsqueeze_index_put_10
    Shape(_onx_unsqueeze_index_put_10) -> _shape_unsqueeze_index_put_100
  Expand(c_lifted_tensor_2, _shape_unsqueeze_index_put_100) -> _onx_expand_c_lifted_tensor_202
    ScatterElements(select_scatter, _onx_expand_c_lifted_tensor_202, _onx_unsqueeze_index_put_10, axis=1, reduction=b'none') -> output_0
output: name='output_0' type=? shape=?

Validation

def validate_onnx(size, sizey, onx, verbose: int = 1, use_ort: bool = False):
    X = torch.randn((size, 2))
    Y = torch.randn((sizey, 2))
    for i in range(5):
        X[i, i % 2] = torch.nan
    for i in range(4):
        Y[i + 1, i % 2] = torch.nan

    knn_imputer = sklearn.impute.KNNImputer(n_neighbors=3)
    knn_imputer.fit(X)

    model = TorchKNNImputer(knn_imputer)

    p1 = knn_imputer.transform(Y)

    model_inputs = (
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X),
        Y,
    )
    p2 = model.transform(*model_inputs)
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    if verbose:
        print(f"Torch Discrepancies for size={size} and sizey={sizey}, d={d}")

    input_names = [i.name for i in onx.graph.input]
    feeds = dict(zip(input_names, [t.numpy() for t in model_inputs]))

    if verbose:
        print("python: loading the model...")
    sess = ExtendedReferenceEvaluator(onx, verbose=0)
    if verbose:
        print("python: running the model...")
    got = sess.run(None, feeds)
    d = max_diff(p1, got[0])
    assert d["abs"] < 1e-5, f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}"
    if verbose:
        print(f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}")

    if use_ort:
        if verbose:
            print("onnxruntime: loading the model...")
        opts = onnxruntime.SessionOptions()
        opts.optimized_model_filepath = "plot_torch_sklearn_201.ort.onnx"
        opts.log_severity_level = 0
        opts.log_verbosity_level = 0
        sess = onnxruntime.InferenceSession(
            onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
        )
        if verbose:
            print("onnxruntime: running the model...")
        got = sess.run(None, feeds)
        d = max_diff(p1, got[0])
        assert d["abs"] < 1e-5, f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}"
        if verbose:
            print(f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}")

    model_inputs = (
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X),
        Y[1:2],
    )
    p1 = knn_imputer.transform(Y[1:2])
    p2 = model.transform(*model_inputs)
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    feeds = dict(zip(input_names, [t.numpy() for t in model_inputs]))
    if verbose:
        print("onnxruntime: running the model...")
    got = sess.run(None, feeds)
    d = max_diff(p1, got[0])
    assert d["abs"] < 1e-5, f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}"
    if verbose:
        print("done")


# This does not work yet.
validate_onnx(5, 10, onx)
validate_onnx(50, 40, onx)
Torch Discrepancies for size=5 and sizey=10, d={'abs': 2.9802322387695312e-08, 'rel': 4.6411738120561695e-08, 'sum': 7.947285962650597e-08, 'n': 20.0, 'dnan': 0.0}
python: loading the model...
python: running the model...
ONNX Discrepancies for size=5 and sizey=10, d={'abs': 2.9802322387695312e-08, 'rel': 4.6411738120561695e-08, 'sum': 7.947285962650597e-08, 'n': 20.0, 'dnan': 0.0}
onnxruntime: running the model...
done
Torch Discrepancies for size=50 and sizey=40, d={'abs': 1.9868214962137642e-08, 'rel': 2.5721566741681128e-08, 'sum': 3.725290298461914e-08, 'n': 80.0, 'dnan': 0.0}
python: loading the model...
python: running the model...
ONNX Discrepancies for size=50 and sizey=40, d={'abs': 1.9868214962137642e-08, 'rel': 2.5721566741681128e-08, 'sum': 3.725290298461914e-08, 'n': 80.0, 'dnan': 0.0}
onnxruntime: running the model...
done

Total running time of the script: (0 minutes 18.881 seconds)

Related examples

102: Convolution and Matrix Multiplication

102: Convolution and Matrix Multiplication

102: Fuse kernels in a small Llama Model

102: Fuse kernels in a small Llama Model

301: Compares LLAMA exporters

301: Compares LLAMA exporters

101: A custom backend for torch

101: A custom backend for torch

101: Linear Regression and export to ONNX

101: Linear Regression and export to ONNX

Gallery generated by Sphinx-Gallery