import time
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import numpy as np
from onnx import ModelProto, save_model
from onnx.model_container import ModelContainer
import sklearn
from ..xbuilder import GraphBuilder, FunctionOptions, OptimizationOptions
[docs]
def to_onnx(
    model: sklearn.base.BaseEstimator,
    args: Optional[Sequence["torch.Tensor"]] = None,  # noqa: F821
    target_opset: Optional[Union[int, Dict[str, int]]] = None,
    as_function: bool = False,
    options: Optional[OptimizationOptions] = None,
    optimize: bool = True,
    filename: Optional[str] = None,
    inline: bool = False,
    input_names: Optional[Sequence[str]] = None,
    output_names: Optional[List[str]] = None,
    large_model: bool = False,
    verbose: int = 0,
    return_builder: bool = False,
    raise_list: Optional[Set[str]] = None,
    external_threshold: int = 1024,
    return_optimize_report: bool = False,
    function_options: Optional[FunctionOptions] = None,
) -> Union[
    Union[ModelProto, ModelContainer],
    Tuple[Union[ModelProto, ModelContainer], GraphBuilder],
]:
    """
    Exports a :epkg:`scikit-learn` model into ONNX.
    :param model: estimator
    :param args: input arguments
    :param kwargs: keyword attributes
    :param input_names: input names
    :param target_opset: targeted opset or targeted opsets as a dictionary
    :param as_function: export as a ModelProto or a FunctionProto
    :param options: optimization options
    :param verbose: verbosity level
    :param return_builder: returns the builder as well
    :param raise_list: the builder stops any time a name falls into that list,
        this is a debbuging tool
    :param optimize: optimize the model before exporting into onnx
    :param large_model: if True returns a :class:`onnx.model_container.ModelContainer`,
        it lets the user to decide later if the weights should be part of the model
        or saved as external weights
    :param external_threshold: if large_model is True, every tensor above this limit
        is stored as external
    :param return_optimize_report: returns statistics on the optimization as well
    :param filename: if specified, stores the model into that file
    :param inline: inline the model before converting to onnx, this is done before
            any optimization takes place
    :param export_options: to apply differents options before to get the exported program
    :param function_options: to specify what to do with the initializers in local functions,
        add them as constants or inputs
    :param output_names: to rename the output names
    :return: onnx model
    """
    assert isinstance(model, sklearn.base.BaseEstimator), f"Unexpected model type {type(model)}"
    import skl2onnx
    if output_names is None:
        if hasattr(model, "get_feature_names_out"):
            output_names = model.get_feature_names_out()
    if args is None:
        if hasattr(model, "n_features_in_"):
            n = model.n_features_in_
        else:
            raise NotImplementedError(
                f"Unable to guess the number of input features for model type {type(model)}"
            )
        args = np.random.randn(2, n).astype(np.float32)
    if isinstance(
        model,
        (
            sklearn.pipeline.Pipeline,
            sklearn.pipeline.FeatureUnion,
            sklearn.compose.ColumnTransformer,
            sklearn.compose.TransformedTargetRegressor,
        ),
    ):
        raise NotImplementedError(f"not implemented yet for {type(model)}")
    add_stats = {}
    begin = time.perf_counter()
    if verbose:
        print(f"[skl.to_onnx] convert {model.__class__.__name__}")
    proto = skl2onnx.to_onnx(
        model,
        args[0],
        target_opset=target_opset,
        options={"zipmap": False} if sklearn.base.is_classifier(model) else None,
        verbose=max(verbose - 1, 0),
    )
    t = time.perf_counter()
    add_stats["time_export"] = t - begin
    add_stats[f"time_export_{model.__class__.__name__}"] = t - begin
    begin = t
    if verbose:
        print(f"[skl.to_onnx] builds {model.__class__.__name__}")
    builder = GraphBuilder(
        target_opset_or_existing_proto=proto,
        as_function=as_function,
        optimization_options=options,
        args=args,
        kwargs=None,
        verbose=verbose,
        raise_list=raise_list,
        graph_module=model,
        output_names=output_names,
    )
    if input_names:
        renames = dict(zip(builder.input_names, input_names))
        if verbose:
            print(f"[skl.to_onnx] renames {renames}")
        builder.rename_names(renames)
    t = time.perf_counter()
    add_stats["time_builder"] = t - begin
    add_stats[f"time_builder_{model.__class__.__name__}"] = t - begin
    begin = t
    if verbose:
        print(f"[skl.to_onnx] make_proto for {model.__class__.__name__}")
    onx, stats = builder.to_onnx(
        optimize=optimize,
        large_model=large_model,
        external_threshold=external_threshold,
        return_optimize_report=True,
        inline=inline,
        function_options=function_options,
    )
    t = time.perf_counter()
    add_stats["time_builder_to_onnx"] = t - begin
    add_stats[f"time_builder_to_onnx_{model.__class__.__name__}"] = t - begin
    begin = time.perf_counter()
    if verbose:
        print(f"[skl.to_onnx] done {model.__class__.__name__}")
    all_stats = dict(builder=builder.statistics_)
    if stats:
        add_stats["optimization"] = stats
    t = time.perf_counter()
    add_stats["time_export_to_onnx"] = t - begin
    if verbose:
        proto = onx if isinstance(onx, ModelProto) else onx.model_proto
        print(
            f"[to_onnx] to_onnx done in {t - begin}s "
            f"and {len(proto.graph.node)} nodes, "
            f"{len(proto.graph.initializer)} initializers, "
            f"{len(proto.graph.input)} inputs, "
            f"{len(proto.graph.output)} outputs"
        )
        if verbose >= 10:
            print(builder.get_debug_msg())
    if filename:
        if isinstance(onx, ModelProto):
            save_model(onx, filename)
        else:
            onx.save(filename, all_tensors_to_one_file=True)
    all_stats.update(add_stats)
    if return_builder:
        return (onx, builder, all_stats) if return_optimize_report else (onx, builder)
    return (onx, all_stats) if return_optimize_report else onx