import datetime
import inspect
import os
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time
import numpy as np
import onnx
import onnxscript
import onnxscript.rewriter.ort_fusions as ort_fusions
import torch
from ..export import CoupleInputsDynamicShapes
from ..helpers import max_diff, string_type, string_diff
from ..helpers.helper import flatten_object
from ..helpers.rt_helper import make_feeds
from ..helpers.torch_helper import to_any, torch_deepcopy
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from ..tasks import random_input_kwargs
from ..torch_export_patches import torch_export_patches
from ..torch_export_patches.patch_inputs import use_dyn_not_str
from .hghub import get_untrained_model_with_inputs
[docs]
def empty(value: Any) -> bool:
    """Tells if the value is empty."""
    if isinstance(value, (str, list, dict, tuple, set)):
        return not bool(value)
    if value is None:
        return True
    return False 
[docs]
def split_args_kwargs(inputs: Any) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    """Splits into args, kwargs."""
    if isinstance(inputs, dict):
        return (), inputs
    if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[1], dict):
        return inputs
    assert isinstance(inputs, tuple), f"Unexpected inputs {string_type(inputs)}"
    return inputs, {} 
def _make_folder_name(
    model_id: str,
    exporter: Optional[str],
    optimization: Optional[str] = None,
    dtype: Optional[Union[str, torch.dtype]] = None,
    device: Optional[Union[str, torch.device]] = None,
    subfolder: Optional[str] = None,
) -> str:
    "Creates a filename unique based on the given options."
    els = [model_id.replace("/", "_")]
    if subfolder:
        els.append(subfolder.replace("/", "_"))
    if exporter:
        els.append(exporter)
    if optimization:
        els.append(optimization)
    if dtype is not None and dtype:
        stype = dtype if isinstance(dtype, str) else str(dtype)
        stype = stype.replace("float", "f").replace("uint", "u").replace("int", "i")
        els.append(stype)
    if device is not None and device:
        sdev = device if isinstance(device, str) else str(device)
        sdev = sdev.lower()
        if "cpu" in sdev:
            sdev = "cpu"
        elif "cuda" in sdev:
            sdev = "cuda"
        else:
            raise AssertionError(f"unexpected value for device={device}, sdev={sdev!r}")
        els.append(sdev)
    return "-".join(els)
[docs]
def version_summary() -> Dict[str, Union[int, float, str]]:
    """
    Example:
    .. runpython::
        :showcode:
        import pprint
        from onnx_diagnostic.torch_models.validate import version_summary
        pprint.pprint(version_summary())
    """
    import numpy
    summary: Dict[str, Union[int, float, str]] = {
        "version_torch": torch.__version__,
        "version_numpy": numpy.__version__,
    }
    try:
        import scipy
        summary["version_scipy"] = getattr(scipy, "__version__", "?")
    except ImportError:
        pass
    try:
        import transformers
        summary["version_transformers"] = getattr(transformers, "__version__", "?")
    except ImportError:
        pass
    try:
        import onnx
        summary["version_onnx"] = getattr(onnx, "__version__", "?")
    except ImportError:
        pass
    try:
        import onnxscript
        summary["version_onnxscript"] = getattr(onnxscript, "__version__", "?")
    except ImportError:
        pass
    try:
        import onnxruntime
        summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?")
    except ImportError:
        pass
    try:
        import onnx_ir
        summary["version_onnx_ir"] = getattr(onnx_ir, "__version__", "?")
    except ImportError:
        pass
    import onnx_diagnostic
    summary["version_onnx_diagnostic"] = onnx_diagnostic.__version__
    summary["version_date"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
    return summary 
def _quiet_or_not_quiet(
    quiet: bool,
    suffix: str,
    summary: Dict[str, Any],
    data: Optional[Dict[str, Any]],
    fct: Callable,
    repeat: int = 1,
    warmup: int = 0,
) -> Any:
    begin = time.perf_counter()
    if quiet:
        try:
            res = fct()
            summary[f"time_{suffix}"] = time.perf_counter() - begin
            if warmup + repeat == 1:
                return res
        except Exception as e:
            summary[f"ERR_{suffix}"] = str(e)
            summary[f"time_{suffix}"] = time.perf_counter() - begin
            if data is None:
                return {f"ERR_{suffix}": e}
            data[f"ERR_{suffix}"] = e
            return None
    else:
        res = fct()
    summary[f"time_{suffix}"] = time.perf_counter() - begin
    if warmup + repeat > 1:
        if suffix == "run":
            res = torch_deepcopy(res)
        summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
        summary[f"{suffix}_warmup"] = warmup
        summary[f"{suffix}_repeat"] = repeat
        for _w in range(max(0, warmup - 1)):
            t = fct()
            summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
        summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
        times = []
        for _r in range(repeat):
            begin = time.perf_counter()
            t = fct()
            times.append(time.perf_counter() - begin)
        a = np.array(times)
        summary[f"time_{suffix}_latency"] = a.mean()
        summary[f"time_{suffix}_latency_std"] = a.std()
        summary[f"time_{suffix}_latency_min"] = a.min()
        summary[f"time_{suffix}_latency_min"] = a.max()
    return res
[docs]
def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
    """Shrinks the configuration before it gets added to the information to log."""
    new_cfg = {}
    for k, v in cfg.items():
        new_cfg[k] = (
            v
            if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
            else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
        )
    return new_cfg 
[docs]
def validate_model(
    model_id: str,
    task: Optional[str] = None,
    do_run: bool = False,
    exporter: Optional[str] = None,
    do_same: bool = False,
    verbose: int = 0,
    dtype: Optional[Union[str, torch.dtype]] = None,
    device: Optional[Union[str, torch.device]] = None,
    same_as_pretrained: bool = False,
    use_pretrained: bool = False,
    optimization: Optional[str] = None,
    quiet: bool = False,
    patch: Union[bool, str, Dict[str, bool]] = False,
    rewrite: bool = False,
    stop_if_static: int = 1,
    dump_folder: Optional[str] = None,
    drop_inputs: Optional[List[str]] = None,
    ortfusiontype: Optional[str] = None,
    input_options: Optional[Dict[str, Any]] = None,
    model_options: Optional[Dict[str, Any]] = None,
    subfolder: Optional[str] = None,
    opset: Optional[int] = None,
    runtime: str = "onnxruntime",
    repeat: int = 1,
    warmup: int = 0,
    inputs2: int = 1,
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
    """
    Validates a model.
    The function can also be called through the command line
    :ref:`l-cmd-validate`.
    :param model_id: model id to validate
    :param task: task used to generate the necessary inputs,
        can be left empty to use the default task for this model
        if it can be determined
    :param do_run: checks the model works with the defined inputs
    :param exporter: exporter the model using this exporter,
        available list: ``export-strict``, ``export-nostrict``, ...
        see below
    :param do_same: checks the discrepancies of the exported model
    :param verbose: verbosity level
    :param dtype: uses this dtype to check the model
    :param device: do the verification on this device
    :param same_as_pretrained: use a model equivalent to the trained,
        this is not always possible
    :param use_pretrained: use the trained model, not the untrained one
    :param optimization: optimization to apply to the exported model,
        depend on the the exporter
    :param quiet: if quiet, catches exception if any issue
    :param patch: applies patches (``patch_transformers=True, path_diffusers=True``)
        if True before exporting
        see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`,
        a string can be used to specify only one of them
    :param rewrite: applies known rewriting (``patch_transformers=True``) before exporting,
        see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
    :param stop_if_static: stops if a dynamic dimension becomes static,
        see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
    :param dump_folder: dumps everything in a subfolder of this one
    :param drop_inputs: drops this list of inputs (given their names)
    :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
        it accepts multiple values separated by ``|``,
        see :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`
    :param input_options: additional options to define the dummy inputs
        used to export
    :param model_options: additional options when creating the model such as
        ``num_hidden_layers`` or ``attn_implementation``
    :param subfolder: version or subfolders to uses when retrieving a model id
    :param opset: onnx opset to use for the conversion
    :param runtime: onnx runtime to use to check about discrepancies,
        only if `do_run` is true
    :param repeat: number of time to measure the model
    :param warmup: warmup the model first
    :param inputs2: checks that the second set of inputs is reunning as well,
        this ensures that the model does support dynamism, the value is used
        as an increment to the first set of values (added to dimensions)
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    The following environment variables can be used to print out some
    information:
    * ``PRINT_CONFIG``: prints the model configuration
    The following exporters are available:
    * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
    * ``onnx-dynamo``: run :func:`torch.onnx.export` (..., dynamo=True),
      models can be optimized with ``optimization`` in ``("ir", "os_ort")``
    * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
    * ``custom``: custom exporter (see :epkg:`experimental-experiment`),
      models can be optimized with ``optimization`` in
      ``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
    The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
    exported model returns the same outputs as the original one, otherwise,
    :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
    """
    if isinstance(patch, bool):
        patch_kwargs = (
            dict(patch_transformers=True, patch_diffusers=True, patch=True)
            if patch
            else dict(patch=False)
        )
    elif isinstance(patch, str):
        patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}}  # noqa: C420
    else:
        assert isinstance(patch, dict), f"Unable to interpret patch={patch!r}"
        patch_kwargs = patch.copy()
        if "patch" not in patch_kwargs:
            if any(patch_kwargs.values()):
                patch_kwargs["patch"] = True
    assert not rewrite or patch_kwargs.get("patch", False), (
        f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
        f"patch must be True to enable rewriting, "
        f"if --no-patch was specified on the command line, --no-rewrite must be added."
    )
    summary = version_summary()
    summary.update(
        dict(
            version_model_id=model_id,
            version_do_run=str(do_run),
            version_dtype=str(dtype or ""),
            version_device=str(device or ""),
            version_same_as_pretrained=str(same_as_pretrained),
            version_use_pretrained=str(use_pretrained),
            version_optimization=optimization or "",
            version_quiet=str(quiet),
            version_patch=str(patch),
            version_patch_kwargs=str(patch_kwargs).replace(" ", ""),
            version_rewrite=str(rewrite),
            version_dump_folder=dump_folder or "",
            version_drop_inputs=str(list(drop_inputs or "")),
            version_ortfusiontype=ortfusiontype or "",
            version_stop_if_static=str(stop_if_static),
            version_exporter=exporter or "",
            version_runtime=runtime,
            version_inputs2=inputs2,
        )
    )
    if opset:
        summary["version_opset"] = opset
    folder_name = None
    if dump_folder:
        folder_name = _make_folder_name(
            model_id, exporter, optimization, dtype=dtype, device=device, subfolder=subfolder
        )
        dump_folder = os.path.join(dump_folder, folder_name)
        if not os.path.exists(dump_folder):
            os.makedirs(dump_folder)
        summary["dump_folder"] = dump_folder
        summary["dump_folder_name"] = folder_name
        if verbose:
            print(f"[validate_model] dump into {folder_name!r}")
    if verbose:
        if subfolder:
            print(f"[validate_model] validate model id {model_id!r}, subfolder={subfolder!r}")
        else:
            print(f"[validate_model] validate model id {model_id!r}")
        if model_options:
            print(f"[validate_model] model_options={model_options!r}")
        print(f"[validate_model] get dummy inputs with input_options={input_options}...")
        print(
            f"[validate_model] rewrite={rewrite}, patch_kwargs={patch_kwargs}, "
            f"stop_if_static={stop_if_static}"
        )
        print(f"[validate_model] exporter={exporter!r}, optimization={optimization!r}")
        print(f"[validate_model] dump_folder={dump_folder!r}")
        summary["model_id"] = model_id
        summary["model_subfolder"] = subfolder or ""
    iop = input_options or {}
    mop = model_options or {}
    data = _quiet_or_not_quiet(
        quiet,
        "create",
        summary,
        None,
        (
            lambda mid=model_id, v=verbose, task=task, uptr=use_pretrained, tr=same_as_pretrained, iop=iop, sub=subfolder, i2=inputs2: (  # noqa: E501
                get_untrained_model_with_inputs(
                    mid,
                    verbose=v,
                    task=task,
                    use_pretrained=uptr,
                    same_as_pretrained=tr,
                    inputs_kwargs=iop,
                    model_kwargs=mop,
                    subfolder=sub,
                    add_second_input=i2,
                )
            )
        ),
    )
    assert not inputs2 or "inputs2" in data, (
        f"inputs2 is True but second set is missing in data for "
        f"model id {model_id!r}: {sorted(data)}"
    )
    if exporter == "modelbuilder":
        # Models used with ModelBuilder do not like batch size > 1.
        # Let's change that.
        for k in ["inputs", "inputs2"]:
            if k not in data:
                continue
            if verbose:
                print(f"[validate_model] set batch=1 for data[{k!r}]")
                print(f"[validate_model] batch=1 === {string_type(data[k], with_shape=True)}")
            cpl = CoupleInputsDynamicShapes(
                tuple(), data[k], dynamic_shapes=data["dynamic_shapes"]
            )
            data[k] = cpl.change_dynamic_dimensions(
                desired_values=dict(batch=1), only_desired=True
            )
            if verbose:
                print(f"[validate_model] batch=1 --> {string_type(data[k], with_shape=True)}")
    data["input_options"] = iop
    data["model_options"] = mop
    data["model_dump_folder"] = dump_folder
    if dtype:
        data["model_dtype"] = dtype if isinstance(dtype, str) else str(dtype)
    if device:
        data["model_device"] = str(device)
    if opset:
        data["model_opset"] = opset
    if "rewrite" in data:
        if rewrite:
            summary["model_rewrite"] = str(data["rewrite"])
            if verbose:
                print(f"[validate_model] model_rewrite={summary['model_rewrite']}")
        else:
            del data["rewrite"]
            if verbose:
                print("[validate_model] no rewrite")
    if os.environ.get("PRINT_CONFIG", "0") in (1, "1"):
        print("[validate_model] -- PRINT CONFIG")
        print("-- type(config)", type(data["configuration"]))
        print(data["configuration"])
        print("[validate_model] -- END PRINT CONFIG")
    if iop:
        summary["input_options"] = str(iop)
    if mop:
        summary["model_options"] = str(mop)
    if "ERR_create" in summary:
        return summary, data
    if drop_inputs:
        if verbose:
            print(f"[validate_model] -- drop inputs: {drop_inputs!r}")
            print(f"[validate_model] current inputs: {string_type(data['inputs'])}")
            print(
                f"[validate_model] current dynnamic_shapes: "
                f"{string_type(data['dynamic_shapes'])}"
            )
        data["inputs"], data["dynamic_shapes"] = filter_inputs(
            data["inputs"],
            drop_names=drop_inputs,
            model=data["model"],
            dynamic_shapes=data["dynamic_shapes"],
        )
        if verbose:
            print(f"[validate_model] new inputs: {string_type(data['inputs'])}")
            print(f"[validate_model] new dynamic_hapes: {string_type(data['dynamic_shapes'])}")
        if inputs2:
            assert (
                "inputs2" in data
            ), "Cannot test a second set of inputs as it was not defined."
            data["inputs2"], _ = filter_inputs(
                data["inputs2"],
                drop_names=drop_inputs,
                model=data["model"],
                dynamic_shapes=data["dynamic_shapes"],
            )
    if not empty(dtype):
        if isinstance(dtype, str):
            dtype = getattr(torch, dtype)
        if verbose:
            print(f"[validate_model] dtype conversion to {dtype}")
        data["model"] = to_any(data["model"], dtype)  # type: ignore
        data["inputs"] = to_any(data["inputs"], dtype)  # type: ignore
        summary["model_dtype"] = str(dtype)
        if "inputs2" in data:
            data["inputs2"] = to_any(data["inputs2"], dtype)  # type: ignore
    if not empty(device):
        if verbose:
            print(f"[validate_model] device conversion to {device}")
        data["model"] = to_any(data["model"], device)  # type: ignore
        data["inputs"] = to_any(data["inputs"], device)  # type: ignore
        summary["model_device"] = str(device)
        if "inputs2" in data:
            data["inputs2"] = to_any(data["inputs2"], device)  # type: ignore
    for k in ["task", "size", "n_weights"]:
        summary[f"model_{k.replace('_','')}"] = data[k]
    summary["model_inputs_options"] = str(input_options or "")
    summary["model_inputs"] = string_type(data["inputs"], with_shape=True)
    summary["model_shapes"] = string_type(data["dynamic_shapes"])
    summary["model_class"] = data["model"].__class__.__name__
    summary["model_module"] = str(data["model"].__class__.__module__)
    if summary["model_module"] in sys.modules:
        summary["model_file"] = str(sys.modules[summary["model_module"]].__file__)  # type: ignore[index]
    summary["model_config_class"] = data["configuration"].__class__.__name__
    summary["model_config"] = str(
        shrink_config(
            data["configuration"]
            if type(data["configuration"]) is dict
            else data["configuration"].to_dict()
        )
    ).replace(" ", "")
    summary["model_id"] = model_id
    if verbose:
        print("[validate_model] --")
        print(f"[validate_model] task={data['task']}")
        print(f"[validate_model] size={data['size'] / 2**20} Mb")
        print(f"[validate_model] n_weights={data['n_weights'] / 1e6} millions parameters")
        for k, v in data["inputs"].items():
            print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}")
        for k, v in data["dynamic_shapes"].items():
            print(f"[validate_model] +SHAPE {k}={string_type(v)}")
        print("[validate_model] --")
    if do_run:
        _validate_do_run_model(
            data, summary, "inputs", "run", "run_expected", verbose, repeat, warmup, quiet
        )
        if inputs2:
            _validate_do_run_model(
                data, summary, "inputs2", "run2", "run_expected2", verbose, 1, 0, quiet
            )
    if exporter:
        print(
            f"[validate_model] -- export the model with {exporter!r}, "
            f"optimization={optimization!r}"
        )
        if patch_kwargs:
            if verbose:
                print(
                    f"[validate_model] applies patches before exporting "
                    f"stop_if_static={stop_if_static}"
                )
            with torch_export_patches(  # type: ignore
                stop_if_static=stop_if_static,
                verbose=max(0, verbose - 1),
                rewrite=data.get("rewrite", None),
                dump_rewriting=(os.path.join(dump_folder, "rewrite") if dump_folder else None),
                **patch_kwargs,  # type: ignore[arg-type]
            ) as modificator:
                data["inputs_export"] = modificator(data["inputs"])  # type: ignore
                if do_run:
                    _validate_do_run_exported_program(data, summary, verbose, quiet)
                # data is modified inplace
                summary_export, data = call_exporter(
                    exporter=exporter,
                    data=data,
                    quiet=quiet,
                    verbose=verbose,
                    optimization=optimization,
                    do_run=do_run,
                    dump_folder=dump_folder,
                )
        else:
            data["inputs_export"] = data["inputs"]
            # data is modified inplace
            summary_export, data = call_exporter(
                exporter=exporter,
                data=data,
                quiet=quiet,
                verbose=verbose,
                optimization=optimization,
                do_run=do_run,
                dump_folder=dump_folder,
            )
        summary.update(summary_export)
    dump_stats = None
    if dump_folder:
        if "exported_program" in data:
            ep = data["exported_program"]
            if verbose:
                print(f"[validate_model] -- dumps exported program in {dump_folder!r}...")
            with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f:
                f.write(str(ep))
            torch.export.save(ep, os.path.join(dump_folder, f"{folder_name}.pt2"))
            with open(os.path.join(dump_folder, f"{folder_name}.graph"), "w") as f:
                f.write(str(ep.graph))
            if verbose:
                print("[validate_model] done (dump ep)")
        if "onnx_program" in data:
            epo = data["onnx_program"]
            if verbose:
                print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
            onnx_filename = os.path.join(dump_folder, f"{folder_name}.onnx")
            begin = time.perf_counter()
            if isinstance(epo, onnx.model_container.ModelContainer):
                epo.save(onnx_filename, all_tensors_to_one_file=True)
            elif isinstance(epo, onnx.ModelProto):
                if os.path.exists(f"{onnx_filename}.data"):
                    os.remove(f"{onnx_filename}.data")
                onnx.save(
                    epo,
                    onnx_filename,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                    location=f"{os.path.split(onnx_filename)[-1]}.data",
                )
            else:
                epo.save(onnx_filename, external_data=True)
            duration = time.perf_counter() - begin
            if verbose:
                print(f"[validate_model] done (dump onnx) in {duration}")
            data["onnx_filename"] = onnx_filename
            summary["time_onnx_save"] = duration
        if verbose:
            print(f"[validate_model] dumps statistics in {dump_folder!r}...")
        dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
        with open(dump_stats, "w") as f:
            for k, v in sorted(summary.items()):
                f.write(f":{k}:{v};\n")
        if verbose:
            print("[validate_model] done (dump)")
    if not exporter or (
        not exporter.startswith(("onnx-", "custom-"))
        and exporter not in ("custom", "modelbuilder")
    ):
        if verbose:
            print("[validate_model] -- done (final)")
        if dump_stats:
            with open(dump_stats, "w") as f:
                for k, v in sorted(summary.items()):
                    f.write(f":{k}:{v};\n")
        return summary, data
    if do_run:
        summary_valid, data = validate_onnx_model(
            data=data,
            quiet=quiet,
            verbose=verbose,
            runtime=runtime,
            repeat=repeat,
            warmup=warmup,
            inputs2=inputs2,
        )
        summary.update(summary_valid)
    if ortfusiontype and "onnx_filename" in data:
        assert (
            "configuration" in data
        ), f"missing configuration in data, cannot run ort fusion for model_id={model_id}"
        config = data["configuration"]
        assert hasattr(
            config, "hidden_size"
        ), f"Missing attribute hidden_size in configuration {config}"
        hidden_size = config.hidden_size
        assert hasattr(
            config, "num_attention_heads"
        ), f"Missing attribute num_attention_heads in configuration {config}"
        num_attention_heads = config.num_attention_heads
        if ortfusiontype == "ALL":
            from onnxruntime.transformers.optimizer import MODEL_TYPES
            model_types = sorted(MODEL_TYPES)
        else:
            model_types = ortfusiontype.split("|")
        for model_type in model_types:
            flavour = f"ort{model_type}"
            summary[f"version_{flavour}_hidden_size"] = hidden_size
            summary[f"version_{flavour}_num_attention_heads"] = num_attention_heads
            begin = time.perf_counter()
            if verbose:
                print(f"[validate_model] run onnxruntime fusion for {model_type!r}")
            input_filename = data["onnx_filename"]
            output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx"
            ort_sum, ort_data = run_ort_fusion(
                input_filename,
                output_path,
                model_type=model_type,
                num_attention_heads=num_attention_heads,
                hidden_size=hidden_size,
            )
            summary.update(ort_sum)
            data.update(ort_data)
            data[f"onnx_filename_{flavour}"] = output_path
            duration = time.perf_counter() - begin
            summary[f"time_ortfusion_{flavour}"] = duration
            if verbose:
                print(
                    f"[validate_model] done {model_type!r} in {duration}, "
                    f"saved into {output_path!r}"
                )
            if do_run:
                summary_valid, data = validate_onnx_model(
                    data=data,
                    quiet=quiet,
                    verbose=verbose,
                    flavour=flavour,
                    runtime=runtime,
                    repeat=repeat,
                    warmup=warmup,
                    inputs2=inputs2,
                )
                summary.update(summary_valid)
    if verbose:
        print("[validate_model] -- done (final)")
    if dump_stats:
        with open(dump_stats, "w") as f:
            for k, v in sorted(summary.items()):
                f.write(f":{k}:{v};\n")
    return summary, data 
def _validate_do_run_model(
    data, summary, key, tag, expected_tag, verbose, repeat, warmup, quiet
):
    if verbose:
        print(f"[validate_model] -- run the model inputs={key!r}...")
        print(f"[validate_model] {key}={string_type(data[key], with_shape=True)}")
    # We make a copy of the input just in case the model modifies them inplace
    hash_inputs = string_type(data[key], with_shape=True)
    inputs = torch_deepcopy(data[key])
    model = data["model"]
    expected = _quiet_or_not_quiet(
        quiet,
        tag,
        summary,
        data,
        (lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
        repeat=repeat,
        warmup=warmup,
    )
    if f"ERR_{tag}" in summary:
        return summary, data
    summary[expected_tag] = string_type(expected, with_shape=True)
    if verbose:
        print(f"[validate_model] done ([{tag}])")
    data[expected_tag] = expected
    assert hash_inputs == string_type(data[key], with_shape=True), (
        f"The model execution did modified the inputs:\n"
        f"before: {hash_inputs}\n"
        f" after: {string_type(data[key], with_shape=True)}"
    )
def _validate_do_run_exported_program(data, summary, verbose, quiet):
    # We run a second time the model to check the patch did not
    # introduce any discrepancies
    if verbose:
        print("[validate_model] run patched model...")
        print(
            f"[validate_model] patched inputs="
            f"{string_type(data['inputs_export'], with_shape=True)}"
        )
    hash_inputs = string_type(data["inputs_export"], with_shape=True)
    # We make a copy of the input just in case the model modifies them inplace
    inputs = torch_deepcopy(data["inputs_export"])
    model = data["model"]
    expected = _quiet_or_not_quiet(
        quiet,
        "run_patched",
        summary,
        data,
        (lambda m=model, inp=inputs: m(**inp)),
    )
    if "ERR_run_patched" in summary:
        return summary, data
    disc = max_diff(data["run_expected"], expected)
    for k, v in disc.items():
        summary[f"disc_patched_{k}"] = str(v)
    if verbose:
        print("[validate_model] done (patched run)")
        print(f"[validate_model] patched discrepancies={string_diff(disc)}")
    assert hash_inputs == string_type(data["inputs_export"], with_shape=True), (
        f"The model execution did modified the inputs:\n"
        f"before: {hash_inputs}\n"
        f" after: {string_type(data['inputs_export'], with_shape=True)}"
    )
[docs]
def call_exporter(
    data: Dict[str, Any],
    exporter: str,
    quiet: bool = False,
    verbose: int = 0,
    optimization: Optional[str] = None,
    do_run: bool = False,
    dump_folder: Optional[str] = None,
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
    """
    Calls an exporter on a model;
    If a patch must be applied, it should be before this functions.
    :param data: dictionary with all the necessary inputs
    :param exporter: exporter to call
    :param quiet: catch exception or not
    :param verbose: verbosity
    :param optimization: optimization to do
    :param do_run: runs and compute discrepancies
    :param dump_folder: to dump additional information
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    """
    if exporter == "export" or exporter.startswith("export-"):
        # torch export
        summary, data = call_torch_export_export(
            exporter=exporter,
            data=data,
            quiet=quiet,
            verbose=verbose,
            optimization=optimization,
            do_run=do_run,
        )
        return summary, data
    if exporter.startswith("onnx-"):
        # torch export
        summary, data = call_torch_export_onnx(
            exporter=exporter,
            data=data,
            quiet=quiet,
            verbose=verbose,
            optimization=optimization,
        )
        return summary, data
    if exporter == "custom" or exporter.startswith("custom"):
        # torch export
        summary, data = call_torch_export_custom(
            exporter=exporter,
            data=data,
            quiet=quiet,
            verbose=verbose,
            optimization=optimization,
            dump_folder=dump_folder,
        )
        return summary, data
    if exporter == "modelbuilder":
        # torch export
        summary, data = call_torch_export_model_builder(
            exporter=exporter,
            data=data,
            quiet=quiet,
            verbose=verbose,
            optimization=optimization,
        )
        return summary, data
    raise NotImplementedError(
        f"export with {exporter!r} and optimization={optimization!r} not implemented yet, "
        f"exporter must startswith 'onnx-', 'custom', 'export', 'modelbuilder' "
        f"(onnx-dynamo, custom, export), optimization can 'ir', "
        f"'default', 'default+onnxruntime', "
        f"'default+onnxruntime+os_ort', 'ir', 'os_ort'"
    ) 
[docs]
def call_torch_export_export(
    data: Dict[str, Any],
    exporter: str,
    quiet: bool = False,
    verbose: int = 0,
    optimization: Optional[str] = None,
    do_run: bool = False,
):
    """
    Exports a model with :func:`torch.export.export`.
    If a patch must be applied, it should be before this functions.
    :param data: dictionary with all the necessary inputs, the dictionary must
        contains keys ``model`` and ``inputs_export``
    :param exporter: exporter to call
    :param quiet: catch exception or not
    :param verbose: verbosity
    :param optimization: optimization to do
    :param do_run: runs and compute discrepancies
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    """
    assert exporter in {
        "export",
        "export-strict",
        "export-nostrict",
    }, f"Unexpected value for exporter={exporter!r}"
    assert not optimization, f"No optimization is implemented for exporter={exporter!r}"
    assert "model" in data, f"model is missing from data: {sorted(data)}"
    assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
    summary: Dict[str, Union[str, int, float]] = {}
    strict = "-strict" in exporter
    args, kwargs = split_args_kwargs(data["inputs_export"])
    ds = data.get("dynamic_shapes", None)
    summary["export_exporter"] = exporter
    summary["export_optimization"] = optimization or ""
    summary["export_strict"] = strict
    summary["export_args"] = string_type(args, with_shape=True)
    summary["export_kwargs"] = string_type(kwargs, with_shape=True)
    summary["export_dynamic_shapes"] = string_type(ds)
    # There is an issue with DynamicShape [[],[]] becomes []
    dse = use_dyn_not_str(ds)
    # dse = CoupleInputsDynamicShapes(args, kwargs, ds).replace_string_by()
    summary["export_dynamic_shapes_export_export"] = string_type(dse)
    if verbose:
        print(
            f"[call_torch_export_export] exporter={exporter!r}, "
            f"strict={strict}, optimization={optimization!r}"
        )
        print(f"[call_torch_export_export] args={string_type(args, with_shape=True)}")
        print(f"[call_torch_export_export] kwargs={string_type(kwargs, with_shape=True)}")
        print(f"[call_torch_export_export] dynamic_shapes={string_type(ds)}")
        print(f"[call_torch_export_export] dynamic_shapes_export_export={string_type(dse)}")
        print("[call_torch_export_export] export...")
    model = data["model"]
    ep = _quiet_or_not_quiet(
        quiet,
        "export_export",
        summary,
        data,
        (
            lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: (
                torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s)
            )
        ),
    )
    if "ERR_export_export" in summary:
        return summary, data
    summary["export_graph_nodes"] = len(ep.graph.nodes)
    if verbose:
        print(
            f"[call_torch_export_export] done (export) "
            f"with {summary['export_graph_nodes']} nodes"
        )
    data["exported_program"] = ep
    if verbose > 1:
        print("[call_torch_export_export] -- ExportedProgram")
        print(ep)
        print("[call_torch_export_export] -- End of ExportedProgram")
    if do_run:
        # We check for discrepancies.
        if verbose:
            print("[validate_model] run exported model...")
            print(
                f"[validate_model] patched inputs="
                f"{string_type(data['inputs_export'], with_shape=True)}"
            )
        hash_inputs = string_type(data["inputs_export"], with_shape=True)
        # We make a copy of the input just in case the model modifies them inplace
        inputs = torch_deepcopy(data["inputs_export"])
        model = ep.module()
        expected = _quiet_or_not_quiet(
            quiet,
            "run_exported",
            summary,
            data,
            (lambda m=model, inputs=inputs: (model(**inputs))),
        )
        if "ERR_export_export" in summary:
            return summary, data
        disc = max_diff(data["run_expected"], expected)
        for k, v in disc.items():
            summary[f"disc_exported_{k}"] = str(v)
        if verbose:
            print("[validate_model] done (exported run)")
            print(f"[validate_model] exported discrepancies={string_diff(disc)}")
        assert hash_inputs == string_type(data["inputs_export"], with_shape=True), (
            f"The exported model execution did modified the inputs:\n"
            f"before: {hash_inputs}\n"
            f" after: {string_type(data['inputs_export'], with_shape=True)}"
        )
    return summary, data 
[docs]
def validate_onnx_model(
    data: Dict[str, Any],
    quiet: bool = False,
    verbose: int = 0,
    flavour: Optional[str] = None,
    runtime: str = "onnxruntime",
    repeat: int = 1,
    warmup: int = 0,
    inputs2: int = 1,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Verifies that an onnx model produces the same
    expected outputs. It uses ``data["onnx_filename]`` as the input
    onnx filename or ``data["onnx_filename_{flavour}]`` if *flavour*
    is specified.
    :param data: dictionary with all the necessary inputs, the dictionary must
        contains keys ``model`` and ``inputs_export``
    :param quiet: catch exception or not
    :param verbose: verbosity
    :param flavour: use a different version of the inputs
    :param runtime: onnx runtime to use, onnxruntime or torch
    :param repeat: run that number of times the model
    :param warmup: warmup the model
    :param inputs2: to validate the model on the second input set
        to make sure the exported model supports dynamism, the value is
        used as an increment added to the first set of inputs (added to dimensions)
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    """
    import onnxruntime
    def _mk(key):
        return f"{key}_{flavour}" if flavour else key
    summary: Dict[str, Any] = {}
    flat_inputs = flatten_object(data["inputs"], drop_keys=True)
    d = flat_inputs[0].get_device()
    providers = (
        ["CPUExecutionProvider"]
        if d < 0
        else ["CUDAExecutionProvider", "CPUExecutionProvider"]
    )
    input_data_key = f"onnx_filename_{flavour}" if flavour else "onnx_filename"
    if input_data_key in data:
        source = data[input_data_key]
        if not os.path.exists(source):
            if verbose:
                print(f"[validate_onnx_model] missing {source!r}")
            summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})"
            return summary, data
        summary[input_data_key] = source
        summary[_mk("onnx_size")] = os.stat(source).st_size
    else:
        assert not flavour, f"flavour={flavour!r}, the filename must be saved."
        assert (
            "onnx_program" in data
        ), f"onnx_program is missing from data which has {sorted(data)}"
        source = data["onnx_program"].model_proto.SerializeToString()
        assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb"
        summary[_mk("onnx_size")] = len(source)
    if verbose:
        print(
            f"[validate_onnx_model] verify onnx model with providers "
            f"{providers}..., flavour={flavour!r}"
        )
    if runtime != "onnxruntime":
        from ..reference import TorchOnnxEvaluator
    cls_runtime = (
        (
            lambda model, providers: onnxruntime.InferenceSession(
                (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
                providers=providers,
            )
        )
        if runtime == "onnxruntime"
        else (
            lambda model, providers, _cls_=TorchOnnxEvaluator: _cls_(  # type: ignore[misc]
                model, providers=providers, verbose=max(verbose - 1, 0)
            )
        )
    )
    sess = _quiet_or_not_quiet(
        quiet,
        _mk("onnx_ort_create"),
        summary,
        data,
        (lambda source=source, providers=providers: cls_runtime(source, providers)),
    )
    if f"ERR_{_mk('onnx_ort_create')}" in summary:
        return summary, data
    data[_mk("onnx_ort_sess")] = sess
    if verbose:
        print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}")
    keys = [("inputs", "run_expected", "")]
    if inputs2:
        keys.append(("inputs2", "run_expected2", "2"))
    for k_input, k_expected, suffix in keys:
        # make_feeds
        if verbose:
            print(f"[validate_onnx_model] -- make_feeds for {k_input!r}...")
            print(
                f"[validate_onnx_model] inputs={string_type(data[k_input], with_shape=True)}"
            )
        feeds = make_feeds(sess, data[k_input], use_numpy=True, check_flatten=False)
        if verbose:
            print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
        summary[_mk(f"onnx_ort_inputs{suffix}")] = string_type(feeds, with_shape=True)
        if verbose:
            print("[validate_onnx_model] done (make_feeds)")
        # run ort
        if verbose:
            print("[validate_onnx_model] run session...")
        got = _quiet_or_not_quiet(
            quiet,
            _mk(f"time_onnx_ort_run{suffix}"),
            summary,
            data,
            (lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
            repeat=repeat,
            warmup=warmup,
        )
        if f"ERR_{_mk(f'time_onnx_ort_run{suffix}')}" in summary:
            return summary, data
        summary[f"run_feeds_{k_input}"] = string_type(feeds, with_shape=True, with_device=True)
        summary[f"run_output_{k_input}"] = string_type(got, with_shape=True, with_device=True)
        if verbose:
            print("[validate_onnx_model] done (run)")
            print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}")
        # compute discrepancies
        disc = max_diff(data[k_expected], got, flatten=True)
        if verbose:
            print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
        for k, v in disc.items():
            summary[_mk(f"disc_onnx_ort_run{suffix}_{k}")] = v
    return summary, data 
[docs]
def call_torch_export_onnx(
    data: Dict[str, Any],
    exporter: str,
    quiet: bool = False,
    verbose: int = 0,
    optimization: Optional[str] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Exports a model into onnx.
    If a patch must be applied, it should be before this functions.
    :param data: dictionary with all the necessary inputs, the dictionary must
        contains keys ``model`` and ``inputs_export``
    :param exporter: exporter to call
    :param quiet: catch exception or not
    :param verbose: verbosity
    :param optimization: optimization to do
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    """
    available = {None, "", "ir", "os_ort"}
    assert (
        optimization in available
    ), f"unexpected value for optimization={optimization}, available={available}"
    assert exporter in {
        "onnx-dynamo",
        "onnx-script",
    }, f"Unexpected value for exporter={exporter!r}"
    assert "model" in data, f"model is missing from data: {sorted(data)}"
    assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
    summary: Dict[str, Union[str, int, float]] = {}
    dynamo = "dynamo" in exporter
    args, kwargs = split_args_kwargs(data["inputs_export"])
    ds = data.get("dynamic_shapes", None)
    if verbose:
        print(
            f"[call_torch_export_onnx] exporter={exporter!r}, "
            f"optimization={optimization!r}"
        )
        print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
        print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
        print(f"[call_torch_export_onnx] dynamic_shapes={string_type(ds)}")
        print("[call_torch_export_onnx] export...")
    summary["export_exporter"] = exporter
    summary["export_optimization"] = optimization or ""
    summary["export_dynamo"] = dynamo
    summary["export_args"] = string_type(args, with_shape=True)
    summary["export_kwargs"] = string_type(kwargs, with_shape=True)
    opset = data.get("model_opset", None)
    if opset:
        summary["export_opset"] = opset
    if dynamo:
        export_export_kwargs = dict(dynamo=True, dynamic_shapes=ds)
    else:
        export_export_kwargs = dict(
            dynamo=False,
            dynamic_axes={
                k: v
                for k, v in CoupleInputsDynamicShapes(args, kwargs, ds)  # type: ignore[arg-type]
                .replace_by_string()
                .items()
                if isinstance(v, dict)
            },
        )
        args = tuple(flatten_unflatten_for_dynamic_shapes(a) for a in args)
        kwargs = {k: flatten_unflatten_for_dynamic_shapes(v) for k, v in kwargs.items()}
        if verbose:
            print("[call_torch_export_onnx] dynamo=False so...")
            print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
            print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
    if opset:
        export_export_kwargs["opset_version"] = opset
    if verbose:
        print(
            f"[call_torch_export_onnx] export_export_kwargs="
            f"{string_type(export_export_kwargs, with_shape=True)}"
        )
    model = data["model"]
    epo = _quiet_or_not_quiet(
        quiet,
        "export_onnx",
        summary,
        data,
        (
            lambda m=model, args=args, kws=kwargs, ekws=export_export_kwargs: (
                torch.onnx.export(
                    m,
                    args,
                    kwargs=kws,
                    **ekws,
                )
            )
        ),
    )
    if "ERR_export_onnx" in summary:
        return summary, data
    assert epo is not None, "no onnx export was found"
    if verbose:
        print("[call_torch_export_onnx] done (export)")
    data["onnx_program"] = epo
    if verbose > 5:
        print("[call_torch_export_onnx] -- ONNXProgram")
        print(epo)
        print("[call_torch_export_onnx] -- End of ONNXProgram")
    if optimization in {"ir", "os_ort"}:
        if verbose:
            print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
        if optimization == "ir":
            label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
        else:
            def _os_ort_optim(epo):
                onnxscript.optimizer.optimize_ir(epo.model)
                optimized = ort_fusions.optimize_for_ort(epo.model)
                if isinstance(optimized, tuple):
                    for k, v in optimized[1].items():
                        summary[f"op_opt_fused_{k}"] = v
                    epo.model = optimized[0]
                else:
                    epo.model = optimized
            label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo))
        _quiet_or_not_quiet(quiet, label, summary, data, f_optim)
        if "ERR_export_onnx_opt_ir" in summary:
            return summary, data
        if verbose:
            print("[call_torch_export_onnx] done (optimization)")
    return summary, data 
[docs]
def call_torch_export_model_builder(
    data: Dict[str, Any],
    exporter: str,
    quiet: bool = False,
    verbose: int = 0,
    optimization: Optional[str] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Exports a model into onnx with :epkg:`ModelBuilder`.
    :param data: dictionary with all the necessary inputs, the dictionary must
        contains keys ``model`` and ``inputs_export``
    :param exporter: exporter to call
    :param quiet: catch exception or not
    :param verbose: verbosity
    :param optimization: optimization to do
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    """
    from ..helpers.model_builder_helper import create_model_builder, save_model_builder
    assert optimization in (
        None,
        "",
    ), f"unexpected value for optimization={optimization}, none is available"
    precision = data.get("model_dtype", "fp32")
    provider = data.get("model_device", "cpu")
    dump_folder = data.get("model_dump_folder", "")
    assert dump_folder, "dump_folder cannot be empty with ModelBuilder"
    cache_dir = os.path.join(dump_folder, "cache_mb")
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    summary: Dict[str, Any] = {}
    epo = _quiet_or_not_quiet(
        quiet,
        "export_model_builder",
        summary,
        data,
        (
            lambda m=data["model"], c=data[
                "configuration"
            ], p=precision, pr=provider, cd=cache_dir: (
                save_model_builder(
                    create_model_builder(
                        c, m, precision=p, execution_provider=pr, cache_dir=cd
                    )
                )
            )
        ),
    )
    if "ERR_export_model_builder" in summary:
        return summary, data
    assert epo is not None, "no onnx export was found"
    if verbose:
        print("[call_torch_export_model_builder] done (export)")
    data["onnx_program"] = epo
    return summary, data 
[docs]
def call_torch_export_custom(
    data: Dict[str, Any],
    exporter: str,
    quiet: bool = False,
    verbose: int = 0,
    optimization: Optional[str] = None,
    dump_folder: Optional[str] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Exports a model into onnx.
    If a patch must be applied, it should be before this functions.
    :param data: dictionary with all the necessary inputs, the dictionary must
        contains keys ``model`` and ``inputs_export``
    :param exporter: exporter to call
    :param quiet: catch exception or not
    :param verbose: verbosity
    :param optimization: optimization to do
    :param dump_folder: to store additional information
    :return: two dictionaries, one with some metrics,
        another one with whatever the function produces
    """
    available = {
        "",
        "default",
        "default+onnxruntime",
        "default+os_ort",
        "default+onnxruntime+os_ort",
        None,
    }
    assert (
        optimization in available
    ), f"unexpected value for optimization={optimization}, available={available}"
    available = {
        "custom",
        "custom-strict",
        "custom-strict-default",
        "custom-strict-all",
        "custom-nostrict",
        "custom-nostrict-default",
        "custom-nostrict-all",
        "custom-noinline",
        "custom-strict-noinline",
        "custom-strict-default-noinline",
        "custom-strict-all-noinline",
        "custom-nostrict-noinline",
        "custom-nostrict-default-noinline",
        "custom-nostrict-all-noinline",
    }
    assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
    assert "model" in data, f"model is missing from data: {sorted(data)}"
    assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
    summary: Dict[str, Union[str, int, float]] = {}
    strict = "-strict" in exporter
    args, kwargs = split_args_kwargs(data["inputs_export"])
    ds = data.get("dynamic_shapes", None)
    opset = data.get("model_opset", None)
    if opset:
        summary["export_opset"] = opset
    if verbose:
        print(
            f"[call_torch_export_custom] exporter={exporter!r}, "
            f"optimization={optimization!r}"
        )
        print(f"[call_torch_export_custom] args={string_type(args, with_shape=True)}")
        print(f"[call_torch_export_custom] kwargs={string_type(kwargs, with_shape=True)}")
        print(f"[call_torch_export_custom] dynamic_shapes={string_type(ds)}")
        print("[call_torch_export_custom] export...")
    summary["export_exporter"] = exporter
    summary["export_optimization"] = optimization or ""
    summary["export_strict"] = strict
    summary["export_args"] = string_type(args, with_shape=True)
    summary["export_kwargs"] = string_type(kwargs, with_shape=True)
    from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
    from experimental_experiment.xbuilder import OptimizationOptions
    spl = optimization.split("+") if optimization else []
    os_ort = "os_ort" in spl
    optimization = "+".join(_ for _ in spl if _ != "os_ort")
    export_options = ExportOptions(
        strict=strict,
        decomposition_table=(
            "default" if "-default" in exporter else ("all" if "-all" in exporter else None)
        ),
        save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
    )
    inline = "-noinline" not in exporter
    options = OptimizationOptions(patterns=optimization) if optimization else None
    model = data["model"]
    kws = dict(
        dynamic_shapes=ds,
        export_options=export_options,
        options=options,
        optimize=bool(optimization),
        large_model=True,
        return_optimize_report=True,
        verbose=max(verbose - 2, 0),
        inline=inline,
    )
    if opset:
        kws["target_opset"] = opset
    epo, opt_stats = _quiet_or_not_quiet(
        quiet,
        "export_export_onnx_c",
        summary,
        data,
        (
            lambda m=model, args=args, kwargs=kwargs, kws=kws: (
                to_onnx(
                    model,
                    args,
                    kwargs=kwargs,
                    **kws,
                )
            )
        ),
    )
    if "ERR_export_onnx_c" in summary:
        return summary, data
    new_stat = {}
    if "optimization" in opt_stats:
        added, removed, time_in = 0, 0, 0.0
        max_iter = 0
        applied = {}
        matched = set()
        n_applied = 0
        by_pattern = {}
        by_pattern_n = {}
        by_iter = {}
        cst_added, cst_removed, cst_time_in = 0, 0, 0.0
        for obs in opt_stats["optimization"]:
            pattern = obs["pattern"]
            if pattern == "constant_folding":
                cst_added += obs.get("added", 0)
                cst_removed += obs.get("removed", 0)
                cst_time_in += obs.get("time_in", 0)
            if pattern not in by_pattern:
                by_pattern[pattern] = 0
                by_pattern_n[pattern] = 0
                by_iter[pattern] = 0
            time_in += obs.get("time_in", 0)
            added += obs.get("added", 0)
            removed += obs.get("removed", 0)
            max_iter = max(max_iter, obs.get("iteration", 0))
            by_pattern[pattern] += obs.get("time_in", 0)
            by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0)
            if not pattern.startswith("match"):
                by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0))
            p = obs["pattern"]
            if p.startswith("match_"):
                matched.add(p)
            elif p.startswith("apply_"):
                key = f"op_opt_{p}"
                key2 = f"op_opt_maxiter_{p}"
                if key not in applied:
                    applied[key] = 1
                    applied[key2] = obs["iteration"]
                else:
                    applied[key] += 1
                    applied[key2] = max(obs["iteration"], applied[key2])
                n_applied += 1
        new_stat.update(
            dict(
                onnx_opt_optimized=1,
                op_opt_all_time_in=time_in,
                op_opt_all_added=added,
                op_opt_all_removed=removed,
                op_opt_max_iter=max_iter,
                op_opt_unique_matched=len(matched),
                op_opt_unique_applied=len(applied),
                op_opt_n_applied=n_applied,
                time_export_optimization=time_in,
                op_opt_export_optimization=time_in,
                op_opt_cst_time_in=cst_time_in,
                op_opt_cst_added=cst_added,
                op_opt_cst_removed=cst_removed,
            )
        )
    summary.update(new_stat)
    assert epo is not None, "no onnx export was found"
    if verbose:
        print("[call_torch_export_custom] done (export)")
    if os_ort:
        if verbose:
            print("[call_torch_export_custom] conversion to IR...")
        begin = time.perf_counter()
        ir_model = epo.to_ir()
        duration = time.perf_counter() - begin
        summary["time_optim_to_ir"] = duration
        if verbose:
            print(f"[call_torch_export_custom] done in {duration}")
            print("[call_torch_export_custom] start optimization...")
        begin = time.perf_counter()
        onnxscript.optimizer.optimize_ir(ir_model)
        ir_optimized = ort_fusions.optimize_for_ort(ir_model)
        if isinstance(ir_optimized, tuple):
            report = ir_optimized[1]
            for k, v in report.items():
                summary[f"op_opt_fused_{k}"] = v
            ir_optimized = ir_optimized[0]
            epo.model = ir_optimized
        duration = time.perf_counter() - begin
        summary["time_optim_os_ort"] = duration
        if verbose:
            print(f"[call_torch_export_custom] done in {duration}")
    data["onnx_program"] = epo
    return summary, data 
[docs]
def run_ort_fusion(
    model_or_path: Union[str, onnx.ModelProto],
    output_path: str,
    num_attention_heads: int,
    hidden_size: int,
    model_type: str = "bert",
    verbose: int = 0,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Runs :epkg:`onnxruntime` fusion optimizer.
    :param model_or_path: path to the ModelProto or the ModelProto itself
    :param output_path: the model to save
    :param num_attention_heads: number of heads, usually ``config.num_attention_heads``
    :param hidden_size: hidden size, usually ``config.hidden_size``
    :param model_type: type of optimization, see below
    :param verbose: verbosity
    :return: two dictionaries, summary and data
    Supported values for ``model_type``:
    .. runpython::
        :showcode:
        import pprint
        from onnxruntime.transformers.optimizer import MODEL_TYPES
        pprint.pprint(sorted(MODEL_TYPES))
    """
    from onnxruntime.transformers.optimizer import optimize_by_fusion
    from onnxruntime.transformers.fusion_options import FusionOptions
    opts = FusionOptions(model_type)
    if isinstance(model_or_path, str):
        if verbose:
            print(f"[run_ort_fusion] loads {model_or_path!r}")
        onx = onnx.load(model_or_path)
    else:
        onx = model_or_path
    begin = time.perf_counter()
    n_nodes = len(onx.graph.node)
    if verbose:
        print(
            f"[run_ort_fusion] starts optimization for "
            f"model_type={model_type!r} with {n_nodes} nodes"
        )
    try:
        new_onx = optimize_by_fusion(
            onx,
            model_type=model_type,
            num_heads=num_attention_heads,
            hidden_size=hidden_size,
            optimization_options=opts,
        )
    except Exception as e:
        duration = time.perf_counter() - begin
        if verbose:
            print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}")
        return {
            f"ERR_opt_ort_{model_type}": str(e),
            f"opt_ort_{model_type}_duration": duration,
        }, {}
    duration = time.perf_counter() - begin
    delta = len(new_onx.model.graph.node)
    if verbose:
        print(f"[run_ort_fusion] done in {duration} with {delta} nodes")
        print(f"[run_ort_fusion] save to {output_path!r}")
    begin = time.perf_counter()
    new_onx.save_model_to_file(output_path, use_external_data_format=True)
    d = time.perf_counter() - begin
    if verbose:
        print(f"[run_ort_fusion] done in {d}")
    return {
        f"opt_ort_{model_type}_n_nodes1": n_nodes,
        f"opt_ort_{model_type}_n_nodes2": delta,
        f"opt_ort_{model_type}_delta_node": delta - n_nodes,
        f"opt_ort_{model_type}_duration": duration,
        f"opt_ort_{model_type}_duration_save": d,
    }, {f"opt_ort_{model_type}": output_path}