Note
Go to the end to download the full example code.
Dynamic Shapes for *args
, **kwargs
¶
Quick tour of dynamic shapes.
We first look at examples playing positional and names parameters
to understand how torch.export.export()
works.
args¶
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.export import ModelInputs
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y) # to check it works
ep = torch.export.export(model, (x, y))
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[5, 6]", y: "f32[1, 6]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:24 in forward, code: return x + y
add: "f32[5, 6]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {}
As expected there is no dynamic shapes.
We use onnx_diagnostic.export.ModelInputs
to define them from two set of valid inputs.
These inputs must have different value for the dynamic
dimensions.
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
{1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}),
{})
The function returns a tuple with two objects. The first one for the positional arguments, the other one for the named arguments. There is no named arguments. We we used the first result to export.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s35, s16]", y: "f32[1, s16]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:24 in forward, code: return x + y
add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
kwargs¶
We do the same with named arguments.
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y) # to check it works
tensor([[-1.5580, 0.9713, 0.1661, -0.6501, -2.6914, -0.7369],
[-1.6740, 1.6444, 0.9112, 1.9784, -2.6826, -0.0259],
[-3.1149, -0.0969, -0.6445, -0.7893, -2.4501, 0.7673],
[-2.0522, -0.1814, 0.0046, -0.0168, -3.0956, 1.6139],
[-1.3522, 0.5885, 2.8786, 1.1563, -1.2122, -0.0681]])
Two sets of valid inputs.
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
((),
{'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}})
And we export.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s32, s17]", y: "f32[1, s17]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:65 in forward, code: return x + y
add: "f32[s32, s17]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s32: VR[2, int_oo], s17: VR[2, int_oo]}
args and kwargs¶
torch.export.export()
does not like having dynami shapes
for both args and kwargs. We need to define them using one mechanism.
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y) # to check it works
tensor([[ 2.6854, 0.5077, -0.3375, 2.8778, -2.3589, -0.2462],
[ 1.5233, -1.2987, -1.6942, 2.4016, -0.7991, -0.4830],
[ 3.0727, -0.4697, -0.6273, 0.7927, -1.8176, -0.1301],
[ 1.5259, 2.2313, -0.2053, 2.4364, -2.3864, -1.0114],
[ 1.4216, 0.0266, 0.7536, 1.4343, -1.7989, -1.1391]])
Two sets of valid inputs with positional and names arguments.
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},),
{'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}})
This does not work with torch.export.export()
so
we use a method to move the positional dynamic shapes to
named one. The method relies on the signature of the
forward method.
new_args, new_kwargs, new_ds = mi.move_to_kwargs(*mi.inputs[0], ds)
pprint.pprint(new_ds)
((),
{'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}})
And we export.
ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s35, s16]", y: "f32[1, s16]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:95 in forward, code: return x + y
add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
doc.plot_legend("dynamic shapes\n*args, **kwargs", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 7.157 seconds)
Related examples

Steel method forward to guess the dynamic shapes (with Tiny-LLM)