Source code for onnx_diagnostic.export.dynamic_shapes
import inspect
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from ..helpers import string_type
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]]
[docs]
def flatten_dynamic_shapes(ds: Any) -> Any:
"""Flattens the dynamic shapes."""
if isinstance(ds, list):
return _flat_list([flatten_dynamic_shapes(t) for t in ds])
if isinstance(ds, tuple):
return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds]))
if isinstance(ds, dict):
if all(isinstance(i, int) for i in ds):
# That's a dynamic shape
return ds
return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()])
raise AssertionError(f"Not implemented for {type(ds)}: {ds}")
def _flat_list(li: List[Any]) -> List[Dict[int, str]]:
res = []
for t in li:
if isinstance(t, dict):
res.append(t)
else:
res.extend(t)
return res
[docs]
class CoupleInputsDynamicShapes:
"""
Pair inputs / dynamic shapes.
:param args: positional arguments
:param kwargs: named arguments
:param dynamic_shapes: dynamic shapes
:param args_names: if both args and kwargs are not empty, then
dynamic shapes must be a dictionary, and positional must be added
to the named arguments. Arguments names or a module must be given
in that case.
"""
def __init__(
self,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
dynamic_shapes: DYNAMIC_SHAPES,
args_names: Optional[Union[torch.nn.Module, List[str]]] = None,
):
self.args = args
self.kwargs = kwargs
self.dynamic_shapes = dynamic_shapes
self.args_names = args_names
def __str__(self) -> str:
return "\n".join(
[
f"{self.__class__.__name__}(",
f" args={string_type(self.args, with_shape=True)},"
f" kwargs={string_type(self.kwargs, with_shape=True)},"
f" dynamic_shapes={string_type(self.dynamic_shapes, with_shape=True)},"
f")",
]
)
[docs]
def replace_string_by(self, value: Any = None):
"""
Replaces string by the value ``torch.export.Dim.DYNAMIC``
(default) or any other value specified by value.
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
T3x1 = torch.rand((3, 1))
T3x4 = torch.rand((3, 4))
ds_batch = {0: "batch"}
ds_batch_seq = {0: "batch", 1: "seq"}
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by())
"""
return self._generic_walker(
lambda inputs, ds, value=value: self._replace_string_dim_tensor(
inputs, ds, value=value
),
flatten_unflatten=True,
)
@classmethod
def _replace_string_dim_tensor(cls, inputs, ds, value=None):
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
f"a dictionary is expected to specify a dimension"
)
if value is None:
value = torch.export.Dim.DYNAMIC
new_ds = ds.copy()
for i, v in ds.items():
if isinstance(v, str):
new_ds[i] = value
return new_ds
[docs]
def replace_by_string(self):
"""
Replaces dimensions by strings.
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
Dim = torch.export.Dim
T3x1 = torch.rand((3, 1))
T3x4 = torch.rand((3, 4))
ds_batch = {0: Dim("batch")}
ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")}
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string())
"""
unique = set()
return self._generic_walker(
lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string(
inputs, ds, unique=unique
),
flatten_unflatten=True,
)
@classmethod
def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]):
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
f"a dictionary is expected to specify a dimension"
)
new_ds = ds.copy()
for i, v in ds.items():
if isinstance(v, str):
unique.add(v)
new_ds[i] = v
elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO):
name = f"Dim{len(unique)}"
new_ds[i] = name
unique.add(name)
else:
name = v.__name__
unique.add(name)
new_ds[i] = name
return new_ds
[docs]
def invalid_dimensions_for_export(self):
"""
Tells if the inputs are valid based on the dynamic shapes definition.
The method assumes that all custom classes can be serialized.
If some patches were applied to export, they should enabled while
calling this method if the inputs contains such classes.
The function checks that a dynamic dimension does not receive a value
of 0 or 1. It returns the unexpected values in the same structure as
the given dynamic shapes.
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
T3x1 = torch.rand((3, 1))
T3x4 = torch.rand((3, 4))
ds_batch = {0: "batch"}
ds_batch_seq = {0: "batch", 1: "seq"}
kwargs = {"A": T3x4, "B": (T3x1, T3x1)}
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
In case it works, it shows:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
T3x2 = torch.rand((3, 2))
T3x4 = torch.rand((3, 4))
ds_batch = {0: "batch"}
ds_batch_seq = {0: "batch", 1: "seq"}
kwargs = {"A": T3x4, "B": (T3x2, T3x2)}
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export())
"""
return self._generic_walker(self._valid_shapes_tensor, flatten_unflatten=True)
@classmethod
def _valid_shapes_tensor(cls, inputs, ds):
assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}"
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
f"a dictionary is expected to specify a dimension dimension"
)
issues = {}
for i, d in enumerate(inputs.shape):
if i in ds and not isinstance(ds[i], int):
# dynamic then
if d in {0, 1}:
# export issues for sure
issues[i] = f"d=[{d}]"
return issues if issues else None
def _generic_walker(
self, processor: Callable, args_kwargs: bool = False, flatten_unflatten: bool = False
):
"""
Generic deserializator walking through inputs and dynamic_shapes all along.
The function returns a result with the same structure as the dynamic shapes.
"""
if not self.args:
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
f"Type mismatch, args={string_type(self.args)} and "
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
)
res = self._generic_walker_step(
processor,
self.kwargs,
self.dynamic_shapes,
flatten_unflatten=flatten_unflatten,
)
return (tuple(), res) if args_kwargs else res
if not self.kwargs:
assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), (
f"Type mismatch, args={string_type(self.args)} and "
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
)
res = self._generic_walker_step(
processor, self.args, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
)
return (res, {}) if args_kwargs else res
assert isinstance(self.dynamic_shapes, dict), (
f"Both positional and named arguments (args and kwargs) are filled. "
f"dynamic shapes must a dictionary not {type(self.dynamic_shapes)}"
)
if not self.args_names and set(self.dynamic_shapes) & set(self.kwargs) == set(
self.dynamic_shapes
):
# No dynamic shapes for the positional arguments.
return self._generic_walker_step(
processor,
self.kwargs,
self.dynamic_shapes,
flatten_unflatten=flatten_unflatten,
)
if isinstance(self.args_names, list):
if not set(self.args_names) & set(self.dynamic_shapes):
# No dynamic shapes for the positional arguments.
return self._generic_walker_step(
processor,
self.kwargs,
self.dynamic_shapes,
flatten_unflatten=flatten_unflatten,
)
assert self.args_names, (
"args and kwargs are filled, then args_names must be specified in "
"the constructor to move positional arguments to named arguments."
)
assert len(self.args) <= len(self.args_names), (
f"There are {len(self.args)} positional arguments "
f"but only {len(self.args_names)} names. "
f"args={string_type(self.args, with_shape=True)}, args_name={self.args_names}"
)
kwargs = dict(zip(self.args_names, self.args))
kwargs.update(self.kwargs)
res = self._generic_walker_step(
processor, kwargs, self.dynamic_shapes, flatten_unflatten=flatten_unflatten
)
if args_kwargs:
pgs = [None for _ in range(len(self.args))]
kws = {}
for k, v in res.items():
if k not in self.kwargs:
pgs[self.args_names.index(k)] = v
else:
kws[k] = v
return pgs, kws
return res
raise NotImplementedError(
f"Not yet implemented when args is filled, "
f"kwargs as well but args_names is {type(self.args_names)}"
)
@classmethod
def _generic_walker_step(
cls, processor: Callable, inputs, ds, flatten_unflatten: bool = False
):
if isinstance(inputs, torch.Tensor):
return processor(inputs, ds)
if isinstance(inputs, (int, float, str)):
return None
if type(inputs) in (tuple, list, dict):
# Type must be strict, some custom classes can inherit from those.
assert type(inputs) is type(ds), (
f"Input type and dynamic shape type mush match but "
f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
)
assert len(ds) == len(inputs), (
f"Length mismatch between inputs {len(inputs)} "
f"and ds={len(ds)}\n"
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
)
if type(inputs) in (tuple, list):
value = []
for i, d in zip(inputs, ds):
value.append(
cls._generic_walker_step(
processor, i, d, flatten_unflatten=flatten_unflatten
)
)
return (
(value if isinstance(ds, list) else tuple(value))
if any(v is not None for v in value)
else None
)
assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
assert set(inputs) == set(ds), (
f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
)
dvalue = {}
for k, v in inputs.items():
t = cls._generic_walker_step(
processor, v, ds[k], flatten_unflatten=flatten_unflatten
)
if t is not None:
dvalue[k] = t
return dvalue if dvalue else None
# A custom class.
assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
f"Class {inputs.__class__.__name__!r} was not registered using "
f"torch.utils._pytree.register_pytree_node, it is not possible to "
f"map this class with the given dynamic shapes."
)
if flatten_unflatten:
flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
res = cls._generic_walker_step(
processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
)
# Should we restore the original class?
return res
flat, spec = torch.utils._pytree.tree_flatten(inputs)
if all(isinstance(t, torch.Tensor) for t in flat):
# We need to flatten dynamic shapes as well
ds = flatten_dynamic_shapes(ds)
res = cls._generic_walker_step(
processor, flat, ds, flatten_unflatten=flatten_unflatten
)
# Then we restore the original class.
return torch.utils._pytree.tree_unflatten(res, spec)
class ChangeDimensionProcessor:
def __init__(self, desired_values):
self.mapping = desired_values or {}
def _build_new_shape(
self, shape: Tuple[int, ...], ds: Dict[int, Any]
) -> Tuple[int, ...]:
new_shape = list(shape)
for i in range(len(shape)):
if i in ds:
if isinstance(ds[i], str):
d = ds[i]
elif isinstance(
ds[i],
(
torch.export.dynamic_shapes._DerivedDim,
torch.export.dynamic_shapes._Dim,
),
):
d = str(ds[i])
elif not isinstance(ds[i], int):
raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
if d in self.mapping:
new_dim = self.mapping[d]
else:
new_dim = shape[i] + 1
self.mapping[d] = new_dim
new_shape[i] = new_dim
return tuple(new_shape)
def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]):
rank = len(tensor.shape)
for i in range(len(tensor.shape)):
d0 = tensor.shape[i]
d1 = new_shape[i]
if d0 == d1:
continue
alt_shape = list(tensor.shape)
alt_shape[i] = d1
new_tensor = torch.zeros(
tuple(alt_shape), dtype=tensor.dtype, device=tensor.device
)
mind = min(d0, d1)
indices: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
indices[i] = slice(0, mind)
ind = tuple(indices)
new_tensor[ind] = tensor[ind]
if d1 > mind:
for k in range(d1 - mind):
indices0: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
indices1: List[Union[slice, int]] = [slice(None) for _ in range(rank)]
indices1[i] = mind + k
indices0[i] = k % mind
new_tensor[tuple(indices1)] = tensor[tuple(indices0)]
tensor = new_tensor
return tensor
def __call__(self, inputs, ds):
assert isinstance(
inputs, torch.Tensor
), f"unexpected type for inputs {type(inputs)}"
assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), (
f"Unexpected types, inputs is a Tensor but ds is {ds}, "
f"a dictionary is expected to specify a dimension dimension"
)
new_shape = self._build_new_shape(inputs.shape, ds)
return self._build_new_tensor(inputs, new_shape)
[docs]
def change_dynamic_dimensions(
self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
):
"""
A model exported with dynamic shapes is not necessarily dynamic
just because the user specified dynamic shapes. The algorithm
may discover that a dimension cannot be dynamic and then continues
the export making the assumption it is static. That may lead a wrong
model. This function produces a new set of inputs with different values
for the dimension than the first ones, assuming they were used to export
the model.
:param desired_values: to fixed named dimension to have the desired value
:param args_kwargs: return both args, kwargs even if empty
:return: new inputs
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes
T3x15 = torch.rand((3, 15))
T3x20 = torch.rand((3, 20))
T3x4 = torch.rand((3, 4))
ds_batch = {0: "batch"}
ds_batch_seq = {0: "batch", 1: "seq"}
kwargs = {"A": T3x4, "B": (T3x15, T3x20)}
ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}
new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions()
print("before:", string_type(kwargs, with_shape=True))
print("-after:", string_type(new_kwargs, with_shape=True))
"""
return self._generic_walker(
self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
)
[docs]
class ModelInputs:
"""
Wraps a model and a couple of sets of valid inputs.
Based on that information, the class is able to infer the dynamic shapes
for :func:`torch.export.export`.
:param model: model to export
:param inputs: list of valid set of inputs
:param level: if this module is a submodule, it is the level of submodule
:param method_name: by default, the forward method is processed but it
could be another one
:param name: a name, mostly for debugging purposes
Examples:
**args**
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y) # to check it works
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
**kwargs**
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y) # to check it works
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
**args and kwargs**
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y) # to check it works
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
:func:`torch.export.export` does not like dynamic shapes defined both as args and kwargs.
kwargs must be used. ``move_to_kwargs`` modifies the inputs and the dynamic shapes
to make the model and the given inputs exportable.
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.helpers import string_type
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y) # to check it works
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
print("moved args:", string_type(a, with_shape=True))
print("moved kwargs:", string_type(kw, with_shape=True))
print("dynamic shapes:")
pprint.pprint(nds)
"""
def __init__(
self,
model: torch.nn.Module,
inputs: Union[
List[Tuple[Any, ...]],
List[Dict[str, Any]],
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
],
level: int = 0,
method_name: str = "forward",
name: str = "main",
):
assert isinstance(model, torch.nn.Module) or inspect.ismodule(
model
), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
assert name, (
f"name={name!r} cannot be empty this string is used to "
f"display meaningful error messages"
)
self.name = name
self.model = model
self.level = level
self.method_name = method_name
self.forward = getattr(model, method_name)
self.signature = inspect.signature(self.forward)
# information about the signature
self.forward_parameter_names = set(
p.name
for p in self.signature.parameters.values()
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
)
self.forward_ordered_parameter_names = list(self.signature.parameters)
self.forward_positioned_parameter_names = [
p.name
for p in self.signature.parameters.values()
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
names = [
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
]
self.forward_args = names[0] if names else None
names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
self.forward_kwargs = names[0] if names else None
self.forward_custom_op_schema = None
self.forward_need_serialization = False
self.forward_fill_kwargs = bool(self.forward_kwargs)
assert not isinstance(
model, (torch.nn.ModuleList, torch.nn.ModuleDict)
), f"ModuleList or ModuleDict should not be traced: {type(model)}"
# process the inputs
self.inputs = self.process_inputs(inputs)
[docs]
def process_inputs(
self,
inputs: Union[
List[Tuple[Any, ...]],
List[Dict[str, Any]],
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
],
) -> List[Tuple[Tuple[Any, ...], Dict[str, Any]]]:
"""
Transforms a list of valid inputs, list of args, list of kwargs or list of both
into a list of (args, kwargs).
"""
if not isinstance(inputs, list):
raise ValueError(
f"inputs should be specified as a list of sets of "
f"inputs but type(inputs) is {type(inputs)}"
)
new_inputs = []
for i, inp in enumerate(inputs):
if (
isinstance(inp, tuple)
and len(inp) == 2
and isinstance(inp[0], tuple)
and isinstance(inp[1], dict)
):
new_inputs.append(inp)
continue
if isinstance(inp, tuple):
new_inputs.append((inp, {}))
continue
if isinstance(inp, dict):
new_inputs.append(((), inp))
continue
raise ValueError(f"Unable to interpret inputs {i}: {string_type(inp)}")
return new_inputs
@property
def true_model_name(self) -> str:
"Returns class name or module name."
return (
self.model.__class__.__name__
if isinstance(self.model, torch.nn.Module)
else self.model.__name__
)
@property
def full_name(self) -> str:
"Returns a name and class name."
if self.method_name == "forward":
return f"{self.name}:{self.true_model_name}"
return f"{self.name}:{self.true_model_name}.{self.method_name}"
@property
def module_name_type(self):
"Returns name and module type."
if self.method_name == "forward":
return f"type({self.name})={self.true_model_name}"
return f"type({self.name})={self.true_model_name}.{self.method_name}"
[docs]
def guess_dynamic_dimensions(
self, *tensors, auto: bool = False
) -> Optional[Dict[int, Any]]:
"""
Infers the dynamic dimension from multiple shapes.
If auto is True, it returns ``torch.export.Dim.AUTO`` for every dimension
which cannot be guessed. Two tensors with the same value for one dimension
can be guessed, but if there is only 1, it cannot.
"""
if len(tensors) == 1:
if isinstance(tensors[0], (int, float)):
return None
assert isinstance(tensors[0], torch.Tensor), (
f"Unexpected type for tensors {string_type(tensors, with_shape=True)}, "
f"Only tensors are allowed."
)
return (
{i: torch.export.Dim.AUTO for i in range(len(tensors[0].shape))} # noqa: C420
if auto
else {}
)
shapes = [t.shape for t in tensors]
set_length = set(len(s) for s in shapes)
assert len(set_length) == 1, (
f"Shapes can be different but not ranks possible shapes={set_length} "
f"shapes={shapes} for module {self.name!r}, "
f"class={self.true_model_name!r}"
)
dynamic: Any = torch.export.Dim.DYNAMIC # type: ignore
rk = set_length.pop()
res = {}
for i in range(rk):
set_dim = set(s[i] for s in shapes)
if len(set_dim) > 1:
res[i] = dynamic
continue
if set_dim == {0}:
# It is unexpected to find a null dimension. Let's replace it by a dynamic one.
res[i] = dynamic
continue
return res
[docs]
def guess_dynamic_shape_object(
self, *objs: Any, auto: bool = False, msg: Optional[Callable] = None
) -> Any:
"""Guesses the dynamic shapes for one argument."""
if len(objs) == 0:
return None
set_types = set(type(o) for o in objs)
assert (
len(set_types) == 1
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
obj = objs[0]
if obj is None:
return None
if isinstance(obj, (bool, int, float, str)):
return None
if isinstance(obj, (torch.Tensor, np.ndarray)):
return self.guess_dynamic_dimensions(*objs, auto=auto)
if isinstance(obj, tuple):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
shapes: Any = []
for i in range(kl.pop()):
shapes.append(
self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
)
return tuple(shapes)
if isinstance(obj, list):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
shapes = []
for i in range(kl.pop()):
shapes.append(
self.guess_dynamic_shape_object(*[o[i] for o in objs], auto=auto, msg=msg)
)
return shapes
if isinstance(obj, dict):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of dict lengths {kl}{msg() if msg else ''}"
shapes = {}
for i in obj:
shapes[i] = self.guess_dynamic_shape_object(
*[o[i] for o in objs], auto=auto, msg=msg
)
return shapes
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
kcl = set(o.__class__ for o in objs)
assert len(kcl) == 1, (
f"All instances of argument {i} are not of the same class but {kcl}, "
f"types should be the same."
)
col_args = [flatten_unflatten_for_dynamic_shapes(o) for o in objs]
kc = set(len(o) for o in col_args)
assert len(kc) == 1, (
f"All instances of type {kcl.pop()} are not serialized into the same number "
f"of arguments, it should be the same."
)
values = []
for i in range(kc.pop()):
values.append(
self.guess_dynamic_shape_object(
*[ca[i] for ca in col_args], auto=auto, msg=msg
)
)
return values
# In case DynamicCache is not registered.
if obj.__class__.__name__ == "DynamicCache":
kc = set(len(o.key_cache) for o in objs)
assert (
len(kc) == 1
), f"All attribute 'key_cache' should have the same length but found {kc}"
vc = set(len(o.value_cache) for o in objs)
assert (
len(vc) == 1
), f"All attribute 'value_cache' should have the same length but found {vc}"
key_cache = []
for i in range(kc.pop()):
key_cache.append(
self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs], auto=auto)
)
value_cache = []
for i in range(vc.pop()):
value_cache.append(
self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs], auto=auto)
)
return [key_cache, value_cache]
raise NotImplementedError(
f"Unable to build dynamic shapes for type {set_types.pop()}: "
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}, "
f"this object needs serialization function to be registered."
)
[docs]
def guess_dynamic_shapes(self, auto: bool = False) -> DYNAMIC_SHAPES:
"""
Guesses the dynamic shapes for that module from two execution.
If there is only one execution, then that would be static dimensions.
:param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
dimension if the number of inputs is one
"""
if len(self.inputs) == 0:
# No inputs, unable to guess.
return (tuple(), {})
if len(self.inputs) == 1:
# No dynamic shapes.
return tuple(
self.guess_dynamic_shape_object(a, auto=auto) for a in self.inputs[0][0]
), {
k: self.guess_dynamic_shape_object(v, auto=auto)
for k, v in self.inputs[0][1].items()
}
# Otherwise.
s1 = set(len(i[0]) for i in self.inputs)
assert (
len(s1) == 1
), f"Different numbers of positional arguments {s1} for {self.full_name}"
s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
args = []
kwargs = {}
for i in range(s1.pop()):
objs = [_[0][i] for _ in self.inputs]
args.append(
self.guess_dynamic_shape_object(
*objs, auto=auto, msg=lambda i=i: f" failing input {i}"
)
)
names = s2.pop()
for name in names:
assert name not in {"_diag", "verbose"}, (
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
)
objs = [_[1][name] for _ in self.inputs]
kwargs[name] = self.guess_dynamic_shape_object(
*objs, auto=auto, msg=lambda name=name: f" failing input {name!r}"
)
return tuple(args), kwargs
[docs]
def move_to_kwargs(
self,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
) -> Tuple[Tuple[Any, ...], Dict[str, Any], DYNAMIC_SHAPES]:
"""
Uses the signatures to move positional arguments (args) to named arguments (kwargs)
with the corresponding dynamic shapes.
*kwargs*, *dynamic_shapes* are modified inplace.
"""
sig = self.signature
arg_dyn, kw_dyn = dynamic_shapes
for i, p in enumerate(sig.parameters):
if i >= len(arg_dyn):
break
kw_dyn[p] = arg_dyn[i]
if self.forward_kwargs:
kdw = {}
for k, v in kw_dyn.items():
if k not in self.forward_parameter_names:
kdw[k] = v
if kdw:
for k in kdw:
del kw_dyn[k]
kw_dyn[self.forward_kwargs] = kdw
# Let's reorder as it seems to matter later
# in the shape inference algorithm.
_kwargs = kwargs
kwargs = {}
_kw_dyn = kw_dyn
kw_dyn = {}
for name in self.forward_ordered_parameter_names:
if name in _kwargs:
kwargs[name] = _kwargs[name]
if name in _kw_dyn:
kw_dyn[name] = _kw_dyn[name]
for k in _kwargs:
if k not in kwargs:
# Then it is part of **kwargs.
kwargs[k] = _kwargs[k]
assert len(kw_dyn) == len(_kw_dyn), (
f"{self.full_name}: unexpected mismatch between _kw_dyn={set(_kw_dyn)} "
f"and kw_dyn={set(kw_dyn)}, "
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
)
assert len(kwargs) == len(_kwargs), (
f"{self.full_name}: unexpected mismatch between _kwargs={set(_kwargs)} "
f"and kwargs={set(kwargs)}, "
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
)
return args, kwargs, (tuple(), kw_dyn)
[docs]
def validate_inputs_for_export(
self, dynamic_shapes: Optional[DYNAMIC_SHAPES] = None
) -> List[List[Union[int, str]]]:
"""
Validates the inputs the class contains for the given dynamic shapes.
If not specified, the dynamic_shapes are guessed.
:param dynamic_shapes: dynamic shapes to validate
:return: a list of lists, every list contains the path the invalid dimension
"""
if dynamic_shapes is None:
if len(self.inputs) == 1:
return []
dyn_shapes = self.guess_dynamic_shapes()
return [
CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export()
for i in self.inputs
]