Dynamic Shapes for *args, **kwargs

Quick tour of dynamic shapes. We first look at examples playing positional and names parameters to understand how torch.export.export() works.

args

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.export import ModelInputs


class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y)  # to check it works

ep = torch.export.export(model, (x, y))
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[5, 6]", y: "f32[1, 6]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:24 in forward, code: return x + y
            add: "f32[5, 6]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {}

As expected there is no dynamic shapes. We use onnx_diagnostic.export.ModelInputs to define them from two set of valid inputs. These inputs must have different value for the dynamic dimensions.

(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True),
   1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)},
  {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

The function returns a tuple with two objects. The first one for the positional arguments, the other one for the named arguments. There is no named arguments. We we used the first result to export.

ep = torch.export.export(model, (x, y), dynamic_shapes=ds[0])
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]", y: "f32[1, s16]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:24 in forward, code: return x + y
            add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}

kwargs

We do the same with named arguments.

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y)  # to check it works
tensor([[-1.2552, -2.0038,  0.1856,  0.3060, -2.0445, -0.6507],
        [-0.4602, -1.8862, -1.1222,  0.8848,  1.8047,  0.7038],
        [-0.5866, -1.2445, -0.1098,  1.0760,  0.1222,  0.0717],
        [ 0.8913, -3.6796,  0.6239,  0.3740,  0.9686, -0.8384],
        [-2.2781, -2.1093, -1.6199,  1.4248,  1.5668, -0.5913]])

Two sets of valid inputs.

((),
 {'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True),
        1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True)},
  'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True)}})

And we export.

ep = torch.export.export(model, (), kwargs=dict(x=x, y=y), dynamic_shapes=ds[1])
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s32, s17]", y: "f32[1, s17]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:65 in forward, code: return x + y
            add: "f32[s32, s17]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s32: VR[2, int_oo], s17: VR[2, int_oo]}

args and kwargs

torch.export.export() does not like having dynami shapes for both args and kwargs. We need to define them using one mechanism.

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y)  # to check it works
tensor([[-0.3231, -1.2757,  0.5792, -2.7636, -1.8815,  0.2091],
        [ 2.1975, -2.5229, -0.4569, -1.4287,  0.1854,  0.9493],
        [ 2.0366,  0.8558, -1.3106, -3.6543,  1.0809,  0.9447],
        [ 0.8163, -0.1944, -1.0441, -4.7692, -0.5105,  1.3996],
        [ 1.9996, -0.8009, -2.5270, -4.0354, -0.4597,  1.3452]])

Two sets of valid inputs with positional and names arguments.

inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True),
   1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)},),
 {'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True)}})

This does not work with torch.export.export() so we use a method to move the positional dynamic shapes to named one. The method relies on the signature of the forward method.

((),
 {'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True),
        1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True)},
  'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                    min=None,
                    max=None,
                    _factory=True)}})

And we export.

ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s35, s16]", y: "f32[1, s16]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_args_kwargs.py:95 in forward, code: return x + y
            add: "f32[s35, s16]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {s35: VR[2, int_oo], s16: VR[2, int_oo]}
doc.plot_legend("dynamic shapes\n*args, **kwargs", "torch.export.export", "tomato")
plot export with args kwargs

Total running time of the script: (0 minutes 5.552 seconds)

Related examples

Export with DynamicCache and dynamic shapes

Export with DynamicCache and dynamic shapes

Untrained microsoft/phi-2

Untrained microsoft/phi-2

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery