onnx_diagnostic.export.dynamic_shapes¶
- class onnx_diagnostic.export.dynamic_shapes.ModelInputs(model: Module, inputs: List[Tuple[Any, ...]] | List[Dict[str, Any]] | List[Tuple[Tuple[Any, ...], Dict[str, Any]]], level: int = 0, method_name: str = 'forward', name: str = 'main')[source]¶
Wraps a model and a couple of sets of valid inputs. Based on that information, the class is able to infer the dynamic shapes for
torch.export.export()
.- Parameters:
model – model to export
inputs – list of valid set of inputs
level – if this module is a submodule, it is the level of submodule
method_name – by default, the forward method is processed but it could be another one
name – a name, mostly for debugging purposes
Examples:
args
<<<
import pprint import torch from onnx_diagnostic.export import ModelInputs class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x, y) # to check it works inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))] mi = ModelInputs(Model(), inputs) ds = mi.guess_dynamic_shapes() pprint.pprint(ds)
>>>
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}), {})
kwargs
<<<
import pprint import torch from onnx_diagnostic.export import ModelInputs class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x=x, y=y) # to check it works inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))] mi = ModelInputs(Model(), inputs) ds = mi.guess_dynamic_shapes() pprint.pprint(ds)
>>>
((), {'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
and and kwargs
<<<
import pprint import torch from onnx_diagnostic.export import ModelInputs class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x, y=y) # to check it works inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))] mi = ModelInputs(Model(), inputs) ds = mi.guess_dynamic_shapes() pprint.pprint(ds)
>>>
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},), {'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
torch.export.export()
does not like dynamic shapes defined both as args and kwargs. kwargs must be used.move_to_kwargs
modifies the inputs and the dynamic shapes to make the model and the given inputs exportable.<<<
import pprint import torch from onnx_diagnostic.export import ModelInputs from onnx_diagnostic.helpers import string_type class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x, y=y) # to check it works inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))] mi = ModelInputs(Model(), inputs) ds = mi.guess_dynamic_shapes() a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds) print("moved args:", string_type(a, with_shape=True)) print("moved kwargs:", string_type(kw, with_shape=True)) print("dynamic shapes:") pprint.pprint(nds)
>>>
moved args: (T1s5x6,) moved kwargs: dict(y:T1s1x6) dynamic shapes: ((), {'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>), 1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}, 'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
- property full_name¶
Returns a name and class name.
- guess_dynamic_dimensions(*tensors) Dict[int, Any] [source]¶
Infers the dynamic dimension from multiple shapes.
- guess_dynamic_shape_object(*objs: Any, msg: Callable | None = None) Any [source]¶
Guesses the dynamic shapes for one argument.
- guess_dynamic_shapes() Tuple[Tuple[Any, ...], Dict[str, Any]] [source]¶
Guesses the dynamic shapes for that module from two execution. If there is only one execution, then that would be static dimensions.
- property module_name_type¶
Returns name and module type.
- move_to_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any], dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]]) Tuple[Tuple[Any, ...], Dict[str, Any], Tuple[Tuple[Any, ...], Dict[str, Any]]] [source]¶
Uses the signatures to move positional arguments (args) to named arguments (kwargs) with the corresponding dynamic shapes. kwargs, dynamic_shapes are modified inplace.
- process_inputs(inputs: List[Tuple[Any, ...]] | List[Dict[str, Any]] | List[Tuple[Tuple[Any, ...], Dict[str, Any]]]) List[Tuple[Tuple[Any, ...], Dict[str, Any]]] [source]¶
Transforms a list of valid inputs, list of args, list of kwargs or list of both into a list of (args, kwargs).
- property true_model_name¶
Returns class name or module name.