import inspect
import os
import textwrap
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
import torch
from .dynamic_shapes import ModelInputs
from .onnx_plug import EagerDirectReplacementWithOnnx
from ..helpers import string_type
[docs]
def get_main_dispatcher(
use_control_flow_dispatcher: bool = False,
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
) -> Any: # Dispatcher
"""Creates a custom dispatcher for the custom exporter."""
from experimental_experiment.torch_interpreter import Dispatcher
if use_control_flow_dispatcher:
from .control_flow_onnx import create_global_dispatcher
control_flow_dispatcher = create_global_dispatcher()
else:
control_flow_dispatcher = None
class MainDispatcher(Dispatcher):
def __init__(self, previous_dispatcher=None):
super().__init__({})
self.previous_dispatcher = previous_dispatcher
@property
def supported(self):
if self.previous_dispatcher:
return set(self.registered_functions) | self.previous_dispatcher.supported
return set(self.registered_functions)
def find_function(self, name: Any):
if self.previous_dispatcher:
find = self.previous_dispatcher.find_function(name)
if find:
return find
return Dispatcher.find_function(self, name)
def find_method(self, name: Any):
if self.previous_dispatcher:
find = self.previous_dispatcher.find_method(name)
if find:
return find
return Dispatcher.find_method(self, name)
main_dispatcher = MainDispatcher(control_flow_dispatcher)
if onnx_plugs:
for plug in onnx_plugs:
main_dispatcher.registered_functions[plug.target_name] = plug.custom_converter()
return main_dispatcher
[docs]
def to_onnx(
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
input_names: Optional[Sequence[str]] = None,
target_opset: Optional[Union[int, Dict[str, int]]] = None,
verbose: int = 0,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
filename: Optional[str] = None,
output_names: Optional[List[str]] = None,
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
exporter: str = "onnx-dynamo",
exporter_kwargs: Optional[Dict[str, Any]] = None,
save_ep: Optional[str] = None,
optimize: bool = True,
optimizer_for_ort: bool = True,
use_control_flow_dispatcher: bool = False,
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
inline: bool = True,
) -> Any:
"""
Exports one model into ONNX.
Common API for exporters. By default, the models are optimized to use the
most efficient kernels implemented in :epkg:`onnxruntime`.
:param mod: torch model
:param args: unnamed arguments
:param kwargs: named arguments
:param input_names: input names for the onnx model (optional)
:param target_opset: opset to target, if not specified, each converter
keeps its default value
:param verbose: verbosity level
:param dynamic_shapes: dynamic shapes, usually a nested structure
included a dictionary for each tensor
:param filename: output filename
:param output_names: to change the output of the onnx model
:param output_dynamic_shapes: to overwrite the dynamic shapes names
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
:param exporter_kwargs: additional parameters sent to the exporter
:param save_ep: saves the exported program
:param optimize: optimizes the model
:param optimizer_for_ort: optimizes the model for onnxruntime
:param use_control_flow_dispatcher: use the dispatcher created to supported
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
:param onnx_plugs: the code was modified to replace some parts with onnx translation
:param inline: inline local functions
:return: the output of the selected exporter, usually a structure including
an onnx model
A simple example:
.. code-block:: python
to_onnx(
model,
kwargs=inputs,
dynamic_shapes=ds,
exporter=exporter,
filename=filename,
)
Some examples using control flows are available in
:func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx` or
:class:`onnx_diagnostic.export.onnx_plug.EagerDirectReplacementWithOnnx`.
"""
if exporter_kwargs and "inline" in exporter_kwargs:
assert (
inline == exporter_kwargs["inline"]
), f"Mismatch between inline={inline} and exporter_kwargs={exporter_kwargs}"
exporter_kwargs.pop("inline")
if exporter == "custom":
from experimental_experiment.torch_interpreter import (
to_onnx as _to_onnx,
ExportOptions,
)
from experimental_experiment.xbuilder import OptimizationOptions
options = None
export_options = None
if exporter_kwargs is not None:
options = exporter_kwargs.pop("options", None)
export_options = exporter_kwargs.pop("export_options", None)
if export_options is None:
export_options = ExportOptions(save_ep=save_ep)
if options is None and optimize:
options = OptimizationOptions(
patterns="default+onnxruntime" if optimizer_for_ort else "default"
)
main_dispatcher = (
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
if onnx_plugs or use_control_flow_dispatcher
else None
)
proto, opt_stats = _to_onnx(
mod,
args=args,
kwargs=kwargs,
input_names=input_names,
output_names=output_names,
target_opset=target_opset,
verbose=verbose,
filename=filename,
dynamic_shapes=dynamic_shapes,
large_model=True,
output_dynamic_shapes=output_dynamic_shapes,
export_options=export_options,
options=options,
inline=inline,
dispatcher=main_dispatcher,
optimize=optimize,
return_optimize_report=True,
**(exporter_kwargs or {}),
)
if opt_stats and filename and os.path.exists(filename):
import pandas
stat_filename = f"{os.path.splitext(filename)[0]}.opt.xlsx"
pattern_stats = []
for k, v in opt_stats.items():
if "time" in k:
pattern_stats.append(dict(level="main", pattern=k, time_in=v))
pattern_stats.extend(
[{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
)
df = pandas.DataFrame(pattern_stats)
df.to_excel(stat_filename, index=False)
cols = [
c
for c in [
"level",
"pattern",
"time_in",
"iteration",
"inlined",
"removed",
"added",
"instances",
"changed",
"scale",
]
if c in df.columns
]
agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
agg.update(dict(iteration="max", instances="mean"))
agg = {k: v for k, v in agg.items() if k in df.columns}
stat_filename = f"{os.path.splitext(filename)[0]}.opt.agg.xlsx"
df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
return proto
if exporter in ("dynamo", "onnx-dynamo"):
from ..helpers import flatten_object
import onnxscript.rewriter.ort_fusions as ort_fusions
assert (
not output_dynamic_shapes
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
assert (
optimize
), f"torch.onnx.export always optimizes the model but optimize={optimize}"
custom_translation_table = {}
if onnx_plugs:
for plug in onnx_plugs:
custom_translation_table[plug.torch_op] = plug.onnx_dynamo_converter()
epo = torch.onnx.export(
mod,
args=args or tuple(),
kwargs=kwargs,
input_names=input_names,
output_names=output_names,
opset_version=target_opset,
dynamic_shapes=dynamic_shapes,
dynamo=True,
verbose=verbose,
dump_exported_program=bool(save_ep),
artifacts_dir=os.path.dirname(filename) if filename else ".",
custom_translation_table=custom_translation_table,
**(exporter_kwargs or {}),
)
if not inline and optimize and optimizer_for_ort:
ort_fusions.optimize_for_ort(epo.model)
if onnx_plugs:
import onnx_ir as ir
import onnx_ir.passes.common as common_passes
opset = (
18
if target_opset is None
else (target_opset if isinstance(target_opset, int) else target_opset[""])
)
irfunctions = [
ir.from_proto(
plug.get_function_proto(
opset, *flatten_object((args, kwargs), drop_keys=True)
)
)
for plug in onnx_plugs
]
for func in irfunctions:
epo.model.functions[func.identifier()] = func
if inline:
common_passes.InlinePass()(epo.model)
common_passes.RemoveUnusedOpsetsPass()(epo.model)
if inline and optimize and optimizer_for_ort:
ort_fusions.optimize_for_ort(epo.model)
if filename:
epo.save(filename, external_data=True)
if save_ep:
if isinstance(save_ep, tuple):
save_ep = save_ep[0]
torch.export.save(epo.exported_program, f"{save_ep}.pt2")
return epo
if exporter == "modelbuilder":
from ..helpers import flatten_object, string_type
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
assert filename, f"filename must be specified for exporter={exporter!r}"
assert (
not output_dynamic_shapes
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}"
assert not args, f"only kwargs can be defined with exporter={exporter!r}"
assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], ( # type: ignore[arg-type]
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
f"but it is {list(kwargs)}" # type: ignore[arg-type]
)
assert optimizer_for_ort and optimize, (
f"ModelBuilder only produces model optimized for onnxruntime but "
f"optimizer_for_ort={optimizer_for_ort} and optimize={optimize}"
)
flat_inputs = flatten_object(kwargs, drop_keys=True)
first = flat_inputs[0]
first_float = [
t
for t in flat_inputs
if t.dtype in {torch.float32, torch.double, torch.float16, torch.bfloat16}
]
assert first_float, (
f"Unable to find a float tensor in the inputs "
f"{string_type(kwargs, with_shape=True)}"
)
onx = create_model_builder(
mod.config,
mod,
precision=str(first_float[0].dtype).split(".")[-1],
execution_provider="cuda" if first.is_cuda else "cpu",
cache_dir=os.path.dirname(filename),
**(exporter_kwargs or {}),
)
save_model_builder(onx, os.path.dirname(filename))
return onx
raise ValueError(f"Unknown exporter={exporter!r}")
class _WrapperToExportMethodToOnnx(torch.nn.Module):
"""
Wraps an existing models in order to spy on inputs.
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
"""
def __init__(
self,
mod: "torch.nn.Module",
method_name: str = "forward",
input_names: Optional[Sequence[str]] = None,
target_opset: Optional[Union[int, Dict[str, int]]] = None,
verbose: int = 0,
filename: Optional[str] = None,
output_names: Optional[List[str]] = None,
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
exporter: str = "onnx-dynamo",
exporter_kwargs: Optional[Dict[str, Any]] = None,
save_ep: Optional[str] = None,
optimize: bool = True,
optimizer_for_ort: bool = True,
use_control_flow_dispatcher: bool = False,
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
inline: bool = True,
convert_after_n_calls: int = 2,
patch_kwargs: Optional[Dict[str, Any]] = None,
skip_kwargs_names: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
super().__init__()
self._model_to_call = mod
self._method_name = method_name
self._method_call = (
self._model_to_call.forward
if method_name == "forward"
else getattr(mod, method_name)
)
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
self._convert_after_n_calls = convert_after_n_calls
self._patch_kwargs = patch_kwargs
self._method_src = None
self.verbose = verbose
self.skip_kwargs_names = skip_kwargs_names
self.dynamic_shapes = dynamic_shapes
self._to_onnx_kwargs = dict(
input_names=input_names,
target_opset=target_opset,
verbose=verbose,
filename=filename,
output_names=output_names,
output_dynamic_shapes=output_dynamic_shapes,
exporter=exporter,
exporter_kwargs=exporter_kwargs,
save_ep=save_ep,
optimize=optimize,
optimizer_for_ort=optimizer_for_ort,
use_control_flow_dispatcher=use_control_flow_dispatcher,
onnx_plugs=onnx_plugs,
inline=inline,
)
self._export_done = False
def __str__(self) -> str:
return self.__repr__()
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
f"{self._method_name})"
)
def forward(self, *args, **kwargs):
if not self._export_done:
self._inputs.append(
(
args,
(
kwargs
if not kwargs or not self.skip_kwargs_names
else {
k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names
}
),
)
)
if self.verbose:
print(
f"[method_to_onnx] input[{len(self._inputs)-1}]: "
f"{string_type(self._inputs[-1], with_shape=True)}"
)
if len(self._inputs) >= self._convert_after_n_calls:
self._convert_method_to_onnx()
del self._inputs[:]
self._export_done = True
return self._method_call(*args, **kwargs)
def _convert_method_to_onnx(self):
def make_method(self):
inner_sig = inspect.signature(self._method_call)
params = [
p.replace(annotation=inspect._empty) for p in inner_sig.parameters.values()
]
simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
args = str(simple_sig)[1:-1]
calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
src = textwrap.dedent(
f"""
def f(self, {args}):
return self._method_call({calls_args})
"""
)
self._method_src = src
ns = {}
try:
exec(src, ns)
except NameError as e:
raise NameError(f"Unable to compile due to {e}\n{src}") from e
return ns["f"]
class WrapWithExactSignature(torch.nn.Module):
def __init__(self, parent):
super().__init__()
self._model_to_call = parent._model_to_call
self._method_call = parent._method_call
forward = make_method(self)
compiled_model = WrapWithExactSignature(self)
if self.dynamic_shapes is None:
mi = ModelInputs(compiled_model, self._inputs)
ds = mi.guess_dynamic_shapes()
if self.verbose:
print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
else:
a, kw = self._inputs[-1]
nds = [self.dynamic_shapes]
if self.verbose:
print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
print(f"[method_to_onnx] dynamic_shapes={string_type(nds)}")
if self._patch_kwargs is None:
to_onnx(
compiled_model,
args=a,
kwargs=kw,
dynamic_shapes=nds[-1],
**self._to_onnx_kwargs,
)
return
from ..torch_export_patches import torch_export_patches
with torch_export_patches(**self._patch_kwargs):
to_onnx(
compiled_model,
args=a,
kwargs=kw,
dynamic_shapes=nds[-1],
**self._to_onnx_kwargs,
)
[docs]
def method_to_onnx(
mod: "torch.nn.Module",
method_name: str = "forward",
input_names: Optional[Sequence[str]] = None,
target_opset: Optional[Union[int, Dict[str, int]]] = None,
verbose: int = 0,
filename: Optional[str] = None,
output_names: Optional[List[str]] = None,
output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
exporter: str = "onnx-dynamo",
exporter_kwargs: Optional[Dict[str, Any]] = None,
save_ep: Optional[str] = None,
optimize: bool = True,
optimizer_for_ort: bool = True,
use_control_flow_dispatcher: bool = False,
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
inline: bool = True,
convert_after_n_calls: int = 2,
patch_kwargs: Optional[Dict[str, Any]] = None,
skip_kwargs_names: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> Callable:
"""
Exports one method into ONNX for a module into ONNX.
It returns a new method which must be called by the user
at least twice with different values for the dynamic dimension
between triggering the conversion into ONNX.
:param mod_meth: function to export into ONNX
:param input_names: input names for the onnx model (optional)
:param target_opset: opset to target, if not specified, each converter
keeps its default value
:param verbose: verbosity level
:param filename: output filename, mandatory, the onnx model is saved on disk
:param output_names: to change the output of the onnx model
:param output_dynamic_shapes: to overwrite the dynamic shapes names
:param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
:param exporter_kwargs: additional parameters sent to the exporter
:param save_ep: saves the exported program
:param optimize: optimizes the model
:param optimizer_for_ort: optimizes the model for onnxruntime
:param use_control_flow_dispatcher: use the dispatcher created to supported
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
:param onnx_plugs: the code was modified to replace some parts with onnx translation
:param inline: inline local functions
:param convert_after_n_calls: converts the model after this number of calls.
:param patch_kwargs: patch arguments
:param skip_kwargs_names: use default values for these parameters part of
the signature of the method to export
:param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
:return: the output of the selected exporter, usually a structure including
an onnx model
See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
"""
wrapped_model = _WrapperToExportMethodToOnnx(
mod=mod,
method_name=method_name,
input_names=input_names,
target_opset=target_opset,
verbose=verbose,
filename=filename,
output_names=output_names,
output_dynamic_shapes=output_dynamic_shapes,
exporter=exporter,
exporter_kwargs=exporter_kwargs,
save_ep=save_ep,
optimize=optimize,
optimizer_for_ort=optimizer_for_ort,
use_control_flow_dispatcher=use_control_flow_dispatcher,
onnx_plugs=onnx_plugs,
inline=inline,
convert_after_n_calls=convert_after_n_calls,
patch_kwargs=patch_kwargs,
skip_kwargs_names=skip_kwargs_names,
dynamic_shapes=dynamic_shapes,
)
return wrapped_model