from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
from onnx import ModelProto
from .debug_backend import onnx_debug_backend
from .fast_backend import onnx_custom_backend
from .dynger_backend import dynger_backend
[docs]
def get_decomposition_table_by_name(name: str):
"""
Returns a predefined decomposition table.
:param name: name see below
:return: decomposition table
* 'none': do not decompose
* 'default': :func:`get_decomposition_table`
* 'onnxscript': :func:`get_decomposition_table_onnxscript`
* 'dynamo': :func:`get_decomposition_table_dynamo`
"""
if name is None or isinstance(name, dict):
return name
mapping = {
"none": lambda: None,
"default": get_decomposition_table,
"onnxscript": get_decomposition_table_onnxscript,
"dynamo": get_decomposition_table_dynamo,
"onnxscript2": lambda: (
filter_decomposition_table(
get_decomposition_table_onnxscript(),
lambda op: "view" not in op.name() or "copy" in op.name(),
)
),
}
if name in mapping:
return mapping[name]()
raise AssertionError(f"Unknown decomposition table name={name!r} among {sorted(mapping)}")
[docs]
def get_decomposition_table():
"""
Returns the decomposition table needed to translate backward
graph into onnx. It should used as follows:
::
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.torch_dynamo import get_decomposition_table
aot_compiler = aot_autograd(
fw_compiler=backend_debug,
decompositions=get_decomposition_table()
)
compiled_model = torch.compile(
model,
backend=aot_compiler,
dynamic=dynamic,
fullgraph=fullgraph,
)
The value is:
.. runpython::
:showcode:
import pprint
from experimental_experiment.torch_dynamo import get_decomposition_table
pprint.pprint(get_decomposition_table())
"""
import torch
new_table = {}
for k, v in torch._decomp.decomposition_table.items():
if k.name() in {
"aten::embedding_dense_backward",
"aten::rrelu_with_noise",
"aten::native_layer_norm_backward",
}:
new_table[k] = v
return new_table
[docs]
def get_decomposition_table_onnxscript():
"""
Returns the decomposition table used by :func:`torch.onnx.export`.
The value is:
.. runpython::
:showcode:
import pprint
from experimental_experiment.torch_dynamo import get_decomposition_table_onnxscript
pprint.pprint(get_decomposition_table_onnxscript())
"""
from torch.onnx._internal.exporter._registration import ONNXRegistry
from torch.onnx._internal.exporter._decomp import (
get_onnx_implemented_overloads,
create_onnx_friendly_decomposition_table,
)
registry = ONNXRegistry.from_torchlib()
onnx_registered_ops = set(get_onnx_implemented_overloads(registry))
decomposition_table = create_onnx_friendly_decomposition_table(onnx_registered_ops)
return decomposition_table
[docs]
def get_decomposition_table_dynamo(onnx_registry=None):
"""
Returns the decomposition table needed for the dynamo exporter.
:param onnx_registry: instance of class
``torch.onnx._internal.exporter.OnnxRegistry``
The value is:
.. runpython::
:showcode:
import pprint
from experimental_experiment.torch_dynamo import get_decomposition_table_dynamo
pprint.pprint(get_decomposition_table_dynamo())
"""
from torch.onnx._internal.fx.decomposition_table import (
create_onnx_friendly_decomposition_table,
)
try:
from torch.onnx._internal._exporter_legacy import OnnxRegistry
except ImportError:
from torch.onnx._internal.exporter import OnnxRegistry
return create_onnx_friendly_decomposition_table(onnx_registry or OnnxRegistry())
[docs]
def filter_decomposition_table(
existing_table: Optional[Dict] = None, filter_fct: Optional[Callable[Any, bool]] = None
) -> Dict:
"""
Returns the decomposition table when some conversions because
their translation in ONNX is less efficient.
:param existing_table: dictionary of decompositions, by default,
it is ``torch._decomp.decomposition_table``.
:param filter_fct: if specified, a decomposition function is remove
if the function returns false
:return: new table
::
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.torch_dynamo import filter_decomposition_table
aot_compiler = aot_autograd(
fw_compiler=backend_debug,
decompositions=filter_decomposition_table()
)
compiled_model = torch.compile(
model,
backend=aot_compiler,
dynamic=dynamic,
fullgraph=fullgraph,
)
The value is:
.. runpython::
:showcode:
import pprint
from experimental_experiment.torch_dynamo import filter_decomposition_table
pprint.pprint(filter_decomposition_table())
"""
if existing_table is None:
import torch
existing_table = torch._decomp.decomposition_table.items()
elif isinstance(existing_table, dict):
existing_table = existing_table.items()
new_table = {}
for k, v in existing_table:
if filter_fct:
if not filter_fct(k):
continue
elif k.name() in {
"aten::slice_backward",
"aten::select_backward.out",
"aten::slice.Tensor",
}:
continue
new_table[k] = v
return new_table
def _single_print(v):
if v is None:
return "None"
if isinstance(v, (int, bool, str, float)):
return str(v)
if isinstance(v, np.ndarray):
return f"array:{v.dtype}:{v.shape}:{v.mean()}"
if hasattr(v, "numpy"):
return _single_print(v.detach().cpu().numpy())
if isinstance(v, ModelProto):
s = str(v).replace("\n", "").replace(" ", "")
return "ModelProto:" + s[:20] + "..." + s[-20:]
if "GraphModule" in str(type(v)):
s = str(v).replace("\n", "")
return "GraphModule:" + s[:20] + "..." + s[-20:]
if "GraphBuilder" in str(type(v)):
s = str(v).replace("\n", "")
return "GraphBuilder:" + s[:20] + "..." + s[-20:]
if "ExtendedReferenceEvaluator" in str(type(v)):
s = str(v).replace("\n", "")
return "ExtendedReferenceEvaluator:" + s[:20] + "..." + s[-20:]
raise TypeError(f"Unexpected type {type(v)}.")
[docs]
def pprint_storage(
storage: Any, indent: int = 0, as_list: bool = False
) -> Union[List[str], str]:
"""
Pretty print for the storage.
:param storage: any object
:param indent: indentation
:param as_list: return list or string
:return: list or string
"""
sind = " " * indent
if isinstance(storage, (np.ndarray, int, float, str, bool, type(None))):
rows = [sind + _single_print(storage)]
elif isinstance(storage, dict):
if len(storage) <= 10 and all(
isinstance(v, (int, float, str, bool, type(None))) for v in storage.values()
):
rows = [sind + str(storage)]
else:
rows = [sind + "{"]
for k, v in storage.items():
r = pprint_storage(v, indent=indent + 1, as_list=True)
if len(r) > 1:
r[0] = f"{r[0][:-1]}{k!r}: {r[0][-1]}"
r[-1] += ","
rows.extend(r)
else:
rows.append(f" {sind}{k!r}: {r[0].lstrip(' ')},")
rows.append(sind + "}")
elif isinstance(storage, list):
if len(storage) <= 10 and all(
isinstance(v, (int, float, str, bool, type(None))) for v in storage
):
rows = [sind + str(storage)]
else:
rows = [sind + "["]
for v in storage:
r = pprint_storage(v, indent=indent + 1, as_list=True)
if len(r) > 1:
r[-1] += ","
rows.extend(r)
else:
rows.append(f" {sind}{r[0].lstrip(' ')},")
rows.append(sind + "]")
elif isinstance(storage, tuple):
if len(storage) <= 10 and all(
isinstance(v, (int, float, str, bool, type(None))) for v in storage
):
rows = [sind + str(storage)]
else:
rows = [sind + "("]
for v in storage:
r = pprint_storage(v, indent=indent + 1, as_list=True)
if len(r) > 1:
r[-1] += ","
rows.extend(r)
else:
rows.append(f" {sind}{r[0].lstrip(' ')},")
rows.append(sind + ")")
elif hasattr(storage, "numpy"):
rows = [sind + _single_print(storage)]
elif storage is None:
rows = [sind + _single_print(storage)]
elif (
"GraphModuleImpl" in str(type(storage))
or "ModelProto" in str(type(storage))
or "GraphBuilder" in str(type(storage))
or "ExtendedReferenceEvaluator" in str(type(storage))
):
rows = [sind + _single_print(storage)]
else:
raise RuntimeError(f"Unexpected type {type(storage)}")
if as_list:
return rows
return "\n".join(rows)