import contextlib
import copy
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from ..helpers import string_type, max_diff
[docs]
def make_copy(obj: Any) -> Any:
"""Makes a copy of the objects."""
if isinstance(obj, np.ndarray):
return obj.copy()
if isinstance(obj, tuple):
return tuple(make_copy(_) for _ in obj)
if isinstance(obj, list):
return [make_copy(_) for _ in obj]
if isinstance(obj, dict):
return {k: make_copy(v) for k, v in obj.items()}
if hasattr(obj, "clone"):
return obj.clone()
if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
cache = obj.__class__()
if hasattr(obj, "_seen_tokens"):
cache._seen_tokens = obj._seen_tokens
cache.key_cache = make_copy(obj.key_cache)
cache.value_cache = make_copy(obj.value_cache)
return cache
try:
return copy.deepcopy(obj)
except RuntimeError as e:
raise RuntimeError(
f"deepcopy did not work on type {type(obj)}: {string_type(obj)}"
) from e
[docs]
class ModelDiagnoseOutput:
"""
Contains inputs and outputs, diagnose results when tracing
intermediate results. An instance of this class is produced
by :func:`infer_shape_type_from_execution`.
Example :ref:`l-plot-exporter-recipes-custom-phi35` tells you
more about how to use this class.
"""
def __init__(self, name: str, model: torch.nn.Module, level: int = 0):
self.name = name
self.model = model
self.level = level
self.forward = model.forward
self.inputs = []
self.outputs = []
self.children: List[ModelDiagnoseOutput] = []
assert not isinstance(model, torch.nn.ModuleList), "ModuleList should not be traced."
[docs]
def pretty_text(
self,
with_dynamic_shape: bool = False,
with_shape: bool = True,
with_min_max: bool = True,
with_device: bool = True,
with_inputs: bool = True,
) -> str:
"""
Renders the outputs.
:param with_dynamic_shape: show dynamic shapes
:param with_shape: see :func:`experimental_experiment.helpers.string_type`.
:param with_min_max: see :func:`experimental_experiment.helpers.string_type`.
:param with_device: see :func:`experimental_experiment.helpers.string_type`.
:param with_inputs: show inputs and outputs shapes
:return: text
"""
assert len(self.inputs) == len(self.outputs), (
f"Number if inputs / outputs mismatch {len(self.inputs)} != "
f"{len(self.outputs)}"
)
kws = dict(with_shape=with_shape, with_min_max=with_min_max, with_device=with_device)
indent = " " * self.level
if not self.children and not with_inputs and not any(kws.values()):
return (
(
f"{indent}>>> {self.name}: {self.model.__class__.__name__}: "
f"DS={self.guess_dynamic_shapes()} <<<"
)
if with_dynamic_shape
else f"{indent}>>> {self.name}: {self.model.__class__.__name__} <<<"
)
rows = [f"{indent}>>> {self.name}: {self.model.__class__.__name__}"]
if with_dynamic_shape:
ds = self.guess_dynamic_shapes()
rows.append(f"{indent} DS={ds}")
if with_inputs:
for i in self.inputs:
rows.append(f"{indent} > {string_type(i, **kws)}")
for child in self.children:
t = child.pretty_text(
with_dynamic_shape=with_dynamic_shape, with_inputs=with_inputs, **kws
)
rows.extend(t.split("\n"))
if with_inputs:
for i in self.outputs:
rows.append(f"{indent} < {string_type(i, **kws)}")
rows.append(f"{indent}<<<")
return "\n".join(rows)
@property
def full_name(self):
"Returns a name and class name."
return f"{self.name}:{self.model.__class__.__name__}"
@property
def dot_name(self):
"Returns a kind of indented name."
return f"{'..' * self.level}{self.name} - {self.model.__class__.__name__}"
@property
def module_name_type(self):
"Returns name and module type."
return f"type({self.name})={self.model.__class__.__name__}"
[docs]
def add_outputs(self, args: Tuple[Any, ...]):
"""Stores returned outputs. Makes a copy."""
if not isinstance(args, tuple):
args = (args,)
self.outputs.append(make_copy(args))
[docs]
def add_child(self, diag: "ModelDiagnoseOutput"):
"""Adds a submodule."""
self.children.append(diag)
[docs]
def guess_dynamic_dimensions(self, *tensors) -> Any:
"""Infers the dynamic dimension from multiple shapes."""
if len(tensors) == 1:
return {}
shapes = [t.shape for t in tensors]
set_length = set(len(s) for s in shapes)
assert len(set_length) == 1, (
f"Shapes can be different but not ranks possible shapes={set_length} "
f"shapes={shapes} for module {self.name!r}, "
f"class={self.model.__class__.__name__!r}"
)
dynamic = torch.export.Dim.DYNAMIC
rk = set_length.pop()
res = {}
for i in range(rk):
if len(set(s[i] for s in shapes)) > 1:
res[i] = dynamic
return res
[docs]
def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None) -> Any:
"""
Guesses the dynamic shapes for one argument.
"""
assert (
len(objs) > 1
), f"Unable to infer shapes with only one object {string_type(objs)}"
set_types = set(type(o) for o in objs)
assert (
len(set_types) == 1
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
obj = objs[0]
if obj is None:
return None
if isinstance(obj, (bool, int, float, str)):
return None
if isinstance(obj, (torch.Tensor, np.ndarray)):
return self.guess_dynamic_dimensions(*objs)
if isinstance(obj, tuple):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
shapes = []
for i in range(kl.pop()):
shapes.append(self.guess_dynamic_shape_object(*[o[i] for o in objs]))
return tuple(shapes)
if isinstance(obj, list):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
shapes = []
for i in range(kl.pop()):
shapes.append(self.guess_dynamic_shape_object(*[o[i] for o in objs]))
return shapes
if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
kc = set(len(o.key_cache) for o in objs)
assert (
len(kc) == 1
), f"All attribute 'key_cache' should have the same length but found {kc}"
vc = set(len(o.value_cache) for o in objs)
assert (
len(vc) == 1
), f"All attribute 'value_cache' should have the same length but found {vc}"
key_cache = []
for i in range(kc.pop()):
key_cache.append(
self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs])
)
value_cache = []
for i in range(vc.pop()):
value_cache.append(
self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs])
)
return [key_cache, value_cache]
raise NotImplementedError(
f"Unable to build dynamic shapes for type {set_types.pop()}: "
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}"
)
[docs]
def guess_dynamic_shapes(self) -> Any:
"""
Guesses the dynamic shapes for that module from two execution.
If there is only one execution, then that would be static dimensions.
"""
if len(self.inputs) == 1:
# No dynamic shapes.
args = tuple({} for _ in self.inputs[0])
kwargs = {k: {} for k, v in self.inputs[1]}
return args, kwargs
# Otherwise.
s1 = set(len(i[0]) for i in self.inputs)
assert len(s1) == 1, f"Different numbers of unnamed arguments {s1}"
s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
assert len(s1) == 1, f"Different named arguments {s2}"
args = []
kwargs = {}
for i in range(s1.pop()):
objs = [_[0][i] for _ in self.inputs]
args.append(
self.guess_dynamic_shape_object(*objs, msg=lambda i=i: f" failing input {i}")
)
for name in s2.pop():
objs = [_[1][name] for _ in self.inputs]
kwargs[name] = self.guess_dynamic_shape_object(
*objs, msg=lambda name=name: f" failing input {name!r}"
)
return tuple(args), kwargs
def _move_to_kwargs(self, args, kwargs, dynamic_shapes):
"""
Uses the signatures to move unnamed arguments (args) to named arguments (kwargs)
with the corresponding dynamic shapes.
*kwargs*, *dynamic_shapes* are modified inplace.
"""
sig = inspect.signature(self.forward)
arg_dyn, kw_dyn = dynamic_shapes
for i, p in enumerate(sig.parameters):
if i >= len(arg_dyn):
break
kwargs[p] = arg_dyn[i]
kw_dyn[p] = args[i]
return tuple(), kwargs, (tuple(), kw_dyn)
def _try_export_no_bypass(
self,
modificator: Optional[Callable] = None,
exporter: str = "fx",
exporter_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
quiet: bool = True,
discrepancies: bool = True,
use_dynamic_shapes: Optional[bool] = None,
atol: float = 1e-2,
rtol: float = 1e-1,
) -> Any:
"""
Tries to export this class.
"""
export_inputs = modificator(self.inputs[0]) if modificator else self.inputs[0]
export_inputs = make_copy(export_inputs)
if use_dynamic_shapes is None:
use_dynamic_shapes = len(self.inputs) > 1
assert (
not use_dynamic_shapes or len(self.inputs) > 1
), "Unable to use dynamic_shapes, only one set of inputs is available."
dynamic_shapes = self.guess_dynamic_shapes() if use_dynamic_shapes else None
self.status = "START"
if exporter == "fx":
args, kwargs = export_inputs
if dynamic_shapes and args and kwargs:
# The export should change dynamic shapes to have only named arguments.
if verbose > 1:
print(
f"[try_export] {self.dot_name}: exporter={exporter!r} "
f"change dynamic_shapes={dynamic_shapes}"
)
args, kwargs, dynamic_shapes = self._move_to_kwargs(
args, kwargs, dynamic_shapes
)
if verbose > 1:
print(
f"[try_export] {self.dot_name}: exporter={exporter!r} "
f"with dynamic_shapes={dynamic_shapes}"
)
if verbose > 2:
print(
f"[try_export] {self.dot_name}: exporter={exporter!r} "
f"with export_inputs={string_type(export_inputs,with_shape=True)}"
)
if quiet:
try:
ep = torch.export.export(
self.model,
args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes[0] or dynamic_shapes[1],
**(exporter_kwargs or {}),
)
self.exporter_status = "OK"
except Exception as e:
self.last_error = e
se = str(e).split("\n")[0].replace("<_DimHint.DYNAMIC: 3>", "DYN")
self.exporter_status = f"FAIL-EXPORT: {se}"
if verbose:
print(f"[try_export] {self.dot_name} --- {self.exporter_status}")
return None
else:
ep = torch.export.export(
self.model,
args,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes[0] or dynamic_shapes[1],
**(exporter_kwargs or {}),
)
self.exporter_status = "OK"
if verbose > 1:
print(f"[try_export] {self.dot_name}: {exporter} done")
setattr(self, exporter, ep)
if discrepancies:
has_disc = False
mod = ep.module()
self.exporter_outputs = []
self.exporter_discs = []
for i, (inp, out) in enumerate(zip(self.inputs, self.outputs)):
copy_inp = make_copy(inp)
args, kwargs = copy_inp
if quiet:
try:
got = mod(*args, **kwargs)
except Exception as e:
self.last_error = e
se = str(e).split("\n")[0]
self.exporter_status = f"FAIL-EVAL: {se}"
break
else:
got = mod(*args, **kwargs)
self.exporter_outputs.append(got)
diff = max_diff(out, got)
if verbose > 1:
print(f"[try_export] {self.dot_name}: diff[{i}]={diff}")
self.exporter_discs.append(diff)
if diff["abs"] > atol or diff["rel"] > rtol:
self.exporter_status = "DISC: abs"
has_disc = True
break
if not has_disc:
self.exporter_status = "OK"
if verbose:
print(f"[try_export] {self.dot_name} --- {self.exporter_status}")
return ep if self.exporter_status == "OK" else None
raise NotImplementedError(f"Export not implemented yet for exporter={exporter!r}")
def _try_export(
self,
exporter: str = "fx",
exporter_kwargs: Optional[Dict[str, Any]] = None,
bypass_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
quiet: bool = True,
discrepancies: bool = True,
use_dynamic_shapes: Optional[bool] = None,
atol: float = 1e-2,
rtol: float = 1e-1,
) -> Any:
"""
Tries to export this class.
"""
if bypass_kwargs:
from .onnx_export_errors import bypass_export_some_errors
with bypass_export_some_errors(
verbose=max(verbose - 1, 0), **bypass_kwargs
) as modificator:
return self._try_export_no_bypass(
modificator,
exporter,
exporter_kwargs=exporter_kwargs,
quiet=quiet,
verbose=verbose,
use_dynamic_shapes=use_dynamic_shapes,
discrepancies=discrepancies,
atol=atol,
rtol=rtol,
)
return self._try_export_no_bypass(
None,
exporter,
exporter_kwargs=exporter_kwargs,
quiet=quiet,
verbose=verbose,
use_dynamic_shapes=use_dynamic_shapes,
discrepancies=discrepancies,
atol=atol,
rtol=rtol,
)
[docs]
def try_export(
self,
exporter: str = "fx",
exporter_kwargs: Optional[Dict[str, Any]] = None,
bypass_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
quiet: bool = True,
discrepancies: bool = True,
use_dynamic_shapes: Optional[bool] = None,
atol: float = 1e-2,
rtol: float = 1e-1,
) -> Any:
"""
Tries to export a model. If not possible,
tries every child until it is possible.
The fucntion stores the export and other results in the class itself.
:param exporter: export way, 'fx' for :func:`torch.export.export`,
`'onnx_dynamo`' to call :func:`torch.onnx.export` ``(..., dynamo=True)``,
`'torch_script'` to call :func:`torch.onnx.export` ``(..., dynamo=False)``,
`'to_onnx'` to call :func:`experimental_experiment.torch_interpreter.to_onnx`.
:param exporter_kwargs: argument for the export function
:param bypass_kwargs: argument for function :func:`bypass_export_some_errors
<experimental_experiment.torch_interpreter.onnx_export_errors.bypass_export_some_errors>`
:param verbose: verbosity, to see what the function is doing
:param discrepancies: run the exported model to measure the discrepancies
:param quiet: do not catch the first exception
:param use_dynamic_shapes: use dynamic shapes
:param atol: absolute tolerance
:param rtol: relative tolerance
:return: result of the export function
See :ref:`l-plot-exporter-recipes-custom-phi35` for an example.
"""
allowed = {"fx", "onnx_dynamo", "torch_script", "to_onnx"}
assert (
exporter in allowed
), f"Unexpected value for exporter={exporter!r} not in {allowed}"
exported = self._try_export(
exporter=exporter,
exporter_kwargs=exporter_kwargs,
bypass_kwargs=bypass_kwargs,
verbose=verbose,
quiet=quiet,
discrepancies=discrepancies,
use_dynamic_shapes=use_dynamic_shapes,
atol=atol,
rtol=rtol,
)
if exported is not None:
return exported
# Then the export failed, we look into the children.
for child in self.children:
child.try_export(
exporter=exporter,
exporter_kwargs=exporter_kwargs,
bypass_kwargs=bypass_kwargs,
verbose=verbose,
quiet=quiet,
discrepancies=discrepancies,
use_dynamic_shapes=use_dynamic_shapes,
atol=atol,
rtol=rtol,
)
# It fails...
return None
def _rewrite_forward(
*args, _diag: Optional[ModelDiagnoseOutput] = None, verbose: int = 0, **kwargs
):
assert _diag is not None, "_diag cannot be None"
if verbose:
indent = " " * _diag.level
if not args:
print(
f"[{_diag.name}:{_diag.model.__class__.__name__}] "
f"{indent}> **{string_type(kwargs)}"
)
elif kwargs:
print(
f"[{_diag.name}:{_diag.model.__class__.__name__}] "
f"{indent}> *{string_type(args)}, **{string_type(kwargs)}"
)
else:
if len(args) == 1 and isinstance(args[0], torch.Tensor):
print(
f"[{_diag.name}:{_diag.model.__class__.__name__}] "
f"{indent}> {string_type(args[0])}"
)
else:
print(
f"[{_diag.name}:{_diag.model.__class__.__name__}] "
f"{indent}> *{string_type(args)}"
)
_diag.add_inputs(args, kwargs)
res = _diag.forward(*args, **kwargs)
_diag.add_outputs(res)
if verbose:
if isinstance(res, torch.Tensor):
print(
f"[{_diag.name}:{_diag.model.__class__.__name__}] "
f"{indent}< {string_type(res)}"
)
else:
print(
f"[{_diag.name}:{_diag.model.__class__.__name__}] "
f"{indent}< *{string_type(res)}"
)
return res
def _trace_forward_execution(
model: torch.nn.Module,
name: str = "__main__",
level: int = 0,
verbose: int = 0,
):
diag = ModelDiagnoseOutput(name, model, level=level)
if verbose:
print(f"[_trace_forward_execution] {diag.dot_name}")
model.forward = lambda *args, _diag=diag, verbose=verbose, **kwargs: _rewrite_forward(
*args, _diag=_diag, verbose=verbose, **kwargs
)
for name, mod in model.named_children():
if isinstance(mod, torch.nn.ModuleList):
for i, m in enumerate(mod):
d = _trace_forward_execution(
m, f"{name}[{i}]", verbose=max(verbose - 1, 0), level=level + 1
)
diag.add_child(d)
else:
d = _trace_forward_execution(
mod, name, verbose=max(verbose - 1, 0), level=level + 1
)
diag.add_child(d)
return diag
def _untrace_forward_execution(diag: ModelDiagnoseOutput, verbose: int = 0):
if verbose:
print(f"[_untrace_forward_execution] {diag.dot_name}")
diag.model.forward = diag.forward
for child in diag.children:
_untrace_forward_execution(child, verbose=verbose)
[docs]
@contextlib.contextmanager
def trace_forward_execution(model: torch.nn.Module, verbose: int = 0) -> ModelDiagnoseOutput:
"""
Replaces all forward to store the inputs and outputs of the module
and every submodules.
See :ref:`l-plot-exporter-recipes-custom-phi35` for an example.
"""
diag = _trace_forward_execution(model, verbose=verbose)
try:
yield diag
finally:
_untrace_forward_execution(diag, verbose=verbose)
[docs]
def infer_shape_type_from_execution(
model: torch.nn.Module,
inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
verbose: int = 0,
) -> ModelDiagnoseOutput:
"""
Runs a model, traces the intermediate output and infers dynamic shapes
based on it.
:param model: model
:param inputs: list of input sets ``[(args, kwargs), (args, kwargs), ...]``
with different shapes (at least for the dynamic dimensions)
:param verbose: verbosity
:return: see :class:`ModelDiagnoseOutput`
See :ref:`l-plot-exporter-recipes-custom-phi35` for an example.
"""
with trace_forward_execution(model, verbose=verbose) as tracer:
for i in inputs:
if isinstance(i, dict):
i = (tuple(), i)
elif isinstance(i, tuple) and (
len(i) != 2 or not isinstance(i[0], tuple) or not isinstance(i[1], dict)
):
i = (i, {})
if verbose:
print(
f"[infer_shape_type_from_execution] run with "
f"{string_type(dict(args=i[0], kwargs=i[1]), with_shape=True)}"
)
assert (
isinstance(i, tuple)
and len(i) == 2
and isinstance(i[0], tuple)
and isinstance(i[1], dict)
), (
f"Unexpected types as inputs, it should (args, kwargs) but got "
f"{string_type(i)}"
)
args, kwargs = i
model(*args, **kwargs)
if verbose:
print(
f"[trace_forward_execution] traced execution of model "
f"{model.__class__.__name__}"
)
print(tracer.pretty_text())
return tracer