import pprint
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from ..helpers import string_type, string_sig
from ._doc_ import TorchOpOverload
[docs]
class ExportOptions:
"""
Gathers altogether all the options defining the way to export a model into a graph
(not onnx).
:param strict: strict export or not
:param fallback: fallback to jit
:param decomposition_table: decomposition_table, a string as well such as default
to use the default decomposition table returned by
:func:`get_decomposition_table
<experimental_experiment.torch_dynamo.get_decomposition_table>`
:param dynamo: to use ``torch._dynamo.export`` instead of:func:`torch.export.export`
:param jit: use jit to get a graph then converts it into a fx graph
:param strategy: to overwrite all the previous parameters with just a value
The fallback strategy tries the following in order:
.. runpython::
:showcode:
import pprint
from experimental_experiment.torch_interpreter import ExportOptions
pprint.pprint(ExportOptions().get_fallback_options())
pprint.pprint(
ExportOptions(decomposition_table="default").get_fallback_options()
)
"""
_allowed = {
None: {},
"none": {},
"strict": {"strict": True},
"nostrict": {"strict": False},
"jit": {"jit": True},
"fallback": {"fallback": True},
"fallback-default": {"fallback": True, "decomposition_table": "default"},
"default": {"decomposition_table": "default"},
}
def __init__(
self,
strict: bool = True,
fallback: bool = False,
jit: bool = False,
decomposition_table: Optional[
Union[str, Dict[TorchOpOverload, Callable[..., Any]]] # noqa: F821
] = None,
strategy: Optional[str] = None,
dynamo: bool = False,
):
self.strict = strict
self.fallback = fallback
self.decomposition_table = (
None if decomposition_table in ("none", None) else decomposition_table
)
self.dynamo = dynamo
self.strategy = strategy
self.jit = jit
if strategy is not None:
assert strategy in self._allowed, (
f"Unexpected value for strategy={strategy!r}, "
f"it should be in {sorted(k for k in self._allowed if k is not None)}"
)
kwargs = self._allowed[strategy]
for k, v in kwargs.items():
setattr(self, k, v)
assert (
not self.dynamo or not self.jit
), "jit and dynamo cannot be true at the same time"
assert self.strict or not self.jit, "jit and strict cannot be true at the same time"
assert (
self.strict or not self.dynamo
), "strict and dynamo cannot be true at the same time"
def __repr__(self) -> str:
return string_sig(self)
[docs]
def get_decomposition_table(
self,
) -> Dict[TorchOpOverload, Callable[..., Any]]: # noqa: F821
"Returns the decompisitions table."
if self.decomposition_table is None:
return None
if isinstance(self.decomposition_table, str):
from ..torch_dynamo import get_decomposition_table_by_name
return get_decomposition_table_by_name(self.decomposition_table)
assert isinstance(
self.decomposition_table, dict
), f"Unexpected type {type(self.decomposition_table)} for decomposition_table"
return self.decomposition_table
[docs]
def get_fallback_options(self, kind: Optional[str] = None) -> List["ExportOptions"]:
"""Returns the fallback scenario."""
if kind is None or kind in ("fallback", "fallback-default", "default"):
return [
ExportOptions(strict=True, decomposition_table=self.decomposition_table),
ExportOptions(strict=False, decomposition_table=self.decomposition_table),
ExportOptions(dynamo=True, decomposition_table=self.decomposition_table),
ExportOptions(strict=True),
ExportOptions(strict=False),
ExportOptions(jit=True, decomposition_table=self.decomposition_table),
]
if kind == "strict":
return [ExportOptions(strict=True), ExportOptions(strict=False)]
if kind == "nostrict":
return [ExportOptions(strict=False), ExportOptions(strict=True)]
if kind in ("jit"):
return [
ExportOptions(strict=True),
ExportOptions(jit=True, decomposition_table=self.decomposition_table),
]
raise AssertionError(f"Unable to return fallback strategy with kind={kind!r}")
[docs]
def export(
self,
mod: Any,
args: Optional[Tuple[Any, ...]],
kwargs: Optional[Dict[str, Any]],
tracing_mode: bool,
dynamic_shapes: Dict,
same_signature: bool,
input_names: Optional[List[str]] = None,
exc: bool = True,
verbose: int = 0,
):
"""Exports the model into an exported program."""
import torch
if self.fallback or self.strategy == "fallback":
if verbose:
print("[ExportOptions.export] fallback")
tries = self.get_fallback_options(self.strategy)
excs = []
for ion, opt in enumerate(tries):
if verbose:
print(f"[ExportOptions.export] tries {ion+1}/{len(tries)}: {opt}")
try:
return opt.export(
mod,
args,
kwargs,
tracing_mode=tracing_mode,
dynamic_shapes=dynamic_shapes,
same_signature=same_signature,
input_names=input_names,
exc=False,
verbose=verbose,
)
except Exception as e:
excs.append(e)
if exc:
raise RuntimeError(
f"None of the following options {tries} worked, "
f"args={string_type(args)}, kwargs={string_type(kwargs)}, "
f"exception=\n-----\n{pprint.pformat(excs)}"
)
return None
if verbose:
print(
f"[ExportOptions.export] {self!r} - torch.export.export {type(mod).__name__!r}"
)
begin = time.perf_counter()
if self.dynamo:
# import torch.utils._pytree as pytree
# flat_args, orig_in_spec = pytree.tree_flatten((args, ))
# debug: orig_in_spec, type(flat_args), len(flat_args))
if verbose:
print("[ExportOptions.export] torch._dynamo.export")
res = torch._dynamo.export(
mod,
aten_graph=True,
tracing_mode=tracing_mode,
dynamic_shapes=dynamic_shapes,
same_signature=same_signature,
decomposition_table=self.get_decomposition_table(),
assume_static_by_default=dynamic_shapes is None,
)(*(args or tuple()), **(kwargs or {}))
if verbose:
print(f"[ExportOptions.export] done in {time.perf_counter() - begin}")
return res # _apply_decompositions(res, self.decomposition_table)
if self.jit:
if verbose:
print("[ExportOptions.export] torch.jit.trace")
from torch._export.converter import TS2EPConverter
jit_model = torch.jit.trace(
mod, example_inputs=args, check_trace=False, strict=False
)
res = TS2EPConverter(jit_model, args, kwargs).convert()
dec = apply_decompositions(res, self.decomposition_table)
if verbose:
print(f"[ExportOptions.export] done in {time.perf_counter() - begin}")
return dec
if verbose:
print("[ExportOptions.export] torch.export.export")
print(f"[ExportOptions.export] dynamic_shapes={dynamic_shapes}")
print(f"[ExportOptions.export] strict={self.strict}")
print(f"[ExportOptions.export] args={string_type(args)}")
print(f"[ExportOptions.export] kwargs={string_type(kwargs)}")
print(f"[ExportOptions.export] verbose={verbose}")
if exc:
exported_program = torch.export.export(
mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=self.strict
)
else:
try:
exported_program = torch.export.export(
mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=self.strict
)
except torch._export.verifier.SpecViolationError:
# see https://github.com/pytorch/pytorch/issues/128394
if verbose:
print("[ExportOptions.export] torch.export._trace._export")
exported_program = torch.export._trace._export(
mod,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
pre_dispatch=False,
strict=self.strict,
)
except torch._dynamo.exc.UserError as e:
eee = None
if verbose:
print("[ExportOptions.export] torch.export.export")
try:
exported_program = torch.export.export(
mod, args, kwargs, strict=self.strict
).graph
except torch._export.verifier.SpecViolationError as ee:
exported_program = None
eee = ee
raise RuntimeError(
f"Unable to convert model {type(mod)}, "
f"type(args)={type(args)}, type(args[0])="
f"{type(args[0]) if isinstance(args, tuple) and args else '?'}, "
f"strict={self.strict}, input_names={input_names}\n--\n"
f"dynamic_shapes={dynamic_shapes}\n--\ne={e}\n--\neee={eee}"
f"\n---exported-program---\n{exported_program}"
) from e
if exported_program is None:
if verbose:
print(f"[ExportOptions.export] done in {time.perf_counter() - begin}")
return exported_program
if self.decomposition_table:
dec = apply_decompositions(exported_program, self.decomposition_table)
if verbose:
print(f"[ExportOptions.export] done in {time.perf_counter() - begin}")
return dec
if verbose:
print(f"[ExportOptions.export] done in {time.perf_counter() - begin}")
return exported_program
def apply_decompositions(
exported_program: "torch.export.ExportedProgram", decomposition_table # noqa: F821
) -> "torch.export.ExportedProgram": # noqa: F821
if decomposition_table == "all":
exported_program = insert_contiguous_between_transpose_and_view(exported_program)
exported_program = exported_program.run_decompositions()
return exported_program
if isinstance(decomposition_table, str):
from ..torch_dynamo import get_decomposition_table_by_name
decomposition_table = get_decomposition_table_by_name(decomposition_table)
if decomposition_table is not None:
exported_program = insert_contiguous_between_transpose_and_view(exported_program)
exported_program = exported_program.run_decompositions(decomposition_table)
return exported_program
[docs]
def insert_contiguous_between_transpose_and_view(
exported_program: "torch.export.ExportedProgram", # noqa: F821
) -> "torch.export.ExportedProgram": # noqa: F821
"""
Modifies the module inplace to insert a node 'contiguous' between a node
'transpose' followed by a node 'view'.
The modification takes place inplace.
See issue https://github.com/pytorch/pytorch/issues/136543.
"""
modified = False
graph = exported_program.graph_module.graph
for node in graph.nodes:
if (node.op != "call_method" or node.target != "transpose") and (
node.op != "call_function"
or not hasattr(node.target, "name")
or node.target.name() != "aten::transpose.int"
):
continue
insert = False
for user in node.users:
if (user.op == "call_method" and user.target == "view") or (
user.op == "call_function"
and hasattr(node.target, "name")
and user.target.name() == "aten::view"
):
insert = True
break
if not insert:
continue
modified = True
with graph.inserting_after(node):
new_node = graph.call_method("contiguous", args=(node,))
node.replace_all_uses_with(new_node)
# new_node is replaced as well so we manually revert the replacement
new_node.update_arg(0, node)
node.users = {new_node: None}
if not modified:
# no rewrite was done.
return exported_program
graph.lint()
return exported_program