Source code for experimental_experiment.torch_interpreter.export_options

import inspect
import pprint
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from ..helpers import string_type, string_sig
from ._doc_ import TorchOpOverload


[docs] class ExportOptions: """ Gathers altogether all the options defining the way to export a model into a graph (not onnx). :param strict: strict export or not :param fallback: fallback to jit :param decomposition_table: decomposition_table, a string as well such as default to use the default decomposition table returned by :func:`get_decomposition_table <experimental_experiment.torch_dynamo.get_decomposition_table>`, it can ``'all'``, ``'default'`` or a decomposition list :param dynamo: to use ``torch._dynamo.export`` instead of :func:`torch.export.export` :param tracing: use symbolic tracing :param jit: use jit to get a graph then converts it into a fx graph :param strategy: to overwrite all the previous parameters with just a value :param remove_inplace: remove inplace nodes :param aten_as_function: keeps aten function as local function to keep a faithful translation of the fx graph. The fallback strategy tries the following in order: .. runpython:: :showcode: import pprint from experimental_experiment.torch_interpreter import ExportOptions print("-- default fallback") pprint.pprint(ExportOptions().get_fallback_options()) print("-- default fallback with decomposition") pprint.pprint( ExportOptions(decomposition_table="default").get_fallback_options() ) Most of the models works with strict=True or False and no decompositions. But if it contains control flows (test or loop), inplace modifications, it may be useful to try different values for strict and to apply decompositions ``decomposition_table='default'``. The decompositions removes unused results coming from inplace modifications. A graph is considered as invalid if decompositions were not run and there is one node with no user. This usually indicates one inplace operation is still part of the graph. """ _allowed = { None: {}, "none": {}, "strict": {"strict": True}, "strict-dec": {"strict": True, "decomposition_table": "default"}, "strict-decall": {"strict": True, "decomposition_table": "all"}, "tracing": {"tracing": True}, "nostrict": {"strict": False}, "nostrict-dec": {"strict": False, "decomposition_table": "default"}, "nostrict-decall": {"strict": False, "decomposition_table": "all"}, "jit": {"jit": True}, "jit-dec": {"jit": True, "decomposition_table": "default"}, "jit-decall": {"jit": True, "decomposition_table": "all"}, "fallback": {"fallback": True}, "fallback-dec": {"fallback": True, "decomposition_table": "default"}, "fallback-decall": {"fallback": True, "decomposition_table": "all"}, "dec": {"decomposition_table": "default"}, "decall": {"decomposition_table": "all"}, } def __init__( self, strict: bool = True, fallback: bool = False, tracing: bool = False, jit: bool = False, decomposition_table: Optional[ Union[str, Dict[TorchOpOverload, Callable[..., Any]]] # noqa: F821 ] = None, strategy: Optional[str] = None, dynamo: bool = False, aten_as_function: bool = False, remove_inplace: bool = True, ): self.strict = strict self.fallback = fallback self.tracing = tracing self.decomposition_table = ( None if decomposition_table in ("none", None) else decomposition_table ) self.dynamo = dynamo self.strategy = strategy self.jit = jit self.aten_as_function = aten_as_function self.remove_inplace = remove_inplace if strategy is not None: assert strategy in self._allowed, ( f"Unexpected value for strategy={strategy!r}, " f"it should be in {sorted(k for k in self._allowed if k is not None)}" ) kwargs = self._allowed[strategy] for k, v in kwargs.items(): setattr(self, k, v) assert ( not self.dynamo or not self.jit ), "jit and dynamo cannot be true at the same time" assert self.strict or not self.jit, "jit and strict cannot be true at the same time" assert ( self.strict or not self.dynamo ), "strict and dynamo cannot be true at the same time" assert ( not tracing or not dynamo ), f"Both tracing and dynamo are incompatible options in {self!r}" assert ( not tracing or strict ), f"Both tracing and strict=False are incompatible options in {self!r}" def __repr__(self) -> str: return string_sig(self)
[docs] def get_decomposition_table( self, ) -> Dict[TorchOpOverload, Callable[..., Any]]: # noqa: F821 "Returns the decompisitions table." if self.decomposition_table is None: return None if isinstance(self.decomposition_table, str): from ..torch_dynamo import get_decomposition_table_by_name return get_decomposition_table_by_name(self.decomposition_table) assert isinstance( self.decomposition_table, dict ), f"Unexpected type {type(self.decomposition_table)} for decomposition_table" return self.decomposition_table
[docs] def get_fallback_options(self, kind: Optional[str] = None) -> List["ExportOptions"]: """Returns the fallback scenario.""" if kind is None or kind in ("fallback", "fallback-dec", "fallback-decall"): other_dec = None if self.decomposition_table else "default" return [ ExportOptions(strict=True, decomposition_table=self.decomposition_table), ExportOptions(strict=False, decomposition_table=self.decomposition_table), ExportOptions(strict=True, decomposition_table=other_dec), ExportOptions(strict=False, decomposition_table=other_dec), ExportOptions(dynamo=True, decomposition_table=self.decomposition_table), ExportOptions(dynamo=True, decomposition_table=other_dec), ExportOptions(jit=True, decomposition_table=self.decomposition_table), ] if kind == "strict": return [ExportOptions(strict=True), ExportOptions(strict=False)] if kind == "nostrict": return [ExportOptions(strict=False), ExportOptions(strict=True)] if kind in ("jit"): return [ ExportOptions(strict=True), ExportOptions(jit=True, decomposition_table=self.decomposition_table), ] raise AssertionError(f"Unable to return fallback strategy with kind={kind!r}")
[docs] def export( self, mod: Any, args: Optional[Tuple[Any, ...]], kwargs: Optional[Dict[str, Any]], tracing_mode: bool, dynamic_shapes: Dict, same_signature: bool, input_names: Optional[List[str]] = None, exc: bool = True, verbose: int = 0, ) -> Union["torch.export.ExportedProgram", "torch.fx.GraphModule"]: # noqa: F821 """Exports the model into an exported program.""" import torch from .tracing import CustomTracer if self.fallback or self.strategy in { "fallback", "fallback-dec", "fallback-decomposition", }: self._last_working = None if verbose: print("[ExportOptions.export] fallback") tries = self.get_fallback_options(self.strategy) excs = [] for ion, opt in enumerate(tries): if verbose: print(f"[ExportOptions.export] tries {ion+1}/{len(tries)}: {opt}") try: res = opt.export( mod, args, kwargs, tracing_mode=tracing_mode, dynamic_shapes=dynamic_shapes, same_signature=same_signature, input_names=input_names, exc=False, verbose=max(verbose - 1, 0), ) except Exception as e: excs.append((opt, e)) if verbose: se = str(e).split("\n", maxsplit=1)[0] print(f"[ExportOptions.export] fails due to {se}") continue if isinstance(res, torch.export.ExportedProgram): inplace_nodes = CustomTracer._inplace_nodes(res.graph) if inplace_nodes: # One node has no users, this usually # indicates an inplace modifications. # This is rejected. excs.append( ( opt, f"Probable inplace modifications, " f"there are nodes with no users: {inplace_nodes}.", ) ) if verbose: print(f"[ExportOptions.export] fails due to {excs[-1][-1]}") if not opt.decomposition_table: # We try with decomposition if possible and to save time. if verbose: print( f"[ExportOptions.export] current decomposition_table=" f"{opt.decomposition_table}, let's try with 'default'" ) res = apply_decompositions(res, "default") inplace_nodes = CustomTracer._inplace_nodes(res.graph) if inplace_nodes: # it fails excs.append( ( opt, f"Probable inplace modifications, " f"even after decomposition. " f"there are nodes with no users: {inplace_nodes}.", ) ) if verbose: print( f"[ExportOptions.export] fails again with " f"{excs[-1][-1]}" ) continue opt.decomposition_table = "default" else: continue if verbose: print(f"[ExportOptions.export] winning options {opt}") self._last_working = opt return res if exc: raise RuntimeError( f"None of the following options {tries} worked, " f"args={string_type(args)}, kwargs={string_type(kwargs)}, " f"exception=\n-----\n{pprint.pformat(excs)}" ) return None if verbose: print( f"[ExportOptions.export] {self!r} - torch.export.export {type(mod).__name__!r}" ) begin = time.perf_counter() if self.tracing: from .tracing import CustomTracer concrete_args = kwargs.copy() if kwargs else {} if args: sig = inspect.signature(mod.forward) for p, a in zip(sig.parameters, args): if a is not None and p not in concrete_args: concrete_args[p] = a graph = CustomTracer().trace(mod, concrete_args=concrete_args) gm = torch.fx.GraphModule(mod, graph) return gm if self.dynamo: # import torch.utils._pytree as pytree # flat_args, orig_in_spec = pytree.tree_flatten((args, )) # debug: orig_in_spec, type(flat_args), len(flat_args)) if verbose: print("[ExportOptions.export] torch._dynamo.export") res = torch._dynamo.export( mod, aten_graph=True, tracing_mode=tracing_mode, dynamic_shapes=dynamic_shapes, same_signature=same_signature, decomposition_table=self.get_decomposition_table(), assume_static_by_default=dynamic_shapes is None, )(*(args or tuple()), **(kwargs or {})) if verbose: print(f"[ExportOptions.export] done in {time.perf_counter() - begin}") return res # _apply_decompositions(res, self.decomposition_table) if self.jit: if verbose: print("[ExportOptions.export] torch.jit.trace") from torch._export.converter import TS2EPConverter jit_model = torch.jit.trace( mod, example_inputs=args, check_trace=False, strict=False ) res = TS2EPConverter(jit_model, args, kwargs).convert() dec = apply_decompositions(res, self.decomposition_table) if verbose: print(f"[ExportOptions.export] done in {time.perf_counter() - begin}") return dec if verbose: print( f"[ExportOptions.export] torch.export.export " f"strict={self.strict}, verbose={verbose}" ) print(f"[ExportOptions.export] dynamic_shapes={dynamic_shapes}") print(f"[ExportOptions.export] args={string_type(args)}") print(f"[ExportOptions.export] kwargs={string_type(kwargs)}") if exc: exported_program = torch.export.export( mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=self.strict ) else: try: exported_program = torch.export.export( mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=self.strict ) except torch._export.verifier.SpecViolationError: # see https://github.com/pytorch/pytorch/issues/128394 if verbose: print("[ExportOptions.export] torch.export._trace._export") exported_program = torch.export._trace._export( mod, args, kwargs, dynamic_shapes=dynamic_shapes, pre_dispatch=False, strict=self.strict, ) except torch._dynamo.exc.UserError as e: eee = None if verbose: print("[ExportOptions.export] torch.export.export") try: exported_program = torch.export.export( mod, args, kwargs, strict=self.strict ).graph except torch._export.verifier.SpecViolationError as ee: exported_program = None eee = ee raise RuntimeError( f"Unable to convert model {type(mod)}, " f"type(args)={type(args)}, type(args[0])=" f"{type(args[0]) if isinstance(args, tuple) and args else '?'}, " f"strict={self.strict}, input_names={input_names}\n--\n" f"dynamic_shapes={dynamic_shapes}\n--\ne={e}\n--\neee={eee}" f"\n---exported-program---\n{exported_program}" ) from e if exported_program is None: if verbose: print(f"[ExportOptions.export] done in {time.perf_counter() - begin}") return exported_program if self.decomposition_table: dec = apply_decompositions(exported_program, self.decomposition_table) if verbose: print( f"[ExportOptions.export] done after decomposition " f"in {time.perf_counter() - begin}" ) return dec if self.remove_inplace: removed = CustomTracer.remove_unnecessary_slices(exported_program.graph) if removed: if verbose: print( f"[ExportOptions.export] slices: {removed} slices nodes were removed" ) exported_program.graph.lint() modified = CustomTracer.remove_inplace( exported_program.graph, exported_program=exported_program ) if modified: if verbose: print( f"[ExportOptions.export] inplaces: " f"{modified} inplaced nodes were removed" ) exported_program.graph.lint() if verbose: print( f"[ExportOptions.export] done with no decomposition " f"in {time.perf_counter() - begin}" ) return exported_program
def apply_decompositions( exported_program: "torch.export.ExportedProgram", decomposition_table # noqa: F821 ) -> "torch.export.ExportedProgram": # noqa: F821 if decomposition_table == "all": exported_program = insert_contiguous_between_transpose_and_view(exported_program) exported_program = exported_program.run_decompositions() return exported_program if isinstance(decomposition_table, str): from ..torch_dynamo import get_decomposition_table_by_name decomposition_table = get_decomposition_table_by_name(decomposition_table) if decomposition_table is not None: exported_program = insert_contiguous_between_transpose_and_view(exported_program) exported_program = exported_program.run_decompositions(decomposition_table) return exported_program
[docs] def insert_contiguous_between_transpose_and_view( exported_program: "torch.export.ExportedProgram", # noqa: F821 ) -> "torch.export.ExportedProgram": # noqa: F821 """ Modifies the module inplace to insert a node 'contiguous' between a node 'transpose' followed by a node 'view'. The modification takes place inplace. See issue https://github.com/pytorch/pytorch/issues/136543. """ modified = False graph = exported_program.graph_module.graph for node in graph.nodes: if (node.op != "call_method" or node.target != "transpose") and ( node.op != "call_function" or not hasattr(node.target, "name") or node.target.name() != "aten::transpose.int" ): continue insert = False for user in node.users: if (user.op == "call_method" and user.target == "view") or ( user.op == "call_function" and hasattr(node.target, "name") and user.target.name() == "aten::view" ): insert = True break if not insert: continue modified = True with graph.inserting_after(node): new_node = graph.call_method("contiguous", args=(node,)) node.replace_all_uses_with(new_node) # new_node is replaced as well so we manually revert the replacement new_node.update_arg(0, node) node.users = {new_node: None} if not modified: # no rewrite was done. return exported_program graph.lint() return exported_program