import inspect
import os
import pprint
import time
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from onnx import ModelProto, save_model
from onnx.defs import onnx_opset_version
from onnx.model_container import ModelContainer
from ..helpers import string_type
from ..xbuilder.graph_builder import GraphBuilder, OptimizationOptions, FunctionOptions
from .export_options import ExportOptions
def _retrieve(
name: str,
value: Any,
weights: Dict[str, "torch.Tensor"], # noqa: F821
buffers: Dict[str, "torch.Tensor"], # noqa: F821
constants: Dict[str, "torch.Tensor"], # noqa: F821
mapping: Dict[str, Tuple[str, bool]],
graph_builder: "GraphBuilder", # noqa: F821
debug: Optional[Any] = None,
exc: bool = True,
) -> "torch.Tensor": # noqa: F821
"""
Sent to the :class:`DynamoInterpreter
<experimental_experiment.torch_interpreter.interpreter.DynamoInterpreter>`.
It retrieves the weights.
:param name: name to retrieve
:param value: value
:param weights: mapping name, weights
:param buffers: mapping name, buffer
:param constants: mapping name, constants
:param mapping: mapping name, (new_name, is_weight)
:param graph_builder: graph builder
:param debug: any debug information when an issue is raised
:param exc: raises an exception if not found
"""
if name not in mapping:
import torch
# This is not a weight but a constant.
if isinstance(value, torch.Tensor) and "FakeTensor" not in str(type(value)):
return value
if len(weights) == 0 and len(buffers) == 0 and len(constants) == 0:
# It has to be an input.
return None
if exc:
raise RuntimeError(
f"Unable to find {name!r}."
f"\nAvailable weights: {list(sorted(weights))}. "
f"\nAvailable buffers: {list(sorted(buffers))}. "
f"\nAvailable constants: {list(sorted(constants))}. "
f"\nmapping={mapping}"
f"{graph_builder.get_debug_msg() if graph_builder else ''}"
f"\nvalue={value.dtype}:{value.shape}\n{value}"
)
return None
new_name, is_weight = mapping[name]
if is_weight:
# weights
if new_name not in weights:
if new_name.startswith("L__self___") and new_name[len("L__self___") :] in weights:
new_name = new_name[len("L__self___") :]
assert new_name in weights, (
f"Unexpected name {name!r} for input "
f"{name!r} mapped to weight {new_name!r}, "
f"cannot be found in {', '.join(sorted(weights))}."
)
import torch
value = weights[new_name]
assert isinstance(value, torch.Tensor), (
f"Unexpected type {type(value)} for input "
f"{name!r} mapped to weight {new_name!r}."
)
return value
# buffers and constants or lieft tensors
if new_name in buffers:
value = buffers[new_name]
import torch
assert isinstance(value, torch.Tensor), (
f"Unexpected type {type(value)} for buffer "
f"{name!r} mapped to buffer {new_name!r}."
)
return value
if new_name in constants:
value = constants[new_name]
import torch
assert isinstance(value, torch.Tensor), (
f"Unexpected type {type(value)} for constant "
f"{name!r} mapped to constant {new_name!r}."
)
return value
if new_name.startswith("L__self___") and new_name[len("L__self___") :] in buffers:
new_name = new_name[len("L__self___") :]
value = buffers[new_name]
import torch
assert isinstance(value, torch.Tensor), (
f"Unexpected type {type(value)} for buffer "
f"{name!r} mapped to buffer {new_name!r}."
)
return value
if new_name.startswith("c_") and new_name[len("c_") :] in constants:
new_name = new_name[len("c_") :]
value = constants[new_name]
import torch
assert isinstance(value, torch.Tensor), (
f"Unexpected type {type(value)} for constant "
f"{name!r} mapped to constant {new_name!r}."
)
return value
if exc:
raise ValueError(
f"Unexpected name {name!r} for input "
f"{name!r} mapped to buffer or constant {new_name!r}, "
f"cannot be found in {', '.join(sorted(buffers))} or "
f"{', '.join(sorted(constants))}"
)
return None
[docs]
class SubModuleNaming:
"""
A class which maps class submodule name and local functions in order to give
short but unique names.
"""
def __init__(self, mod: "torch.nn.Module"): # noqa: F821
self.mod = mod
self._memo = {}
self._names = {}
def __call__(self, name: str, submod: "torch.nn.Module") -> str: # noqa: F821
if type(submod) in self._memo:
self._names[type(submod)].append(name)
return self._memo[type(submod)]
type_name = submod.__class__.__name__
ends = f"<locals>.{type_name}'>"
if str(type(submod)).endswith(ends):
type_name = ends[:-2]
ends = f"__main__.{type_name}'>"
if str(type(submod)).endswith(ends):
type_name = ends[:-2]
if type_name not in self._memo:
self._memo[type(submod)] = type_name
self._memo[type_name] = type(submod)
self._names[type(submod)] = [name]
return type_name
raise NotImplementedError(
f"Unable to give a unique name to submodule {name!r}, "
f"module={submod.__module__!r}, "
f"type_name={type_name!r}, type(submod)={type(submod)}, "
f"already given:\n{pprint.pformat(self._memo)}"
f"\nnames:\n{pprint.pformat(self._names)}"
)
[docs]
class ParameterNaming:
"""
A class which maps parameters name in the original module and the different
they have in the fx.graph.
The exported program and the original model may have different parameter names.
"""
@classmethod
def _local_rendering(cls, diff):
return "\n".join([", ".join(map(str, _)) for _ in diff])
def __init__(
self,
mod: "torch.nn.Module", # noqa: F821
exported_program: Optional["torch.export.ExportedProgram"] = None, # noqa: F821
):
self.mod = mod
self._idmap = {}
self._id_modules = {}
self.display = {}
self._unable_to_map = set()
use_mod = mod
add_names = {}
if exported_program is not None:
mod_names = dict(mod.named_parameters())
exp_names = dict(exported_program.named_parameters())
if mod_names != exp_names:
mod_ptr = set(m.data_ptr() for m in mod_names.values())
exp_ptr = set(m.data_ptr() for m in exp_names.values())
if mod_ptr != exp_ptr and None in exp_ptr:
# If names and pointers are different, it is more difficult
# to map parameters after exporting the model.
union = mod_names | exp_names
diff = []
for k in sorted(union):
if k in mod_names and k in exp_names:
continue
diff.append(
(
1 if k in mod_names else 0,
1 if k in exp_names else 0,
k,
string_type(mod_names.get(k, None), with_shape=True),
mod_names[k].data_ptr() if k in mod_names else None,
string_type(exp_names.get(k, None), with_shape=True),
exp_names[k].data_ptr() if k in exp_names else None,
)
)
assert all(_[1] == 1 for _ in diff), (
f"ExportedProgram and module do not have the same parameters\n"
f"{self._local_rendering(diff)}\n----\n{exported_program.graph}"
)
use_mod = exported_program
else:
# It means we probably needs to add an alias.
# We assume all used parametes in the exported program
# are part of the exported program.
# Maybe the conversion had to add a name.
for m in exp_names:
if m not in mod_names:
add_names[m] = exp_names[m]
# parameter names
def _chain(use_mod, add_names):
yield from use_mod.named_parameters()
yield from add_names.items()
for name, p in _chain(use_mod, add_names):
self._idmap[name] = p
self.display[name] = name
new_key = name.replace(".", "_")
if new_key != name:
assert (
new_key not in self._id_modules
), f"Two parameters have similar names {name!r} mapped into {new_key!r}"
self._idmap[new_key] = name
self.display[new_key] = name
else:
self._idmap[new_key] = name
# modules names
for name, submod in mod.named_modules():
if not name:
continue
self._id_modules[name] = submod
new_key = name.replace(".", "_")
if new_key != name:
assert (
new_key not in self._id_modules
), f"Two modules have similar names {name!r} mapped into {new_key!r}"
self._id_modules[new_key] = name
else:
self._id_modules[new_key] = name
updates = {}
# final step
for k in self._idmap:
kl = k.lower()
if kl == k:
continue
assert (
kl not in self._idmap
), f"Ambiguities for weights {kl!r} and {k!r} in {pprint.pformat(self._idmap)}"
updates[kl] = self._idmap[k]
self._idmap.update(updates)
def __call__(
self,
name: str,
value: "torch.nn.Parameter", # noqa: F821
node: "torch.fx.Node", # noqa: F821
prefix: Optional[str] = None,
msg: Optional[Callable] = None,
) -> str:
assert isinstance(
name, str
), f"Unexpected type {type(name)} for name{self.get_debug_msg()}"
from_node = None
if "from_node" in node.meta:
from_node = node.meta["from_node"]
assert (
len(from_node) == 1
), f"Parameter {name!r} seems shared accross multiple objects{from_node}"
key = None
if name.startswith("p_") and name[2:] in self._idmap:
key = name[2:]
elif name.startswith("p_fn_"):
key = name[len("p_fn_") :]
elif from_node is not None:
# Only valid for pytorch >= 2.6
parent_name = from_node[0].node_info.target
key = f"{parent_name}_{name}"
if key.startswith("L__self___"):
key = key[len("L__self___") :]
elif name in self._idmap:
key = name
if key is None:
assert prefix is not None, (
f"Unable to find parameter {name!r} from node {node!r} "
f"with a null prefix, "
f"with node.meta={pprint.pformat(node.meta)}\n--display--\n"
f"{pprint.pformat(self.display)}{msg() if msg else ''}"
)
key = f"{prefix}.{name}"
if key not in self._idmap:
# There may be unknown name if the module dynamically creates name.
self._unable_to_map.add(name)
return name
res = self._idmap[key]
if not isinstance(res, str):
return key
assert isinstance(
res, str
), f"Unexpected type for key={key!r}, name={name!r}, type(res)={type(res)!r}"
return res
def _make_builder_interpreter(
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
input_names: Optional[Sequence[str]] = None,
target_opset: Union[int, Dict[str, int]] = 18,
as_function: bool = False,
optimization_options: Optional[OptimizationOptions] = None,
verbose: int = 0,
raise_list: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
tracing_mode: str = "symbolic",
same_signature: bool = True,
dispatcher: Optional["Dispatcher"] = None, # noqa: F821
export_options: Optional[Union[str, ExportOptions]] = None,
optimize_submodules: bool = False,
function_options: Optional[FunctionOptions] = None,
local_domain: str = "local_functions",
submodule_naming: Optional[Callable] = None,
parameter_naming: Optional[Callable] = None,
module_name: Optional[str] = None,
) -> Tuple[
Union["torch.export.ExportedProgram", "torch.fx.GraphModule"], # noqa: F821
GraphBuilder,
"DynamoInterpreter", # noqa: F821
Optional[List[bool]],
]:
"""
Exports a torch model into ONNX using
`dynamo export
<https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`_.
:param mod: torch module
:param args: input arguments
:param kwargs: keyword attributes
:param input_names: input names
:param target_opset: targeted opset or targeted opsets as a dictionary
:param as_function: export as a ModelProto or a FunctionProto
:param optimization_options: optimization options
:param verbose: verbosity level
:param raise_list: the builder stops any time a name falls into that list,
this is a debbuging tool
:param dynamic_shapes: see :epkg:`torch.export.export` or ``torch._dynamo.export``
:param same_signature: same signature
:param tracing_mode: tracing model
:param dispatcher: see :class:`experimental_experiment.torch_interpreter.Dispatcher`
:param export_options: Optional[Union[str, ExportOptions]] = None,
:param optimize_submodules: optimizes submodules, this is done while building the model,
and not at the end
:param function_options: how to deal with local functions
:param local_domain: domain name to use for local functions if not specified
:param submodule_naming: a function which returns a submodule name in the onnx graph
:param parameter_naming: a function which returns a parameter name in the onnx graph
:param module_name: name of the module, to help retrieve the parameter name
:return: onnx model, interpreter, graph builder, mask_outputs
"""
def _get(x, att=None):
if att is None:
if isinstance(x, dict):
return list(sorted(x))
if isinstance(x, list):
return x
return [x]
if hasattr(x, att):
return _get(getattr(x, att))
return ["?"]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torch
import torch.export
if export_options is None:
export_options = ExportOptions()
mask_outputs = None
if isinstance(mod, torch.fx.GraphModule):
if verbose > 0:
print(f"[_make_builder_interpreter] use existing {type(mod)}")
graph_module = mod
weights = dict(graph_module.named_parameters())
buffers = dict(graph_module.named_buffers())
constants = mod.state_dict()
mapping = {}
if os.environ.get("PRINT_GRAPH_MODULE", "0") in (1, "1"):
print("-- GIVEN GRAPH MODULE")
print(graph_module.graph)
exported_program = None
elif isinstance(mod, torch.nn.Module) and mod.__class__.__name__ == "InterpreterModule":
# comes from unflatten function
if verbose > 0:
print(f"[_make_builder_interpreter] use existing submodule {type(mod)}")
graph_module = mod
weights = dict(graph_module.named_parameters())
buffers = dict(graph_module.named_buffers())
constants = mod.state_dict()
mapping = {}
if os.environ.get("PRINT_GRAPH_MODULE", "0") in (1, "1"):
print("-- GIVEN GRAPH MODULE")
print(graph_module.graph)
exported_program = None
else:
if verbose > 0:
print(f"[_make_builder_interpreter] export_options={export_options!r}")
print(f"[_make_builder_interpreter] input args={string_type(args)}")
print(f"[_make_builder_interpreter] input kwargs={string_type(kwargs)}")
print(f"[_make_builder_interpreter] dynamic_shapes={dynamic_shapes}")
print(
f"[_make_builder_interpreter] same_signature={same_signature}, "
f"tracing_mode={tracing_mode}"
)
# If this step fails, try bypass_export_some_errors.
exported_program = export_options.export(
mod,
args if isinstance(args, tuple) else (tuple() if args is None else args),
kwargs,
tracing_mode=tracing_mode,
dynamic_shapes=dynamic_shapes,
same_signature=same_signature,
input_names=input_names,
verbose=verbose,
)
if os.environ.get("PRINT_EXPORTED_PROGRAM", "0") in (1, "1"):
print("-- EXPORTED PROGRAM")
print(exported_program)
graph_module = (
exported_program
if isinstance(exported_program, torch.fx.GraphModule)
else exported_program.graph_module
)
if os.environ.get("PRINT_GRAPH_MODULE", "0") in (1, "1"):
print("-- EXPORTED GRAPH MODULE")
print(graph_module.graph)
try:
weights = dict(exported_program.named_parameters())
except AttributeError:
weights = dict(mod.named_parameters())
try:
buffers = dict(exported_program.named_buffers())
except AttributeError:
buffers = dict(mod.named_buffers())
if hasattr(exported_program, "tensor_constants"):
constants = exported_program.tensor_constants or {}
else:
# A bug may appear later.
constants = {}
if hasattr(exported_program, "graph_signature"):
sig_mismatch_constants = set(k.replace(".", "_") for k in constants)
signature = exported_program.graph_signature
mapping = {}
for k, v in signature.inputs_to_parameters.items():
mapping[k] = v, True
for k, v in signature.inputs_to_buffers.items():
mapping[k] = v, False
for k, v in signature.inputs_to_lifted_tensor_constants.items():
mapping[k] = v, False
assert (
k in constants
or k[2:] in constants
or k[2:] in sig_mismatch_constants
or k[2:].replace("getattr_l__self", "getattr_L__self") in constants
), (
f"export_options={export_options!r}"
f"A constant {k!r}, k[2:]={k[2:]!r}, v={v!r} was detected "
f"in the signature was not retrieved from the model. "
f"k in constants={k in constants}, "
f"k[2:] in constants={k[2:] in constants}, "
f"type(constants)={type(constants)}, "
f"\nlist(constants)={pprint.pformat(list(sorted(constants)))}"
f"\nexported_mod.tensor_constants="
f"{pprint.pformat(_get(exported_program, 'tensor_constants'))}"
f"\nexported_mod._constants="
f"{pprint.pformat(_get(exported_program, '_constants'))}"
f"\nsig_mismatch_constants="
f"{pprint.pformat(_get(sig_mismatch_constants))}"
f"\ndir(export_mod)={dir(exported_program)}"
f"\ndir(mod)={dir(mod)}"
)
mask_outputs = [
spec.kind == torch.export.graph_signature.OutputKind.USER_OUTPUT
for spec in exported_program.graph_signature.output_specs
]
else:
mapping = {}
for k in weights:
mapping[k] = k, True
for k in buffers:
mapping[k] = k, False
for k in constants:
mapping[k] = k, False
builder = GraphBuilder(
target_opset,
input_names=input_names,
as_function=as_function,
optimization_options=optimization_options,
args=args,
kwargs=kwargs,
verbose=verbose,
raise_list=raise_list,
dynamic_shapes=dynamic_shapes,
local_domain=local_domain,
signature=inspect.signature(mod.forward),
check_empty_source=True,
graph_module=graph_module,
)
def retrieve(
name,
value,
debug=None,
weights=weights,
buffers=buffers,
mapping=mapping,
constants=constants,
builder=builder,
exc=True,
):
return _retrieve(
name, value, weights, buffers, constants, mapping, builder, debug, exc=exc
)
from .interpreter import DynamoInterpreter
if not submodule_naming:
submodule_naming = SubModuleNaming(mod)
if not parameter_naming:
parameter_naming = ParameterNaming(mod, exported_program=exported_program)
interpreter = DynamoInterpreter(
builder,
retrieve,
dispatcher=dispatcher,
example_inputs=args,
export_options=export_options,
optimize_submodules=optimize_submodules,
function_options=function_options,
submodule_naming=submodule_naming,
parameter_naming=parameter_naming,
module_name=module_name,
default_values=_default_values_from_sig(mod),
)
attr = getattr(export_options, "_last_working", None)
if attr:
# Tweak to retrieve that information
interpreter._working_export_options = attr
return (exported_program or graph_module), builder, interpreter, mask_outputs
def _model_signature(
model: Union["torch.nn.Module", Callable], # noqa: F821
) -> inspect.Signature:
import torch
return inspect.signature(model.forward if isinstance(model, torch.nn.Module) else model)
def _default_values_from_sig(mod: "torch.nn.Module") -> Dict[str, Any]: # noqa: F821
sig = _model_signature(mod)
res = {}
for name, p in sig.parameters.items():
if p.default not in (None, inspect.Parameter.empty):
res[name] = p.default
return res
def _replacements_dynamic_shapes(
mod: Any,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dict_dynamic_shapes: Optional[Dict[str, Any]] = None,
input_names: Optional[List[str]] = None,
verbose: int = 0,
):
assert dict_dynamic_shapes is not None, "dict_dynamic_shapes is missing"
if verbose > 2:
print(f"[_replacements_dynamic_shapes] type(mod)={type(mod)}")
print(f"[_replacements_dynamic_shapes] args={string_type(args)}")
print(f"[_replacements_dynamic_shapes] kwargs={string_type(kwargs)}")
print(f"[_replacements_dynamic_shapes] dict_dynamic_shapes={dict_dynamic_shapes}")
print(f"[_replacements_dynamic_shapes] input_names={input_names}")
new_dynamic_shapes = {}
sig = _model_signature(mod)
true_input_names = []
has_args = None
n_args = None if input_names is None else len(input_names)
for name, p in sig.parameters.items():
if verbose > 3:
print(
f"[_replacements_dynamic_shapes] -- {name}: {p.kind} - "
f"has_args={has_args} - n_args={n_args}"
)
if n_args is not None and n_args <= 0:
break
if p.kind in (
p.VAR_POSITIONAL,
p.VAR_KEYWORD,
p.POSITIONAL_OR_KEYWORD,
p.POSITIONAL_ONLY,
):
assert not has_args, (
f"has_args={has_args} is already specified, "
f"input_names={input_names}, dynamic_shapes="
f"{dict_dynamic_shapes}"
)
assert input_names is None or len(input_names) == len(
args
), f"Mismatch number between args={string_type(args)}, input_names={input_names}"
true_input_names.append(p.name)
if p.kind == p.VAR_POSITIONAL:
if verbose > 3:
print(f"[_replacements_dynamic_shapes] + {p.name}, has_args={has_args}")
has_args = (p.name, len(args), len(true_input_names))
if n_args is not None:
n_args -= len(args)
if verbose > 3:
print(f"[_replacements_dynamic_shapes] + {p.name}, n_args={n_args}")
elif p.default in (None, inspect.Parameter.empty):
true_input_names.append(name)
if verbose > 3:
print(f"[_replacements_dynamic_shapes] + {p.name}")
if has_args is None:
replacements = {} if input_names is None else dict(zip(input_names, true_input_names))
for k, v in dict_dynamic_shapes.items():
r = replacements.get(k, k)
new_dynamic_shapes[r] = v if not has_args or has_args[0] != r else (v,)
if verbose > 2:
print(f"[_replacements_dynamic_shapes] 1> input_names={input_names}")
print(f"[_replacements_dynamic_shapes] 1> input_names={true_input_names}")
print(f"[_replacements_dynamic_shapes] 1> new_dynamic_shapes={new_dynamic_shapes}")
return new_dynamic_shapes
if has_args:
# has_args is supposed to be used when *args is used.
assert input_names is not None, (
f"Not implemented for has_args={has_args}, dynamic_shapes={dict_dynamic_shapes}"
f", input_names={input_names}"
)
assert len(dict_dynamic_shapes) == len(input_names) == has_args[1], (
f"Mismatch for has_args={has_args}, dynamic_shapes={dict_dynamic_shapes}"
f", input_names={input_names}"
)
new_dynamic_shapes = {has_args[0]: tuple(dict_dynamic_shapes[n] for n in input_names)}
if verbose > 2:
print(f"[_replacements_dynamic_shapes] 2> has_args={has_args}")
print(f"[_replacements_dynamic_shapes] 2> input_names={input_names}")
print(
f"[_replacements_dynamic_shapes] 2> dict_dynamic_shapes={dict_dynamic_shapes}"
)
print(f"[_replacements_dynamic_shapes] 2> new_dynamic_shapes={new_dynamic_shapes}")
return new_dynamic_shapes
if verbose > 2:
print(f"[_replacements_dynamic_shapes] 3> new_dynamic_shapes={new_dynamic_shapes}")
return new_dynamic_shapes
[docs]
def is_wrapped(model: Any, dynamic_shapes: Optional[Any] = None) -> bool:
"""
Tells if a model is wrapped.
"""
if len(dynamic_shapes) != 1 or not isinstance(dynamic_shapes[0], tuple):
return False
raise AssertionError(f"Unable to tell for type {type(model)}")
[docs]
def to_onnx(
mod: Union["torch.nn.Module", "torch.fx.GraphModule"], # noqa: F821
args: Optional[Sequence["torch.Tensor"]] = None, # noqa: F821
kwargs: Optional[Dict[str, "torch.Tensor"]] = None, # noqa: F821
input_names: Optional[Sequence[str]] = None,
target_opset: Optional[Union[int, Dict[str, int]]] = None,
as_function: bool = False,
options: Optional[OptimizationOptions] = None,
verbose: int = 0,
return_builder: bool = False,
raise_list: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
optimize: bool = True,
dispatcher: Optional["Dispatcher"] = None, # noqa: F821
large_model: bool = False,
external_threshold: int = 1024,
export_options: Optional[Union[str, ExportOptions]] = None,
return_optimize_report: bool = False,
filename: Optional[str] = None,
inline: bool = False,
export_modules_as_functions: Union[
bool, Set[type["torch.nn.Module"]] # noqa: F821
] = False,
function_options: Optional[FunctionOptions] = None,
) -> Union[
Union[ModelProto, ModelContainer],
Tuple[Union[ModelProto, ModelContainer], GraphBuilder],
]:
"""
Exports a torch model into ONNX using
`dynamo export
<https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html>`_.
:param mod: torch module
:param args: input arguments
:param kwargs: keyword attributes
:param input_names: input names
:param target_opset: targeted opset or targeted opsets as a dictionary
:param as_function: export as a ModelProto or a FunctionProto
:param options: optimization options
:param verbose: verbosity level
:param return_builder: returns the builder as well
:param raise_list: the builder stops any time a name falls into that list,
this is a debbuging tool
:param dynamic_shapes: see :epkg:`torch.export.export`
:param optimize: optimize the model before exporting into onnx
:param dispatcher: see :class:`experimental_experiment.torch_interpreter.Dispatcher`
:param large_model: if True returns a :class:`onnx.model_container.ModelContainer`,
it lets the user to decide later if the weights should be part of the model
or saved as external weights
:param external_threshold: if large_model is True, every tensor above this limit
is stored as external
:param return_optimize_report: returns statistics on the optimization as well
:param filename: if specified, stores the model into that file
:param inline: inline the model before converting to onnx, this is done before
any optimization takes place
:param export_options: to apply differents options before to get the exported program
:param export_modules_as_functions: export submodules as local functions,
this parameter can be filled with a set of class to preserve,
all this other will be exported as usual
:param function_options: to specify what to do with the initializers in local functions,
add them as constants or inputs
:return: onnx model
If environment variable ``PRINT_GRAPH_MODULE`` is set to one,
information about the graph module is printed out.
Environment variable ``ONNXVERBOSE=1`` can be used to
increase verbosity in this function.
Environment variable ``ONNX_BUILDER_PROGRESS=1`` can be used to show
a progress bar on big models.
"""
if target_opset is None:
target_opset = min(18, onnx_opset_version() - 1)
if options is None:
options = OptimizationOptions()
begin = time.perf_counter()
verbose = max(verbose, int(os.environ.get("ONNXVERBOSE", verbose)))
if verbose:
print(f"[to_onnx] build the graph module from {type(mod)}, type(args)={type(args)}")
if input_names:
print(f"[to_onnx] build the graph module with input_names={input_names}")
if dynamic_shapes:
print(f"[to_onnx] dynamic_shapes={dynamic_shapes}")
graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter(
mod=mod,
args=args,
kwargs=kwargs,
input_names=input_names,
target_opset=target_opset,
as_function=as_function,
optimization_options=options,
verbose=verbose,
raise_list=raise_list,
dynamic_shapes=dynamic_shapes,
dispatcher=dispatcher,
export_options=export_options,
optimize_submodules=optimize,
function_options=function_options,
module_name="",
)
add_stats = {}
t = time.perf_counter()
add_stats["time_export_graph_module"] = t - begin
winning_opt = getattr(interpreter, "_working_export_options", None)
if winning_opt:
add_stats["onnx_export_options_strict"] = 1 if winning_opt.strict else 0
if winning_opt.decomposition_table:
add_stats["onnx_export_options_decomp"] = winning_opt.decomposition_table
if winning_opt.tracing:
add_stats["onnx_export_options_tracing"] = 1
if winning_opt.jit:
add_stats["onnx_export_options_jit"] = 1
if winning_opt.dynamo:
add_stats["onnx_export_options_dynamo"] = 1
if verbose:
print(f"[to_onnx] graph module done in {t - begin} s")
if export_modules_as_functions:
import torch.export
assert isinstance(
graph_module, torch.export.ExportedProgram
), f"Unexpected type {type(graph_module)} for graph_module"
if export_modules_as_functions is True:
export_modules_as_functions = set(type(m) for m in mod.modules())
interpreter.register_named_modules(
None, export_modules_as_functions, dict(mod.named_modules())
)
if verbose > 1:
print(
f"[to_onnx] unflatten the graph_module, "
f"preserve {sorted(c.__name__ for c in export_modules_as_functions)}"
)
a = time.perf_counter()
new_graph_module = torch.export.unflatten(graph_module)
add_stats["time_export_unflatten"] = t - a
graph_module = new_graph_module
if verbose > 4:
print(f"[to_onnx] -- fx graph --\n{graph_module.graph}")
if verbose:
print("[to_onnx] start creating the onnx nodes")
print(f"[to_onnx] interpreter.function_options={interpreter.function_options!r}")
begin = t
builder.process(graph_module, interpreter)
t = time.perf_counter()
add_stats["time_export_builder_process"] = t - begin
if verbose:
print(f"[to_onnx] {len(builder.nodes)} onnx nodes done in {t - begin} s")
print(
f"[to_onnx] start conversion to onnx (before optimization) "
f"mask_outputs={mask_outputs}"
)
begin = t
onx, stats = builder.to_onnx(
optimize=optimize,
large_model=large_model,
external_threshold=external_threshold,
return_optimize_report=True,
inline=inline,
function_options=function_options,
mask_outputs=mask_outputs,
)
all_stats = dict(builder=builder.statistics_)
if stats:
add_stats["optimization"] = stats
t = time.perf_counter()
add_stats["time_export_to_onnx"] = t - begin
if verbose:
proto = onx if isinstance(onx, ModelProto) else onx.model_proto
print(
f"[to_onnx] to_onnx done in {t - begin}s "
f"and {len(proto.graph.node)} nodes, "
f"{len(proto.graph.initializer)} initializers, "
f"{len(proto.graph.input)} inputs, "
f"{len(proto.graph.output)} outputs"
)
if verbose >= 10:
print(builder.get_debug_msg())
if filename:
if isinstance(onx, ModelProto):
save_model(onx, filename)
else:
onx.save(filename, all_tensors_to_one_file=True)
all_stats.update(add_stats)
if return_builder:
return (onx, builder, all_stats) if return_optimize_report else (onx, builder)
return (onx, all_stats) if return_optimize_report else onx