Source code for onnx_diagnostic.export.dynamic_shapes
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from ..helpers import string_type
[docs]
class ModelInputs:
"""
Wraps a model and a couple of sets of valid inputs.
Based on that information, the class is able to infer the dynamic shapes
for :func:`torch.export.export`.
:param model: model to export
:param inputs: list of valid set of inputs
:param level: if this module is a submodule, it is the level of submodule
:param method_name: by default, the forward method is processed but it
could be another one
:param name: a name, mostly for debugging purposes
Examples:
**args**
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y) # to check it works
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
**kwargs**
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y) # to check it works
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
**and and kwargs**
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y) # to check it works
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
:func:`torch.export.export` does not like dynamic shapes defined both as args and kwargs.
kwargs must be used. ``move_to_kwargs`` modifies the inputs and the dynamic shapes
to make the model and the given inputs exportable.
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.helpers import string_type
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y) # to check it works
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
print("moved args:", string_type(a, with_shape=True))
print("moved kwargs:", string_type(kw, with_shape=True))
print("dynamic shapes:")
pprint.pprint(nds)
"""
def __init__(
self,
model: torch.nn.Module,
inputs: Union[
List[Tuple[Any, ...]],
List[Dict[str, Any]],
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
],
level: int = 0,
method_name: str = "forward",
name: str = "main",
):
assert isinstance(model, torch.nn.Module) or inspect.ismodule(
model
), f"unexpected type for model={type(model)}, it must be a torch.nn.Module"
assert name, (
f"name={name!r} cannot be empty this string is used to "
f"display meaningful error messages"
)
self.name = name
self.model = model
self.level = level
self.method_name = method_name
self.forward = getattr(model, method_name)
self.signature = inspect.signature(self.forward)
# information about the signature
self.forward_parameter_names = set(
p.name
for p in self.signature.parameters.values()
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
)
self.forward_ordered_parameter_names = list(self.signature.parameters)
self.forward_positioned_parameter_names = [
p.name
for p in self.signature.parameters.values()
if p.kind in (p.VAR_POSITIONAL, p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
names = [
p.name for p in self.signature.parameters.values() if p.kind == p.VAR_POSITIONAL
]
self.forward_args = names[0] if names else None
names = [p.name for p in self.signature.parameters.values() if p.kind == p.VAR_KEYWORD]
self.forward_kwargs = names[0] if names else None
self.forward_custom_op_schema = None
self.forward_need_serialization = False
self.forward_fill_kwargs = bool(self.forward_kwargs)
assert not isinstance(
model, (torch.nn.ModuleList, torch.nn.ModuleDict)
), f"ModuleList or ModuleDict should not be traced: {type(model)}"
# process the inputs
self.inputs = self.process_inputs(inputs)
[docs]
def process_inputs(
self,
inputs: Union[
List[Tuple[Any, ...]],
List[Dict[str, Any]],
List[Tuple[Tuple[Any, ...], Dict[str, Any]]],
],
) -> List[Tuple[Tuple[Any, ...], Dict[str, Any]]]:
"""
Transforms a list of valid inputs, list of args, list of kwargs or list of both
into a list of (args, kwargs).
"""
if not isinstance(inputs, list):
raise ValueError(
f"inputs should be specified as a list of sets of "
f"inputs but type(inputs) is {type(inputs)}"
)
new_inputs = []
for i, inp in enumerate(inputs):
if (
isinstance(inp, tuple)
and len(inp) == 2
and isinstance(inp[0], tuple)
and isinstance(inp[1], dict)
):
new_inputs.append(inp)
continue
if isinstance(inp, tuple):
new_inputs.append((inp, {}))
continue
if isinstance(inp, dict):
new_inputs.append(((), inp))
continue
raise ValueError(f"Unable to interpret inputs {i}: {string_type(inp)}")
return new_inputs
@property
def true_model_name(self):
"Returns class name or module name."
return (
self.model.__class__.__name__
if isinstance(self.model, torch.nn.Module)
else self.model.__name__
)
@property
def full_name(self):
"Returns a name and class name."
if self.method_name == "forward":
return f"{self.name}:{self.true_model_name}"
return f"{self.name}:{self.true_model_name}.{self.method_name}"
@property
def module_name_type(self):
"Returns name and module type."
if self.method_name == "forward":
return f"type({self.name})={self.true_model_name}"
return f"type({self.name})={self.true_model_name}.{self.method_name}"
[docs]
def guess_dynamic_dimensions(self, *tensors) -> Dict[int, Any]:
"""Infers the dynamic dimension from multiple shapes."""
if len(tensors) == 1:
return {}
shapes = [t.shape for t in tensors]
set_length = set(len(s) for s in shapes)
assert len(set_length) == 1, (
f"Shapes can be different but not ranks possible shapes={set_length} "
f"shapes={shapes} for module {self.name!r}, "
f"class={self.true_model_name!r}"
)
dynamic: Any = torch.export.Dim.DYNAMIC # type: ignore
rk = set_length.pop()
res = {}
for i in range(rk):
set_dim = set(s[i] for s in shapes)
if len(set_dim) > 1:
res[i] = dynamic
continue
if set_dim == {0}:
# It is unexpected to find a null dimension. Let's replace it by a dynamic one.
res[i] = dynamic
continue
return res
[docs]
def guess_dynamic_shape_object(self, *objs: Any, msg: Optional[Callable] = None) -> Any:
"""Guesses the dynamic shapes for one argument."""
if len(objs) == 0:
return None
set_types = set(type(o) for o in objs)
assert (
len(set_types) == 1
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
obj = objs[0]
if obj is None:
return None
if isinstance(obj, (bool, int, float, str)):
return None
if isinstance(obj, (torch.Tensor, np.ndarray)):
return self.guess_dynamic_dimensions(*objs)
if isinstance(obj, tuple):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of tuple lengths {kl}{msg() if msg else ''}"
shapes: Any = []
for i in range(kl.pop()):
shapes.append(self.guess_dynamic_shape_object(*[o[i] for o in objs]))
return tuple(shapes)
if isinstance(obj, list):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of list lengths {kl}{msg() if msg else ''}"
shapes = []
for i in range(kl.pop()):
shapes.append(self.guess_dynamic_shape_object(*[o[i] for o in objs]))
return shapes
if isinstance(obj, dict):
kl = set(len(o) for o in objs)
assert (
len(kl) == 1
), f"Unexpected variety of dict lengths {kl}{msg() if msg else ''}"
shapes = {}
for i in obj:
shapes[i] = self.guess_dynamic_shape_object(*[o[i] for o in objs])
return shapes
if obj.__class__.__name__ == "DynamicCache":
kc = set(len(o.key_cache) for o in objs)
assert (
len(kc) == 1
), f"All attribute 'key_cache' should have the same length but found {kc}"
vc = set(len(o.value_cache) for o in objs)
assert (
len(vc) == 1
), f"All attribute 'value_cache' should have the same length but found {vc}"
key_cache = []
for i in range(kc.pop()):
key_cache.append(
self.guess_dynamic_dimensions(*[o.key_cache[i] for o in objs])
)
value_cache = []
for i in range(vc.pop()):
value_cache.append(
self.guess_dynamic_dimensions(*[o.value_cache[i] for o in objs])
)
return [key_cache, value_cache]
raise NotImplementedError(
f"Unable to build dynamic shapes for type {set_types.pop()}: "
f"{string_type(objs)}{msg() if msg else ''} in {self.module_name_type}"
)
[docs]
def guess_dynamic_shapes(
self,
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Guesses the dynamic shapes for that module from two execution.
If there is only one execution, then that would be static dimensions.
"""
if len(self.inputs) == 0:
# No inputs, unable to guess.
return (tuple(), {})
if len(self.inputs) == 1:
# No dynamic shapes.
return tuple(self.guess_dynamic_shape_object(a) for a in self.inputs[0][0]), {
k: self.guess_dynamic_shape_object(v) for k, v in self.inputs[0][1].items()
}
# Otherwise.
s1 = set(len(i[0]) for i in self.inputs)
assert (
len(s1) == 1
), f"Different numbers of positional arguments {s1} for {self.full_name}"
s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
args = []
kwargs = {}
for i in range(s1.pop()):
objs = [_[0][i] for _ in self.inputs]
args.append(
self.guess_dynamic_shape_object(*objs, msg=lambda i=i: f" failing input {i}")
)
names = s2.pop()
for name in names:
assert name not in {"_diag", "verbose"}, (
f"{self.full_name}: unexpected parameter {name!r}, names={names}"
f"\ninputs[0]={string_type(self.inputs[0], with_shape=True)}"
f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
)
objs = [_[1][name] for _ in self.inputs]
kwargs[name] = self.guess_dynamic_shape_object(
*objs, msg=lambda name=name: f" failing input {name!r}"
)
return tuple(args), kwargs
[docs]
def move_to_kwargs(
self,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]],
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Tuple[Tuple[Any, ...], Dict[str, Any]]]:
"""
Uses the signatures to move positional arguments (args) to named arguments (kwargs)
with the corresponding dynamic shapes.
*kwargs*, *dynamic_shapes* are modified inplace.
"""
sig = self.signature
arg_dyn, kw_dyn = dynamic_shapes
for i, p in enumerate(sig.parameters):
if i >= len(arg_dyn):
break
kw_dyn[p] = arg_dyn[i]
if self.forward_kwargs:
kdw = {}
for k, v in kw_dyn.items():
if k not in self.forward_parameter_names:
kdw[k] = v
if kdw:
for k in kdw:
del kw_dyn[k]
kw_dyn[self.forward_kwargs] = kdw
# Let's reorder as it seems to matter later
# in the shape inference algorithm.
_kwargs = kwargs
kwargs = {}
_kw_dyn = kw_dyn
kw_dyn = {}
for name in self.forward_ordered_parameter_names:
if name in _kwargs:
kwargs[name] = _kwargs[name]
if name in _kw_dyn:
kw_dyn[name] = _kw_dyn[name]
for k in _kwargs:
if k not in kwargs:
# Then it is part of **kwargs.
kwargs[k] = _kwargs[k]
assert len(kw_dyn) == len(_kw_dyn), (
f"{self.full_name}: unexpected mismatch between _kw_dyn={set(_kw_dyn)} "
f"and kw_dyn={set(kw_dyn)}, "
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
)
assert len(kwargs) == len(_kwargs), (
f"{self.full_name}: unexpected mismatch between _kwargs={set(_kwargs)} "
f"and kwargs={set(kwargs)}, "
f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
)
return args, kwargs, (tuple(), kw_dyn)