to_onnx and infer dynamic shapes

We try to make it easier to export with dynamic shapes. To do that, we run the model at least twice with a different set of inputs and we try to guess the dynamic shapes found along the way.

Infer dynamic shapes

import onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.piece_by_piece import (
    trace_execution_piece_by_piece,
)


class MA(torch.nn.Module):
    def forward(self, x, y):
        return x + y


class MM(torch.nn.Module):
    def forward(self, x, y):
        return x * y


class MASMM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ma = MA()
        self.mm = MM()

    def forward(self, x, y, z):
        return self.ma(x, y) - self.mm(y, z)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ma = MA()
        self.masmm = MASMM()

    def forward(self, x):
        return self.ma(x, self.masmm(x, x, x))

The model.

model = Model()

Two sets of inputs.

inputs = [
    ((torch.randn((5, 6)),), {}),
    ((torch.randn((6, 6)),), {}),
]

Then we run the model, stores intermediates inputs and outputs, to finally guess the dynamic shapes.

diag = trace_execution_piece_by_piece(model, inputs, verbose=0)
pretty = diag.pretty_text(with_dynamic_shape=True)
print(pretty)
>>> __main__: Model
  DS=(({0: <_DimHint.DYNAMIC: 3>},), {})
  > ((CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496],),{})
  > ((CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367],),{})
    >>> ma: MA
      DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
      > ((CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496],CT1s5x6[-5.855304718017578,0.9999058246612549:A-0.7025150872766972]),{})
      > ((CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367],CT1s6x6[-11.211071014404297,0.9910044074058533:A-1.6185989735854998]),{})
      < (CT1s5x6[-7.473568439483643,2.216357707977295:A-0.68682375450929],)
      < (CT1s6x6[-13.70550537109375,2.2497808933258057:A-1.797695554792881],)
    <<<
    >>> masmm: MASMM
      DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
      > ((CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496],CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496],CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496]),{})
      > ((CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367],CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367],CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367]),{})
        >>> ma: MA
          DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
          > ((CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496],CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496]),{})
          > ((CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367],CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367]),{})
          < (CT1s5x6[-3.236527442932129,3.3668365478515625:A0.03138269620637099],)
          < (CT1s6x6[-4.988868713378906,6.797675609588623:A-0.3581931156416734],)
        <<<
        >>> mm: MM
          DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
          > ((CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496],CT1s5x6[-1.6182637214660645,1.6834182739257812:A0.015691348103185496]),{})
          > ((CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367],CT1s6x6[-2.494434356689453,3.3988378047943115:A-0.1790965578208367]),{})
          < (CT1s5x6[0.0006233706953935325,2.833897113800049:A0.7338977855455596],)
          < (CT1s6x6[0.0059312377125024796,11.552098274230957:A1.2604058814597212],)
        <<<
      < (CT1s5x6[-5.855304718017578,0.9999058246612549:A-0.7025150872766972],)
      < (CT1s6x6[-11.211071014404297,0.9910044074058533:A-1.6185989735854998],)
    <<<
  < (CT1s5x6[-7.473568439483643,2.216357707977295:A-0.68682375450929],)
  < (CT1s6x6[-13.70550537109375,2.2497808933258057:A-1.797695554792881],)
<<<

The dynamic shapes are obtained with:

ds = diag.guess_dynamic_shapes()
print(ds)
(({0: <_DimHint.DYNAMIC: 3>},), {})

Export

We use these dynamic shapes to export.

onx, builder = to_onnx(
    model, inputs[0][0], kwargs=inputs[0][1], dynamic_shapes=ds[0], return_builder=True
)
onnx.save(onx, "plot_exporter_recipes_c_ds.onnx")
print(builder.pretty_text())
 __call__ (
    x                             |'<_ParameterKind.POSITIONAL_OR_KEYWORD: 1>'
)
dyn---: DYN0 -> WrapSym(DYN0)
dyn---: s0 -> WrapSym(s0)
dynrev: DYN0 -> [('DYN0', SymInt(DYN0))]
dynsrc: DYN0 -> [{DYN0:('input_name', 'x'), DYN0:('axis', 0)}]
dynals: s0 -> 'DYN0'
t-dynshp: 0 -> {0:(0, _Dim(DYN0))}
opset: : 18
input:: x                                                                       |T1: DYN0 x 6
Add: x, x -> add                                                                |T1: DYN0 x 6                 - add_Tensor
Mul: x, x -> mul                                                                |T1: DYN0 x 6                 - mul_Tensor
Sub: add, mul -> sub                                                            |T1: DYN0 x 6                 - sub_Tensor
Add: x, sub -> output_0                                                         |T1: DYN0 x 6                 - add_Tensor2
output:: output_0                                                               |T1: DYN0 x 6

And visually.

plot exporter recipes c ds

Total running time of the script: (0 minutes 0.152 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

A few tricks about dynamic shapes

A few tricks about dynamic shapes

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

Gallery generated by Sphinx-Gallery