
class onnx_diagnostic.export.dynamic_shapes.ModelInputs(model: Module, inputs: List[Tuple[Any, ...]] | List[Dict[str, Any]] | List[Tuple[Tuple[Any, ...], Dict[str, Any]]], level: int = 0, method_name: str = 'forward', name: str = 'main')[source]

Wraps a model and a couple of sets of valid inputs. Based on that information, the class is able to infer the dynamic shapes for torch.export.export().

  • model – model to export

  • inputs – list of valid set of inputs

  • level – if this module is a submodule, it is the level of submodule

  • method_name – by default, the forward method is processed but it could be another one

  • name – a name, mostly for debugging purposes




import pprint
import torch
from onnx_diagnostic.export import ModelInputs

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y)  # to check it works

inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()


    (({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
       1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
      {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}),



import pprint
import torch
from onnx_diagnostic.export import ModelInputs

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y)  # to check it works

inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()


     {'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
            1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
      'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})

and and kwargs


import pprint
import torch
from onnx_diagnostic.export import ModelInputs

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y)  # to check it works

inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()


    (({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
       1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},),
     {'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})

torch.export.export() does not like dynamic shapes defined both as args and kwargs. kwargs must be used. move_to_kwargs modifies the inputs and the dynamic shapes to make the model and the given inputs exportable.


import pprint
import torch
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.helpers import string_type

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y

model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y)  # to check it works

inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()

a, kw, nds = mi.move_to_kwargs(*mi.inputs[0], ds)
print("moved args:", string_type(a, with_shape=True))
print("moved kwargs:", string_type(kw, with_shape=True))
print("dynamic shapes:")


    moved args: (T1s5x6,)
    moved kwargs: dict(y:T1s1x6)
    dynamic shapes:
     {'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
            1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
      'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
property full_name

Returns a name and class name.

guess_dynamic_dimensions(*tensors) Dict[int, Any][source]

Infers the dynamic dimension from multiple shapes.

guess_dynamic_shape_object(*objs: Any, msg: Callable | None = None) Any[source]

Guesses the dynamic shapes for one argument.

guess_dynamic_shapes() Tuple[Tuple[Any, ...], Dict[str, Any]][source]

Guesses the dynamic shapes for that module from two execution. If there is only one execution, then that would be static dimensions.

property module_name_type

Returns name and module type.

move_to_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any], dynamic_shapes: Tuple[Tuple[Any, ...], Dict[str, Any]]) Tuple[Tuple[Any, ...], Dict[str, Any], Tuple[Tuple[Any, ...], Dict[str, Any]]][source]

Uses the signatures to move positional arguments (args) to named arguments (kwargs) with the corresponding dynamic shapes. kwargs, dynamic_shapes are modified inplace.

process_inputs(inputs: List[Tuple[Any, ...]] | List[Dict[str, Any]] | List[Tuple[Tuple[Any, ...], Dict[str, Any]]]) List[Tuple[Tuple[Any, ...], Dict[str, Any]]][source]

Transforms a list of valid inputs, list of args, list of kwargs or list of both into a list of (args, kwargs).

property true_model_name

Returns class name or module name.