201: Use torch to export a scikit-learn model into ONNX

When sklearn-onnx is missing a converter, torch can be used to write it. We use sklearn.impute.KNNImputer as an example. The first step is to rewrite the scikit-learn model with torch functions. The code is then refactored and split into submodules to be able to bypass some pieces torch.export.export() cannot process.

torch implementation of nan_euclidean_distances

Let’s start with a simple case, a pairwise distance. See sklearn.metrics.nan_euclidean_distances().

Module

import contextlib
import io
import logging
import math
import numbers
import warnings
import numpy as np
import onnx
import sklearn
import torch
import onnxruntime
from onnx_diagnostic.helpers import max_diff
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
from experimental_experiment.reference import ExtendedReferenceEvaluator
from experimental_experiment.skl.helpers import flatnonzero, _get_weights
from experimental_experiment.torch_interpreter import make_undefined_dimension
from onnx_diagnostic.torch_export_patches import torch_export_patches
from experimental_experiment.torch_interpreter.piece_by_piece import (
    trace_execution_piece_by_piece,
    CustomOpStrategy,
)
from experimental_experiment.xbuilder.reverse_graph_builder import to_graph_builder_code


class NanEuclidean(torch.nn.Module):
    """Implements :func:`sklearn.metrics.nan_euclidean_distances`."""

    def __init__(self, squared=False, copy=True):
        super().__init__()
        self.squared = squared
        self.copy = copy

    def forward(self, X, Y):
        X = X.clone()
        Y = Y.to(X.dtype).clone()

        missing_X = torch.isnan(X)
        missing_Y = torch.isnan(Y)

        # set missing values to zero
        X[missing_X] = 0
        Y[missing_Y] = 0

        # Adjust distances for missing values
        XX = X * X
        YY = Y * Y

        distances = -2 * X @ Y.T + XX.sum(1, keepdim=True) + YY.sum(1, keepdim=True).T

        distances -= XX @ missing_Y.to(X.dtype).T
        distances -= missing_X.to(X.dtype) @ YY.T

        distances = torch.clip(distances, 0, None)

        present_X = 1 - missing_X.to(X.dtype)
        present_Y = ~missing_Y
        present_count = present_X @ present_Y.to(X.dtype).T
        distances[present_count == 0] = torch.nan
        # avoid divide by zero
        present_count = torch.maximum(
            torch.tensor([1], dtype=present_count.dtype), present_count
        )
        distances /= present_count
        distances *= X.shape[1]

        if not self.squared:
            distances = distances.sqrt()

        return distances

Validation

def get_xy(sizex=5, sizey=3, col=3, n_nans=None):
    X = torch.randn((sizex, col))
    Y = torch.randn((sizey, col))
    i_nans = 0
    for i in range(sizex):
        X[i, i % X.shape[1]] = torch.nan
        i_nans += 1
        if n_nans and n_nans >= i_nans:
            break
        X[i, (i + 1) % X.shape[1]] = torch.nan
        i_nans += 1
        if n_nans and n_nans >= i_nans:
            break
    i_nans = 0
    for i in range(sizey):
        Y[(i + 1) % sizey, i % Y.shape[1]] = torch.nan
        i_nans += 1
        if n_nans and n_nans >= i_nans:
            break
        Y[(i + 1) % sizey, (i + 1) % Y.shape[1]] = torch.nan
        i_nans += 1
        if n_nans and n_nans >= i_nans:
            break
    return X, Y


X, Y = get_xy()
model = NanEuclidean()


d1 = sklearn.metrics.nan_euclidean_distances(X.numpy(), Y.numpy())
d2 = model(X, Y)
print(f"discrepancies: {max_diff(d1, d2)}")
discrepancies: {'abs': 0.0, 'rel': 0.0, 'sum': 0.0, 'n': 15.0, 'dnan': 0.0, 'argm': (0, 0)}

torch implementation of KNNImputer

See sklearn.impute.KNNImputer. The code is split into several torch.nn.Module and refactored to avoid control flow.

Module and sub modules

def _get_mask(X, value_to_mask):
    return (
        torch.isnan(X)
        if (  # sklearn.utils._missing.is_scalar_nan(value_to_mask)
            not isinstance(value_to_mask, numbers.Integral)
            and isinstance(value_to_mask, numbers.Real)
            and math.isnan(value_to_mask)
        )
        else (value_to_mask == X)
    )


class SubWeightMatrix(torch.nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.weights = weights

    def forward(self, donors_dist):
        weight_matrix = _get_weights(donors_dist, self.weights)
        if weight_matrix is not None:
            weight_matrix = weight_matrix.clone()
            weight_matrix[torch.isnan(weight_matrix)] = 0.0
        else:
            weight_matrix = torch.ones_like(donors_dist)
            weight_matrix[torch.isnan(donors_dist)] = 0.0
        return weight_matrix


class SubDonorsIdx(torch.nn.Module):
    def forward(self, dist_pot_donors, n_neighbors):
        xn = torch.nan_to_num(dist_pot_donors, nan=1.0e10)
        tk = torch.topk(xn, n_neighbors, dim=1, largest=False, sorted=True)
        return tk.indices, tk.values


class MakeNewWeights(torch.nn.Module):
    def forward(self, donors_mask, donors, weight_matrix):
        return donors_mask.to(donors.dtype) * weight_matrix.to(donors.dtype)


class CalcImpute(torch.nn.Module):
    """Implements :meth:`sklearn.impute.KNNImputer._calc_impute`."""

    def __init__(self, weights):
        super().__init__()
        self._weights = SubWeightMatrix(weights)
        self._donors_idx = SubDonorsIdx()
        self._make_new_neights = MakeNewWeights()

    def _calc_impute(self, dist_pot_donors, n_neighbors, fit_X_col, mask_fit_X_col):
        donors_idx, donors_dist = self._donors_idx(dist_pot_donors, n_neighbors)
        weight_matrix = self._weights(donors_dist)
        # Retrieve donor values and calculate kNN average
        donors = fit_X_col.take(donors_idx)
        donors_mask = torch.tensor([1], dtype=donors_idx.dtype) - (
            mask_fit_X_col.take(donors_idx)
        ).to(donors_idx.dtype)

        new_weights = self._make_new_neights(donors_mask, donors, weight_matrix)

        weights_sum = new_weights.sum(axis=1, keepdim=True)
        div = torch.where(
            weights_sum == 0, torch.tensor([1], dtype=weights_sum.dtype), weights_sum
        )
        res = (donors * new_weights).sum(axis=1, keepdim=True) / div
        return res.squeeze(dim=1).to(dist_pot_donors.dtype)

    def forward(self, dist_pot_donors, n_neighbors, fit_X_col, mask_fit_X_col):
        return self._calc_impute(dist_pot_donors, n_neighbors, fit_X_col, mask_fit_X_col)


class ColProcessorAllNan(torch.nn.Module):
    def __init__(self, col: int):
        super().__init__()
        self.col = col

    def forward(
        self,
        X,
        dist_subset,
        mask_fit_X,
        _fit_X,
        receivers_idx,
        all_nan_receivers_idx,
        all_nan_dist_mask,
        dist_chunk,
        dist_idx_map,
        potential_donors_idx,
    ):
        col = self.col
        X = X.clone()
        mask_ = (~mask_fit_X[:, col]).to(_fit_X.dtype)
        mask_sum = mask_.to(X.dtype).sum()

        col_sum = (_fit_X[mask_ == 1, col]).sum().to(X.dtype)
        div = torch.where(mask_sum > 0, mask_sum, torch.tensor([1], dtype=mask_sum.dtype))
        X[all_nan_receivers_idx, col] = col_sum / div

        # receivers with at least one defined distance
        receivers_idx = receivers_idx[~all_nan_dist_mask]
        dist_subset = dist_chunk[dist_idx_map[receivers_idx]][:, potential_donors_idx]
        return X, dist_subset, receivers_idx


class ColProcessorIdentity(torch.nn.Module):
    def forward(
        self,
        X,
        dist_subset,
        mask_fit_X,
        _fit_X,
        receivers_idx,
        all_nan_receivers_idx,
        all_nan_dist_mask,
        dist_chunk,
        dist_idx_map,
        potential_donors_idx,
    ):
        # .clone() not efficient but torch.cond does not like simple return
        return (
            X.contiguous(),
            dist_subset.contiguous(),
            receivers_idx.contiguous(),
        )


class ColProcessorCond(torch.nn.Module):
    def __init__(self, col: int):
        super().__init__()
        self.col = col
        self._all_nan = ColProcessorAllNan(col)
        self._identity = ColProcessorIdentity()

    def forward(
        self,
        X,
        dist_subset,
        mask_fit_X,
        _fit_X,
        receivers_idx,
        all_nan_receivers_idx,
        all_nan_dist_mask,
        dist_chunk,
        dist_idx_map,
        potential_donors_idx,
    ):
        X, dist_subset, receivers_idx = torch.cond(
            all_nan_receivers_idx.numel() > 0,
            self._all_nan,
            self._identity,
            [
                X,
                dist_subset,
                mask_fit_X,
                _fit_X,
                receivers_idx,
                all_nan_receivers_idx,
                all_nan_dist_mask,
                dist_chunk,
                dist_idx_map,
                potential_donors_idx,
            ],
        )
        return X.contiguous(), dist_subset.contiguous(), receivers_idx.contiguous()


class ColProcessor(torch.nn.Module):
    """Processes one column (= one feature)."""

    def __init__(self, col, n_neighbors, weights):
        super().__init__()
        self._calc_impute = CalcImpute(weights)
        self._col_cond = ColProcessorCond(col)
        self.col = col
        self.n_neighbors = n_neighbors

    def process_one_col(
        self,
        X,
        dist_chunk,
        non_missing_fix_X,
        mask_fit_X,
        dist_idx_map,
        mask,
        row_missing_idx,
        _fit_X,
    ):
        col = self.col
        X = X.clone()
        row_missing_chunk = row_missing_idx
        col_mask = mask[row_missing_chunk, col]

        potential_donors_idx = torch.nonzero(non_missing_fix_X[:, col], as_tuple=True)[0]

        # receivers_idx are indices in X
        receivers_idx = row_missing_chunk[flatnonzero(col_mask)]

        # distances for samples that needed imputation for column
        dist_subset = dist_chunk[dist_idx_map[receivers_idx]][:, potential_donors_idx]

        # receivers with all nan distances impute with mean
        all_nan_dist_mask = torch.isnan(dist_subset).all(axis=1)
        all_nan_receivers_idx = receivers_idx[all_nan_dist_mask]

        # when all_nan_receivers_idx is not empty (training set is small)
        # ... if all_nan_receivers_idx.size > 0:
        #    # onnxruntime does not like this part when it is empty
        #    mask_ = (~mask_fit_X[:, col]).to(_fit_X.dtype)
        #    mask_sum = mask_.to(X.dtype).sum()
        #
        #    col_sum = (_fit_X[mask_ == 1, col]).sum().to(X.dtype)
        #    div = torch.where(mask_sum > 0, mask_sum, torch.tensor([1], dtype=mask_sum.dtype))
        #    X[all_nan_receivers_idx, col] = col_sum / div
        #
        #     # receivers with at least one defined distance
        #     receivers_idx = receivers_idx[~all_nan_dist_mask]
        #     dist_subset = dist_chunk[dist_idx_map[receivers_idx]][:, potential_donors_idx]
        # else
        #     ... identity
        X, dist_subset, receivers_idx = self._col_cond(
            X,
            dist_subset,
            mask_fit_X,
            _fit_X,
            receivers_idx,
            all_nan_receivers_idx,
            all_nan_dist_mask,
            dist_chunk,
            dist_idx_map,
            potential_donors_idx,
        )

        # when all_nan_receivers_idx is not empty (training set is big)
        tn = torch.tensor(self.n_neighbors)
        n_neighbors = torch.where(
            tn < potential_donors_idx.shape[0], tn, potential_donors_idx.shape[0]
        )
        # to make sure n_neighbors > 0
        n_neighbors = torch.where(
            n_neighbors <= 0, torch.tensor([1], dtype=n_neighbors.dtype), n_neighbors
        )
        value = self._calc_impute(
            dist_subset,
            n_neighbors,
            _fit_X[potential_donors_idx, col],
            mask_fit_X[potential_donors_idx, col],
        )
        X[receivers_idx, col] = value.to(X.dtype)
        return X

    def forward(
        self,
        X,
        dist_chunk,
        non_missing_fix_X,
        mask_fit_X,
        dist_idx_map,
        mask,
        row_missing_idx,
        _fit_X,
    ):
        return self.process_one_col(
            X,
            dist_chunk,
            non_missing_fix_X,
            mask_fit_X,
            dist_idx_map,
            mask,
            row_missing_idx,
            _fit_X,
        )


class MakeDictIdxMap(torch.nn.Module):
    def forward(self, X, row_missing_idx):
        dist_idx_map = torch.zeros(X.shape[0], dtype=int)
        dist_idx_map[row_missing_idx] = torch.arange(row_missing_idx.shape[0])
        return dist_idx_map


class TorchKNNImputer(torch.nn.Module):
    def __init__(self, knn_imputer):
        super().__init__()
        assert (
            knn_imputer.metric == "nan_euclidean"
        ), f"Not implemented for metric={knn_imputer.metric!r}"
        self.dist = NanEuclidean()
        cols = []
        for col in range(knn_imputer._fit_X.shape[1]):
            cols.append(ColProcessor(col, knn_imputer.n_neighbors, knn_imputer.weights))
        self.columns = torch.nn.ModuleList(cols)
        # refactoring
        self._make_dict_idx_map = MakeDictIdxMap()
        # knn imputer
        self.missing_values = knn_imputer.missing_values
        self.n_neighbors = knn_imputer.n_neighbors
        self.weights = knn_imputer.weights
        self.metric = knn_imputer.metric
        self.keep_empty_features = knn_imputer.keep_empty_features
        self.add_indicator = knn_imputer.add_indicator
        # results of fitting
        self.indicator_ = knn_imputer.indicator_
        # The training results.
        # self._fit_X = torch.from_numpy(knn_imputer._fit_X.astype(np.float32))
        # self._mask_fit_X = torch.from_numpy(knn_imputer._mask_fit_X)
        # self._valid_mask = torch.from_numpy(knn_imputer._valid_mask)

    def _transform_indicator(self, X):
        if self.add_indicator:
            if not hasattr(self, "indicator_"):
                raise ValueError(
                    "Make sure to call _fit_indicator before _transform_indicator"
                )
            raise NotImplementedError(type(self.indicator_))
            # return self.indicator_.transform(X)
        return None

    def _concatenate_indicator(self, X_imputed, X_indicator):
        if not self.add_indicator:
            return X_imputed
        if X_indicator is None:
            raise ValueError(
                "Data from the missing indicator are not provided. Call "
                "_fit_indicator and _transform_indicator in the imputer "
                "implementation."
            )
        return torch.cat([X_imputed, X_indicator], dim=0)

    def transform(self, mask_fit_X, _valid_mask, _fit_X, X):
        X = X.clone()
        mask = _get_mask(X, self.missing_values)

        X_indicator = self._transform_indicator(mask)

        row_missing_idx = flatnonzero(mask[:, _valid_mask].any(axis=1))
        non_missing_fix_X = torch.logical_not(mask_fit_X)

        # Maps from indices from X to indices in dist matrix
        dist_idx_map = self._make_dict_idx_map(X, row_missing_idx)

        # process in fixed-memory chunks
        pairwise_distances = self.dist(X[row_missing_idx, :], _fit_X)

        # The export unfold the loop as it depends on the number of features.
        # Fixed in this case.
        for col_processor in self.columns:
            X = col_processor(
                X,
                pairwise_distances,
                non_missing_fix_X,
                mask_fit_X,
                dist_idx_map,
                mask,
                row_missing_idx,
                _fit_X,
            )

        if self.keep_empty_features:
            Xc = X.clone()
            Xc[:, ~_valid_mask] = 0
        else:
            Xc = X[:, _valid_mask]

        return self._concatenate_indicator(Xc, X_indicator)

    def forward(self, _mask_fit_X, _valid_mask, _fit_X, X):
        return self.transform(_mask_fit_X, _valid_mask, _fit_X, X)

Validation

We need to do that with different sizes of training set.

def validate(size, sizey, col=3, n_nans=None):
    X, Y = get_xy(size, sizey, col, n_nans)
    knn_imputer = sklearn.impute.KNNImputer(n_neighbors=3)
    knn_imputer.fit(X)

    model = TorchKNNImputer(knn_imputer)

    p1 = knn_imputer.transform(Y)
    p2 = model.transform(
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X.astype(np.float32)),
        Y,
    )
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    print(f"knn discrepancies for size={size}: {d}")

    p1 = knn_imputer.transform(Y[1:2])
    p2 = model.transform(
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X.astype(np.float32)),
        Y[1:2],
    )
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    print(f"knn discrepancies for size={size}: {d}")
    return knn_imputer, Y


knn5, Y10 = validate(5, 10)
knn50, Y40 = validate(50, 40)
knn1, Y1 = validate(10, 10, n_nans=1)
knn11, Y11 = validate(11, 11, n_nans=1)
knn discrepancies for size=5: {'abs': 1.4901161193847656e-08, 'rel': 3.727023293108136e-08, 'sum': 1.043081283569336e-07, 'n': 30.0, 'dnan': 0.0, 'argm': (0, 0)}
knn discrepancies for size=5: {'abs': 1.4901161193847656e-08, 'rel': 3.727023293108136e-08, 'sum': 1.4901161193847656e-08, 'n': 3.0, 'dnan': 0.0, 'argm': (0, 0)}
knn discrepancies for size=50: {'abs': 9.368009423749157e-09, 'rel': 1.379606193778086e-07, 'sum': 5.808165843348978e-07, 'n': 120.0, 'dnan': 0.0, 'argm': (2, 2)}
knn discrepancies for size=50: {'abs': 7.450580596923828e-09, 'rel': 9.230125399471292e-08, 'sum': 1.2490679233978508e-08, 'n': 3.0, 'dnan': 0.0, 'argm': (0, 1)}
knn discrepancies for size=10: {'abs': 1.9868214962137642e-08, 'rel': 2.948249269874871e-08, 'sum': 1.9868214962137642e-08, 'n': 30.0, 'dnan': 0.0, 'argm': (1, 0)}
knn discrepancies for size=10: {'abs': 1.9868214962137642e-08, 'rel': 2.948249269874871e-08, 'sum': 1.9868214962137642e-08, 'n': 3.0, 'dnan': 0.0, 'argm': (0, 0)}
knn discrepancies for size=11: {'abs': 9.934107481068821e-09, 'rel': 3.110919082332138e-08, 'sum': 9.934107481068821e-09, 'n': 33.0, 'dnan': 0.0, 'argm': (1, 0)}
knn discrepancies for size=11: {'abs': 9.934107481068821e-09, 'rel': 3.110919082332138e-08, 'sum': 9.934107481068821e-09, 'n': 3.0, 'dnan': 0.0, 'argm': (0, 0)}

Export to ONNX

The module cannot be exported as is because one operator torch.topk() expects a fixed number of neighbour but the model makes it variable. This is case not supported by torch.export.export(). We need to isolate that part before exporting the model. It is done by replacing it with a custom op. This is automatically done by function trace_execution_piece_by_piece().

First step, we create two sets of inputs. A function will use this to infer the dynamic shapes.

First step: tracing intermediate outputs

used = [(knn50, Y40), (knn5, Y10), (knn1, Y1), (knn11, Y11)]
inputs = [
    (
        (
            torch.from_numpy(knn._mask_fit_X),
            torch.from_numpy(knn._valid_mask),
            torch.from_numpy(knn._fit_X.astype(np.float32)),
            y,
        ),
        {},
    )
    for knn, y in used
]

Then we trace the execution to capture every input and output of every submodule. The model implementation was refactored to introduce many tiny one and get a fine-grained evaluation of the exportability. We track messages such as -needs-more-inputs or -no-input. If any, we must provide the tracer more input to make sure every submodule receives enough data to guess dynamic shapes and export. When the model has control flow, we need more data to make sure every piece is used.

__main__                  TorchKNNImputer        <OK-4i-0>
..dist                    NanEuclidean           <OK-4i-0>
..columns[0]              ColProcessor           <OK-4i-0>
...._calc_impute          CalcImpute             <OK-4i-2>
......_weights            SubWeightMatrix        <OK-4i-2>
......_donors_idx         SubDonorsIdx           <OK-4i-2>
......_make_new_neights   MakeNewWeights         <OK-4i-2>
...._col_cond             ColProcessorCond       <OK-4i-0>
......_all_nan            ColProcessorAllNan     <OK-2i-0>
......_identity           ColProcessorIdentity   <OK-2i-may-need-more>
..columns[1]              ColProcessor           <OK-4i-0>
...._calc_impute          CalcImpute             <OK-4i-may-need-more>
......_weights            SubWeightMatrix        <OK-4i-may-need-more>
......_donors_idx         SubDonorsIdx           <OK-4i-may-need-more>
......_make_new_neights   MakeNewWeights         <OK-4i-may-need-more>
...._col_cond             ColProcessorCond       <OK-4i-0>
......_all_nan            ColProcessorAllNan     <OK-2i-0>
......_identity           ColProcessorIdentity   <OK-2i-may-need-more>
..columns[2]              ColProcessor           <OK-4i-0>
...._calc_impute          CalcImpute             <OK-4i-may-need-more>
......_weights            SubWeightMatrix        <OK-4i-may-need-more>
......_donors_idx         SubDonorsIdx           <OK-4i-may-need-more>
......_make_new_neights   MakeNewWeights         <OK-4i-may-need-more>
...._col_cond             ColProcessorCond       <OK-4i-0>
......_all_nan            ColProcessorAllNan     <OK-2i-0>
......_identity           ColProcessorIdentity   <OK-2i-may-need-more>
.._make_dict_idx_map      MakeDictIdxMap         <OK-4i-0>

We need more so let’s add more.

def rotate(inputs, col=3):
    if isinstance(inputs, torch.Tensor):
        if len(inputs.shape) == 2 and inputs.shape[1] == 3:
            return torch.cat([inputs[:, 1:], inputs[:, :1]], axis=1)
        if len(inputs.shape) == 1 and inputs.shape[0] == 3:
            return torch.cat([inputs[1:], inputs[:1]], axis=0)
        return inputs
    if isinstance(inputs, tuple):
        return tuple(rotate(i, col=col) for i in inputs)
    if isinstance(inputs, list):
        return [rotate(i, col=col) for i in inputs]
    if isinstance(inputs, dict):
        return {k: rotate(v, col=col) for k, v in inputs.items()}
    raise TypeError(f"Unexpected type {type(inputs)}")


inputs = [*inputs, *rotate(inputs), *rotate(rotate(inputs))]

Let’s try again.

---------
__main__                  TorchKNNImputer        <OK-12i-0>
..dist                    NanEuclidean           <OK-12i-0>
..columns[0]              ColProcessor           <OK-12i-0>
...._calc_impute          CalcImpute             <OK-12i-2>
......_weights            SubWeightMatrix        <OK-12i-2>
......_donors_idx         SubDonorsIdx           <OK-12i-2>
......_make_new_neights   MakeNewWeights         <OK-12i-2>
...._col_cond             ColProcessorCond       <OK-12i-0>
......_all_nan            ColProcessorAllNan     <OK-6i-0>
......_identity           ColProcessorIdentity   <OK-6i-may-need-more>
..columns[1]              ColProcessor           <OK-12i-0>
...._calc_impute          CalcImpute             <OK-12i-10>
......_weights            SubWeightMatrix        <OK-12i-10>
......_donors_idx         SubDonorsIdx           <OK-12i-10>
......_make_new_neights   MakeNewWeights         <OK-12i-10>
...._col_cond             ColProcessorCond       <OK-12i-0>
......_all_nan            ColProcessorAllNan     <OK-6i-0>
......_identity           ColProcessorIdentity   <OK-6i-may-need-more>
..columns[2]              ColProcessor           <OK-12i-0>
...._calc_impute          CalcImpute             <OK-12i-6>
......_weights            SubWeightMatrix        <OK-12i-6>
......_donors_idx         SubDonorsIdx           <OK-12i-6>
......_make_new_neights   MakeNewWeights         <OK-12i-6>
...._col_cond             ColProcessorCond       <OK-12i-0>
......_all_nan            ColProcessorAllNan     <OK-6i-0>
......_identity           ColProcessorIdentity   <OK-6i-may-need-more>
.._make_dict_idx_map      MakeDictIdxMap         <OK-12i-0>

The dynamic shapes for the whole model:

print("dynamic shapes:")
print(trace.guess_dynamic_shapes())
dynamic shapes:
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}, {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)}), {})

The method try_export cannot infer all links between input shapes and output shapes for every submodule. The following function fills this gap.

shape_functions = {
    "NanEuclidean": {
        0: lambda *args, **kwargs: torch.empty(
            (args[0].shape[0], args[1].shape[0]), dtype=args[0].dtype
        )
    },
    "CalcImpute": {
        0: lambda *args, **kwargs: torch.empty((args[0].shape[0],), dtype=args[0].dtype)
    },
    "SubDonorsIdx": {
        0: lambda *args, **kwargs: torch.empty(
            (
                make_undefined_dimension(111111),  # args[0].shape[0]),
                make_undefined_dimension(min(args[0].shape[1], knn5.n_neighbors)),
            ),
            dtype=args[0].dtype,
        ),
        1: lambda *args, **kwargs: torch.empty(
            (
                make_undefined_dimension(111111),  # args[0].shape[0]),
                make_undefined_dimension(min(args[0].shape[1], knn5.n_neighbors)),
            ),
            dtype=torch.float32,
        ),
    },
    "MakeDictIdxMap": {
        0: lambda *args, **kwargs: torch.empty((args[0].shape[0],), dtype=args[1].dtype),
    },
    "ColProcessorCond": {
        0: lambda *args, **kwargs: torch.empty(args[0], dtype=args[0].dtype),
        1: lambda *args, **kwargs: torch.empty(
            make_undefined_dimension(0), args[1].shape[1], dtype=args[0].dtype
        ),
        2: lambda *args, **kwargs: torch.empty(
            (make_undefined_dimension(0),), dtype=args[0].dtype
        ),
    },
}

Then we we try to export piece by piece. We capture the standard output to avoid being overwhelmed and we use function onnx_diagnostic.torch_export_patches.torch_export_patches() to skip some errors with shape checking made by torch.

logging.disable(logging.CRITICAL)

with contextlib.redirect_stderr(io.StringIO()), torch_export_patches():
    ep = trace.try_export(
        exporter="fx",
        use_dynamic_shapes=True,
        exporter_kwargs=dict(strict=False),
        replace_by_custom_op=CustomOpStrategy.LOCAL,
        verbose=0,
        shape_functions=shape_functions,
        quiet=1,
    )

assert ep.status in (
    ep.status.OK,
    ep.status.OK_CHILDC,
), f"FAIL: {ep}\n-- report --\n{trace.get_export_report()}"
print(trace.get_export_report())
__main__                  TorchKNNImputer        OK_CHILDC -- ExportedProgram
..dist                    NanEuclidean           OK -- ExportedProgram
..columns[0]              ColProcessor           OK_CHILDC -- ExportedProgram
...._calc_impute          CalcImpute             OK_CHILDC -- ExportedProgram
......_weights            SubWeightMatrix        OK -- ExportedProgram
......_donors_idx         SubDonorsIdx           OK -- ExportedProgram
......_make_new_neights   MakeNewWeights         OK -- ExportedProgram
...._col_cond             ColProcessorCond       FAIL_CHILDC -- step=EXPORT, reason='Dynamo failed to run FX node with fake tensors: call_function cond(*(s2, GraphModule(), GraphModule(...'
......_all_nan            ColProcessorAllNan     OK -- ExportedProgram
......_identity           ColProcessorIdentity   OK -- ExportedProgram
..columns[1]              ColProcessor           OK_CHILDC -- ExportedProgram
...._calc_impute          CalcImpute             OK_CHILDC -- ExportedProgram
......_weights            SubWeightMatrix        OK -- ExportedProgram
......_donors_idx         SubDonorsIdx           OK -- ExportedProgram
......_make_new_neights   MakeNewWeights         OK -- ExportedProgram
...._col_cond             ColProcessorCond       FAIL_CHILDC -- step=EXPORT, reason='Dynamo failed to run FX node with fake tensors: call_function cond(*(s2, GraphModule(), GraphModule(...'
......_all_nan            ColProcessorAllNan     OK -- ExportedProgram
......_identity           ColProcessorIdentity   OK -- ExportedProgram
..columns[2]              ColProcessor           OK_CHILDC -- ExportedProgram
...._calc_impute          CalcImpute             OK_CHILDC -- ExportedProgram
......_weights            SubWeightMatrix        OK -- ExportedProgram
......_donors_idx         SubDonorsIdx           OK -- ExportedProgram
......_make_new_neights   MakeNewWeights         OK -- ExportedProgram
...._col_cond             ColProcessorCond       FAIL_CHILDC -- step=EXPORT, reason='Dynamo failed to run FX node with fake tensors: call_function cond(*(s2, GraphModule(), GraphModule(...'
......_all_nan            ColProcessorAllNan     OK -- ExportedProgram
......_identity           ColProcessorIdentity   OK -- ExportedProgram
.._make_dict_idx_map      MakeDictIdxMap         OK -- ExportedProgram

OK means the module is exportable. OK_CHILDC means the module can be exported after its submodules are replaced by custom ops. It works except for the topk function. FAIL means the submodule cannot be exported at all but that module is simple enough and its ONNX conversion can be provided.

Final step

We first start by running the decompositions on every exported program.

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    for t in trace:
        if t.exporter_status.exported is None:
            print(f"[run_decompositions] {t.dot_name} - skipped")
            continue
        print(f"[run_decompositions] {t.dot_name}")
        t.exporter_status.exported = t.exporter_status.exported.run_decompositions({})
[run_decompositions]  M:__main__-TorchKNNImputer
[run_decompositions] .. M:dist-NanEuclidean
[run_decompositions] .. M:columns[0]-ColProcessor
[run_decompositions] .... M:_calc_impute-CalcImpute
[run_decompositions] ...... M:_weights-SubWeightMatrix
[run_decompositions] ...... M:_donors_idx-SubDonorsIdx
[run_decompositions] ...... M:_make_new_neights-MakeNewWeights
[run_decompositions] .... M:_col_cond-ColProcessorCond - skipped
[run_decompositions] ...... M:_all_nan-ColProcessorAllNan
[run_decompositions] ...... M:_identity-ColProcessorIdentity
[run_decompositions] .. M:columns[1]-ColProcessor
[run_decompositions] .... M:_calc_impute-CalcImpute
[run_decompositions] ...... M:_weights-SubWeightMatrix
[run_decompositions] ...... M:_donors_idx-SubDonorsIdx
[run_decompositions] ...... M:_make_new_neights-MakeNewWeights
[run_decompositions] .... M:_col_cond-ColProcessorCond - skipped
[run_decompositions] ...... M:_all_nan-ColProcessorAllNan
[run_decompositions] ...... M:_identity-ColProcessorIdentity
[run_decompositions] .. M:columns[2]-ColProcessor
[run_decompositions] .... M:_calc_impute-CalcImpute
[run_decompositions] ...... M:_weights-SubWeightMatrix
[run_decompositions] ...... M:_donors_idx-SubDonorsIdx
[run_decompositions] ...... M:_make_new_neights-MakeNewWeights
[run_decompositions] .... M:_col_cond-ColProcessorCond - skipped
[run_decompositions] ...... M:_all_nan-ColProcessorAllNan
[run_decompositions] ...... M:_identity-ColProcessorIdentity
[run_decompositions] .. M:_make_dict_idx_map-MakeDictIdxMap

Let’s run the conversion. We also check the conversion into ONNX is accurate. It is doable because every intermediate results were previously traced.

try:
    onx = trace.to_onnx_local(
        verbose=1,
        check_conversion_cls=dict(cls=ExtendedReferenceEvaluator, atol=1e-5, rtol=1e-5),
        inline=False,
    )
except Exception as e:
    print(f"The example is broken: {e}")
    onx = None
[to_onnx_local]  M:__main__-TorchKNNImputer - to_onnx_local
[to_onnx_local]  M:__main__-TorchKNNImputer - export child 'C_TorchKNNImputer_dist'
[to_onnx_local] .. M:dist-NanEuclidean - to_onnx_local
[to_onnx_local] .. M:dist-NanEuclidean - export starts C_TorchKNNImputer_dist
[to_onnx_local] .. M:dist-NanEuclidean - export done
[to_onnx_local] .. M:dist-NanEuclidean - run validation
[onnx_run_disc] .. M:dist-NanEuclidean run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s40x3,T1s50x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s40x3[nan,nan:AnanN80nans],T1s50x3[nan,nan:AnanN100nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s40x50[nan,nan:AnanN1333nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s40x50[0.0014024811098352075,7.827565670013428:A1.572636351330908N1333nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s10x3,T1s5x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s10x3[nan,nan:AnanN20nans],T1s5x3[nan,nan:AnanN10nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s10x5[nan,nan:AnanN33nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s10x5[0.0014024811098352075,4.854142665863037:A1.9105679926750085N33nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s1x3,T1s10x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s1x3[nan,nan:AnanN1nans],T1s10x3[nan,nan:AnanN1nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s1x10[1.891329288482666,3.3306772708892822:A2.516362464427948],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s1x10[1.891329288482666,3.3306772708892822:A2.516362464427948]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s1x3,T1s11x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s1x3[nan,nan:AnanN1nans],T1s11x3[nan,nan:AnanN1nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s1x11[0.6446343064308167,3.4043798446655273:A1.950129048390822],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s1x11[0.6446343064308167,3.4043798446655273:A1.950129048390822]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s40x3,T1s50x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s40x3[nan,nan:AnanN80nans],T1s50x3[nan,nan:AnanN100nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s40x50[nan,nan:AnanN1333nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s40x50[0.0014024811098352075,7.827565670013428:A1.572636351330908N1333nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s10x3,T1s5x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s10x3[nan,nan:AnanN20nans],T1s5x3[nan,nan:AnanN10nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s10x5[nan,nan:AnanN33nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s10x5[0.0014024811098352075,4.854142665863037:A1.9105679926750085N33nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s1x3,T1s10x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s1x3[nan,nan:AnanN1nans],T1s10x3[nan,nan:AnanN1nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s1x10[1.891329288482666,3.3306772708892822:A2.5163624405860903],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s1x10[1.891329288482666,3.3306772708892822:A2.5163624405860903]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s1x3,T1s11x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s1x3[nan,nan:AnanN1nans],T1s11x3[nan,nan:AnanN1nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s1x11[0.6446343064308167,3.4043798446655273:A1.950129048390822],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s1x11[0.6446343064308167,3.4043798446655273:A1.950129048390822]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s40x3,T1s50x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s40x3[nan,nan:AnanN80nans],T1s50x3[nan,nan:AnanN100nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s40x50[nan,nan:AnanN1333nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s40x50[0.0014024811098352075,7.827565670013428:A1.572636351330908N1333nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s10x3,T1s5x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s10x3[nan,nan:AnanN20nans],T1s5x3[nan,nan:AnanN10nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s10x5[nan,nan:AnanN33nans],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s10x5[0.0014024811098352075,4.854142665863037:A1.9105679926750085N33nans]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s1x3,T1s10x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s1x3[nan,nan:AnanN1nans],T1s10x3[nan,nan:AnanN1nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s1x10[1.891329288482666,3.330677032470703:A2.5163624405860903],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s1x10[1.891329288482666,3.330677032470703:A2.5163624405860903]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean run with ((T1s1x3,T1s11x3),{})
[onnx_run_disc] .. M:dist-NanEuclidean flattened into ((T1s1x3[nan,nan:AnanN1nans],T1s11x3[nan,nan:AnanN1nans]),{})
[onnx_run_disc] .. M:dist-NanEuclidean expecting (T1s1x11[0.6446343064308167,3.4043798446655273:A1.950129048390822],)
[onnx_run_disc] .. M:dist-NanEuclidean computing A1s1x11[0.6446343064308167,3.4043798446655273:A1.950129048390822]
[onnx_run_disc] .. M:dist-NanEuclidean diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] .. M:dist-NanEuclidean validation done
[to_onnx_local] .. M:dist-NanEuclidean - done
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] .. M:dist-NanEuclidean - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local]  M:__main__-TorchKNNImputer - export child 'C_TorchKNNImputer_columns_0_'
[to_onnx_local] .. M:columns[0]-ColProcessor - to_onnx_local
[to_onnx_local] .. M:columns[0]-ColProcessor - export child 'C_TorchKNNImputer_columns_0___calc_impute'
[to_onnx_local] .... M:_calc_impute-CalcImpute - to_onnx_local
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_0___calc_impute__weights'
[to_onnx_local] ...... M:_weights-SubWeightMatrix - to_onnx_local
[to_onnx_local] ...... M:_weights-SubWeightMatrix - export starts C_TorchKNNImputer_columns_0___calc_impute__weights
[to_onnx_local] ...... M:_weights-SubWeightMatrix - export done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - run validation
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x2,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x2[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x2[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x2[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s1x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s1x3[1.960836410522461,2.320211172103882:A2.0816839933395386],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s1x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s1x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s1x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s1x3[0.6446343064308167,1.6473352909088135:A1.0984085599581401],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s1x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s1x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x1,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x1[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x1[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x1[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x2,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x2[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x2[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x2[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix run with ((T1s0x3,),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix flattened into ((T1s0x3[empty],),{})
[onnx_run_disc] ...... M:_weights-SubWeightMatrix expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_weights-SubWeightMatrix computing A1s0x3[empty]
[onnx_run_disc] ...... M:_weights-SubWeightMatrix diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_weights-SubWeightMatrix validation done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - done
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_weights-SubWeightMatrix - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_0___calc_impute__donors_idx'
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - to_onnx_local
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - export starts C_TorchKNNImputer_columns_0___calc_impute__donors_idx
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - export done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - run validation
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x17,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x17[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x2,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x2[empty],T7s1[2,2:A2.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x2[empty],T1s0x2[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x2[empty],A1s0x2[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s1x9,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s1x9[1.960836410522461,3.3306772708892822:A2.585810595088535],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s1x3[0,7:A3.6666666666666665],T1s1x3[1.960836410522461,2.320211172103882:A2.0816839933395386])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s1x3[0,7:A3.6666666666666665],A1s1x3[1.960836410522461,2.320211172103882:A2.0816839933395386])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s1x10,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s1x10[0.6446343064308167,3.4043798446655273:A1.947733038663864],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s1x3[1,9:A5.333333333333333],T1s1x3[0.6446343064308167,1.6473352909088135:A1.0984085599581401])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s1x3[1,9:A5.333333333333333],A1s1x3[0.6446343064308167,1.6473352909088135:A1.0984085599581401])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x16,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x16[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x1,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x1[empty],T7s1[1,1:A1.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x1[empty],T1s0x1[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x1[empty],A1s0x1[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x10,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x10[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x11,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x11[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x17,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x17[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x2,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x2[empty],T7s1[2,2:A2.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x2[empty],T1s0x2[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x2[empty],A1s0x2[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x10,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x10[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx run with ((T1s0x11,T7s1),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx flattened into ((T1s0x11[empty],T7s1[3,3:A3.0]),{})
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx expecting (T7s0x3[empty],T1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx computing (A7s0x3[empty],A1s0x3[empty])
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx diff=abs=0, rel=0
[onnx_run_disc] ...... M:_donors_idx-SubDonorsIdx validation done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - done
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] ...... M:_donors_idx-SubDonorsIdx - discrepancies: abs=0, rel=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - export child 'C_TorchKNNImputer_columns_0___calc_impute__make_new_neights'
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - to_onnx_local
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - export starts C_TorchKNNImputer_columns_0___calc_impute__make_new_neights
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - export done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - run validation
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x2,T1s0x2,T1s0x2),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x2[empty],T1s0x2[empty],T1s0x2[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x2[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x2[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s1x3,T1s1x3,T1s1x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s1x3[1,1:A1.0],T1s1x3[0.2025328278541565,0.9459595680236816:A0.6728987495104471],T1s1x3[1.0,1.0:A1.0]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s1x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s1x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s1x3,T1s1x3,T1s1x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s1x3[1,1:A1.0],T1s1x3[-0.7442224621772766,0.23930585384368896:A-0.3183303078015645],T1s1x3[1.0,1.0:A1.0]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s1x3[1.0,1.0:A1.0],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s1x3[1.0,1.0:A1.0]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0.0, rel=0.0,amax=0,0
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x1,T1s0x1,T1s0x1),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x1[empty],T1s0x1[empty],T1s0x1[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x1[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x1[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x2,T1s0x2,T1s0x2),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x2[empty],T1s0x2[empty],T1s0x2[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x2[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x2[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights run with ((T7s0x3,T1s0x3,T1s0x3),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights flattened into ((T7s0x3[empty],T1s0x3[empty],T1s0x3[empty]),{})
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights expecting (T1s0x3[empty],)
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights computing A1s0x3[empty]
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights diff=abs=0, rel=0,amax=None
[onnx_run_disc] ...... M:_make_new_neights-MakeNewWeights validation done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - done
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0.0, rel=0.0,amax=0,0
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] ...... M:_make_new_neights-MakeNewWeights - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - export starts C_TorchKNNImputer_columns_0___calc_impute
[to_onnx_local] .... M:_calc_impute-CalcImpute - export done
[to_onnx_local] .... M:_calc_impute-CalcImpute - run validation
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with cls=ExtendedReferenceEvaluator on ModelProto
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x17,T7s1,T1s17,T9s17),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x17[empty],T7s1[3,3:A3.0],T1s17[-1.0490533113479614,1.4862864017486572:A0.05360487717039445],T9s17[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x2,T7s1,T1s2,T9s2),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x2[empty],T7s1[2,2:A2.0],T1s2[-1.0841875076293945,0.28655949234962463:A-0.39881400763988495],T9s2[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s1x9,T7s1,T1s9,T9s9),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s1x9[1.960836410522461,3.3306772708892822:A2.585810595088535],T7s1[3,3:A3.0],T1s9[-2.2622437477111816,0.9459595680236816:A-0.2179878850777944],T9s9[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s1[0.6728987693786621,0.6728987693786621:A0.6728987693786621],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s1[0.6728987693786621,0.6728987693786621:A0.6728987693786621]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0.0, rel=0.0,amax=0
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s1x10,T7s1,T1s10,T9s10),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s1x10[0.6446343064308167,3.4043798446655273:A1.947733038663864],T7s1[3,3:A3.0],T1s10[-1.0289477109909058,1.3091744184494019:A-0.06142359673976898],T9s10[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s1[-0.318330317735672,-0.318330317735672:A-0.318330317735672],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s1[-0.318330317735672,-0.318330317735672:A-0.318330317735672]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0.0, rel=0.0,amax=0
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x16,T7s1,T1s16,T9s16),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x16[empty],T7s1[3,3:A3.0],T1s16[-1.5512176752090454,1.0594156980514526:A-0.16428888589143753],T9s16[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x1,T7s1,T1s1,T9s1),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x1[empty],T7s1[1,1:A1.0],T1s1[0.5527158379554749,0.5527158379554749:A0.5527158379554749],T9s1[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x10,T7s1,T1s10,T9s10),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x10[empty],T7s1[3,3:A3.0],T1s10[-1.9389777183532715,1.426966905593872:A-0.5683467619121074],T9s10[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x11,T7s1,T1s11,T9s11),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x11[empty],T7s1[3,3:A3.0],T1s11[-1.745384693145752,1.8026071786880493:A0.23856408157470552],T9s11[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x17,T7s1,T1s17,T9s17),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x17[empty],T7s1[3,3:A3.0],T1s17[-1.9769785404205322,1.096831202507019:A-0.06690350366646752],T9s17[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x2,T7s1,T1s2,T9s2),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x2[empty],T7s1[2,2:A2.0],T1s2[-1.8390218019485474,0.472423791885376:A-0.6832990050315857],T9s2[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x10,T7s1,T1s10,T9s10),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x10[empty],T7s1[3,3:A3.0],T1s10[-1.918097734451294,2.5004260540008545:A-0.2677750276401639],T9s10[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute run with ((T1s0x11,T7s1,T1s11,T9s11),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute flattened into ((T1s0x11[empty],T7s1[3,3:A3.0],T1s11[-0.9313030242919922,0.7051165103912354:A-0.03593648805029013],T9s11[False,False:A0.0]),{})
[onnx_run_disc] .... M:_calc_impute-CalcImpute expecting (T1s0[empty],)
[onnx_run_disc] .... M:_calc_impute-CalcImpute computing A1s0[empty]
[onnx_run_disc] .... M:_calc_impute-CalcImpute diff=abs=0, rel=0,amax=None
[onnx_run_disc] .... M:_calc_impute-CalcImpute validation done
[to_onnx_local] .... M:_calc_impute-CalcImpute - done
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0.0, rel=0.0,amax=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0.0, rel=0.0,amax=0
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .... M:_calc_impute-CalcImpute - discrepancies: abs=0, rel=0,amax=None
[to_onnx_local] .. M:columns[0]-ColProcessor - export child 'C_TorchKNNImputer_columns_0___col_cond'
The example is broken: _col_cond:ColProcessorCond: exporter failed, status=<StatusExportCode.FAIL_CHILDC: 6>, reason='Dynamo failed to run FX node with fake tensors: call_function cond(*(s2, GraphModule(), GraphModule(), (FakeTensor(..., size=(s85, 3)), FakeTensor(..., size=(s57, s30)), FakeTensor(..., size=(s84, 3), dtype=torch.bool), FakeTensor(..., size=(s52, 3)), FakeTensor(..., size=(s94,), dtype=torch.int64), FakeTensor(..., size=(s29,), dtype=torch.int64), FakeTensor(..., size=(s72,), dtype=torch.bool), FakeTensor(..., size=(s79, s54)), FakeTensor(..., size=(s31,), dtype=torch.int64), FakeTensor(..., size=(s96,), dtype=torch.int64))), **{}): got AssertionError((0, 1)) ---  --- from user code: ---    File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 184, in _cond_op_wrapper ---     return cond_op(*args, **kwargs) ---  --- Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you\'re reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" --- [\'Traceback (most recent call last):\\n\', \'  File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/piece_by_piece.py", line 1587, in _try_export_no_bypass_export\\n    ep = torch.export.export(\\n         ^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export\\n    raise e\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export\\n    return _export(\\n           ^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\\n    raise e\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\\n    ep = fn(*args, **kwargs)\\n         ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\\n    return fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export\\n    ep = _export_for_training(\\n         ^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper\\n    raise e\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper\\n    ep = fn(*args, **kwargs)\\n         ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper\\n    return fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training\\n    export_artifact = export_func(\\n                      ^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in _non_strict_export\\n    aten_export_artifact = _to_aten_func(  # type: ignore[operator]\\n                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx\\n    gm, graph_signature = transform(_make_fx_helper)(\\n                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict\\n    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)\\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1679, in _make_fx_helper\\n    gm = make_fx(\\n         ^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped\\n    return make_fx_tracer.trace(f, *args)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2228, in trace\\n    return self._trace_inner(f, *args)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner\\n    t = dispatch_trace(\\n        ^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner\\n    return disable_fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn\\n    return fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace\\n    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]\\n            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1787, in trace\\n    res = super().trace(root, concrete_args)\\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace\\n    (self.create_arg(fn(*args)),),\\n                     ^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped\\n    out = f(*tensors)  # type:ignore[call-arg]\\n          ^^^^^^^^^^^\\n\', \'  File "<string>", line 1, in <lambda>\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1583, in wrapped_fn\\n    return tuple(flat_fn(*args))\\n                 ^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn\\n    tree_out = fn(*args, **kwargs)\\n               ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call\\n    out = mod(*args[params_len:], **kwargs)\\n          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\\n    return self.call_module(mod, forward, args, kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\\n    return Tracer.call_module(self, m, forward, args, kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\\n    ret_val = forward(*args, **kwargs)\\n              ^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\\n    return _orig_module_call(mod, *args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\\n    return self._call_impl(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\\n    return forward_call(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1885, in forward\\n    tree_out = mod(*args, **kwargs)\\n               ^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper\\n    return self.call_module(mod, forward, args, kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1857, in call_module\\n    return Tracer.call_module(self, m, forward, args, kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module\\n    ret_val = forward(*args, **kwargs)\\n              ^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward\\n    return _orig_module_call(mod, *args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl\\n    return self._call_impl(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl\\n    return forward_call(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/github/experimental-experiment/_doc/examples/plot_torch_sklearn_201.py", line 293, in forward\\n    X, dist_subset, receivers_idx = torch.cond(\\n                                    ^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 192, in cond\\n    return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)(\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 699, in _fn\\n    return fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1463, in __call__\\n    return self._torchdynamo_orig_callable(\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 624, in __call__\\n    return _compile(\\n           ^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1087, in _compile\\n    guarded_code = compile_inner(code, one_graph, hooks, transform)\\n                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function\\n    return function(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 778, in compile_inner\\n    return _compile_inner(code, one_graph, hooks, transform)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile_inner\\n    out_code = transform_code_object(code, transform)\\n               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1423, in transform_code_object\\n    transformations(instructions, code_options)\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 264, in _fn\\n    return fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 742, in transform\\n    tracer.run()\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3508, in run\\n    super().run()\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1345, in run\\n    while self.step():\\n          ^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1253, in step\\n    self.dispatch_table[inst.opcode](self, inst)\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 828, in wrapper\\n    return inner_fn(self, inst)\\n           ^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2254, in CALL_FUNCTION_EX\\n    self.call_function(fn, argsvars.items, kwargsvars)\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1179, in call_function\\n    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]\\n              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward\\n    return getattr(self.realize(), name)(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 75, in graph_break_as_hard_error\\n    return fn(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1129, in call_function\\n    return _call_function_and_unflatten_output(\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 210, in _call_function_and_unflatten_output\\n    flat_variable = wrap_fx_proxy(\\n                    ^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2490, in wrap_fx_proxy\\n    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2556, in wrap_fx_proxy_cls\\n    return _wrap_fx_proxy(\\n           ^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2654, in _wrap_fx_proxy\\n    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)\\n                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3302, in get_fake_value\\n    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3200, in get_fake_value\\n    ret_val = wrap_fake_exception(\\n              ^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 2700, in wrap_fake_exception\\n    return fn()\\n           ^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3201, in <lambda>\\n    lambda: run_node(tx.output, node, args, kwargs, nnmodule)\\n            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3409, in run_node\\n    raise RuntimeError(make_error_message(e)).with_traceback(\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 3368, in run_node\\n    return node.target(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 59, in __call__\\n    return super().__call__(pred, true_fn, false_fn, operands)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__\\n    return wrapper()\\n           ^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 497, in wrapper\\n    return self.dispatch(\\n           ^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 485, in dispatch\\n    return kernel(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 320, in maybe_run_autograd\\n    return self(*args, **kwargs)\\n           ^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 59, in __call__\\n    return super().__call__(pred, true_fn, false_fn, operands)\\n           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__\\n    return wrapper()\\n           ^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 497, in wrapper\\n    return self.dispatch(\\n           ^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 393, in dispatch\\n    result = handler(mode, *args, **kwargs)\\n             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 432, in cond_fake_tensor_mode\\n    merged_outs.append(_merge_tensors(true_out, false_out, mode))\\n                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 656, in _merge_tensors\\n    merged_stride: list[Union[int, torch.SymInt]] = _bound_stride(\\n                                                    ^^^^^^^^^^^^^^\\n\', \'  File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 624, in _bound_stride\\n    assert b_val == 0, (a_val, b_val)\\n           ^^^^^^^^^^\\n\', \'torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function cond(*(s2, GraphModule(), GraphModule(), (FakeTensor(..., size=(s85, 3)), FakeTensor(..., size=(s57, s30)), FakeTensor(..., size=(s84, 3), dtype=torch.bool), FakeTensor(..., size=(s52, 3)), FakeTensor(..., size=(s94,), dtype=torch.int64), FakeTensor(..., size=(s29,), dtype=torch.int64), FakeTensor(..., size=(s72,), dtype=torch.bool), FakeTensor(..., size=(s79, s54)), FakeTensor(..., size=(s31,), dtype=torch.int64), FakeTensor(..., size=(s96,), dtype=torch.int64))), **{}): got AssertionError((0, 1))\\n\\nfrom user code:\\n   File "~/vv/this312/lib/python3.12/site-packages/torch/_higher_order_ops/cond.py", line 184, in _cond_op_wrapper\\n    return cond_op(*args, **kwargs)\\n\\nSet TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you\\\'re reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"\\n\\n\']', a custom onnx converter must be provided for 'diag_lib::C_TorchKNNImputer_columns_0___col_cond', args=(T1s40x3,T1s27x17,T9s50x3,T1s50x3,T7s27,T7s27,T9s27,T1s40x50,T7s40,T7s17), kwargs={}, outputs=(T1s40x3,T1s0x17,T7s0)

Let’s save it.

if onx:
    onnx.save(onx, "plot_torch_sklearn_201.onnx")

We can also print it.

if onx:
    print(pretty_onnx(onx))

Validation again

def validate_onnx(size, sizey, onx, verbose: int = 1, use_ort: bool = False, col: int = 3):
    X, Y = get_xy(size, sizey, col=col)

    knn_imputer = sklearn.impute.KNNImputer(n_neighbors=3)
    knn_imputer.fit(X)

    model = TorchKNNImputer(knn_imputer)

    expected = p1 = knn_imputer.transform(Y)

    model_inputs = (
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X.astype(np.float32)),
        Y,
    )
    p2 = model.transform(*model_inputs)
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    if verbose:
        print(f"Torch Discrepancies for size={size} and sizey={sizey}, d={d}")

    input_names = [i.name for i in onx.graph.input]
    feeds = feeds0 = dict(zip(input_names, [t.numpy() for t in model_inputs]))

    if verbose:
        print("python: loading the model...")
    sess = ExtendedReferenceEvaluator(onx, verbose=0)
    if verbose:
        print("python: running the model...")
    got = sess.run(None, feeds)
    d = max_diff(p1, got[0])
    assert d["abs"] < 1e-5, f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}"
    if verbose:
        print(f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}")

    if use_ort:
        if verbose:
            print("onnxruntime: loading the model...")
        opts = onnxruntime.SessionOptions()
        opts.optimized_model_filepath = "plot_torch_sklearn_201.ort.onnx"
        opts.log_severity_level = 0
        opts.log_verbosity_level = 0
        sess = onnxruntime.InferenceSession(
            onx.SerializeToString(), opts, providers=["CPUExecutionProvider"]
        )
        if verbose:
            print("onnxruntime: running the model...")
        got = sess.run(None, feeds)
        d = max_diff(p1, got[0])
        assert d["abs"] < 1e-5, f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}"
        if verbose:
            print(f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}")

    model_inputs = (
        torch.from_numpy(knn_imputer._mask_fit_X),
        torch.from_numpy(knn_imputer._valid_mask),
        torch.from_numpy(knn_imputer._fit_X.astype(np.float32)),
        Y[1:2],
    )
    p1 = knn_imputer.transform(Y[1:2])
    p2 = model.transform(*model_inputs)
    d = max_diff(p1, p2)
    assert d["abs"] < 1e-5, f"Discrepancies for size={size} and sizey={sizey}, d={d}"
    feeds = dict(zip(input_names, [t.numpy() for t in model_inputs]))
    if verbose:
        print("ReferenceEvaluator: running the model...")
    got = sess.run(None, feeds)
    d = max_diff(p1, got[0])
    assert d["abs"] < 1e-5, f"ONNX Discrepancies for size={size} and sizey={sizey}, d={d}"
    if verbose:
        print("done")
    return feeds0, expected

This does not work yet.

if onx:
    feeds, expected = validate_onnx(5, 10, onx)
    validate_onnx(50, 40, onx)

ModelProto to python Code

We finally call function to_graph_builder_code to convert the onnx model into pseudo code if that helps moving that code to a converter library (sklearn-onnx).

if onx:
    code = to_graph_builder_code(onx)
    addition = (
        f"""

    feeds = {feeds!r}
    expected = {expected!r}
    ref = ExtendedReferenceEvaluator(model)
    got = ref.run(None, feeds)
    print("disrepancies:", max_diff(expected, got[0]))
    """.replace(
            "nan", "np.nan"
        )
        .replace("array", "np.array")
        .replace("float32", "np.float32")
    )
    code = f"""
    from experimental_experiment.reference import ExtendedReferenceEvaluator
    from experimental_experiment.helpers import max_diff
    {code}
    {addition}
    """
    print(code)

Let’s finally check it produces the same results.

if onx:
    with open("_plot_torch_sklearn_201_knnpy.py", "w") as f:
        f.write(code)

Let’s run it… It can be run this way.

subprocess.run([sys.executable, "_plot_torch_sklearn_201_knnpy.py"])

Total running time of the script: (0 minutes 24.578 seconds)

Related examples

102: Convolution and Matrix Multiplication

102: Convolution and Matrix Multiplication

102: Fuse kernels in a small Llama Model

102: Fuse kernels in a small Llama Model

301: Compares LLAMA exporters

301: Compares LLAMA exporters

Playground for big optimization pattern

Playground for big optimization pattern

101: A custom backend for torch

101: A custom backend for torch

Gallery generated by Sphinx-Gallery