Infer dynamic shapes before exporting

Dynamic shapes need to be specified to get a model able to cope with different dimensions. Input rank are expected to be the same but the dimension may change. The user has the ability to set them up or to call a function able to infer them from two sets of inputs having different values for the dynamic dimensions.

Infer dynamic shapes

import onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.piece_by_piece import (
    trace_execution_piece_by_piece,
)


class MA(torch.nn.Module):
    def forward(self, x, y):
        return x + y


class MM(torch.nn.Module):
    def forward(self, x, y):
        return x * y


class MASMM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ma = MA()
        self.mm = MM()

    def forward(self, x, y, z):
        return self.ma(x, y) - self.mm(y, z)


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ma = MA()
        self.masmm = MASMM()

    def forward(self, x):
        return self.ma(x, self.masmm(x, x, x))

The model.

model = Model()

Two sets of inputs.

inputs = [
    ((torch.randn((5, 6)),), {}),
    ((torch.randn((6, 6)),), {}),
]

Then we run the model, stores intermediates inputs and outputs, to finally guess the dynamic shapes.

diag = trace_execution_piece_by_piece(model, inputs, verbose=0)
pretty = diag.pretty_text(with_dynamic_shape=True)
print(pretty)
>>> __main__: Model
  DS=(({0: <_DimHint.DYNAMIC: 3>},), {})
  > ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],),{})
  > ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],),{})
    >>> ma: MA
      DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
      > ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-6.960428714752197,0.9987935423851013:A-1.0601868790884812]),{})
      > ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-10.313776016235352,0.9942804574966431:A-0.8953918929522237]),{})
      < (CT1s5x6[-8.781851768493652,2.249394655227661:A-1.0655590564012527],)
      < (CT1s6x6[-12.677371978759766,2.248688220977783:A-0.8722537219938304],)
    <<<
    >>> masmm: MASMM
      DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
      > ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793]),{})
      > ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148]),{})
        >>> ma: MA
          DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
          > ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793]),{})
          > ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148]),{})
          < (CT1s5x6[-3.6428463459014893,6.616884231567383:A-0.010744354128837586],)
          < (CT1s6x6[-4.72719144821167,3.6642446517944336:A0.04627636964950296],)
        <<<
        >>> mm: MM
          DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
          > ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793]),{})
          > ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148]),{})
          < (CT1s5x6[0.0032380588818341494,10.945789337158203:A1.0494425284443423],)
          < (CT1s6x6[0.0002757608890533447,5.586584568023682:A0.9416682639453534],)
        <<<
      < (CT1s5x6[-6.960428714752197,0.9987935423851013:A-1.0601868790884812],)
      < (CT1s6x6[-10.313776016235352,0.9942804574966431:A-0.8953918929522237],)
    <<<
  < (CT1s5x6[-8.781851768493652,2.249394655227661:A-1.0655590564012527],)
  < (CT1s6x6[-12.677371978759766,2.248688220977783:A-0.8722537219938304],)
<<<

The dynamic shapes are obtained with:

(({0: <_DimHint.DYNAMIC: 3>},), {})

Export

We use these dynamic shapes to export.

ep = torch.export.export(model, inputs[0][0], kwargs=inputs[0][1], dynamic_shapes=ds[0])
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, 6]"):
             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:28 in forward, code: return x + y
            add: "f32[s0, 6]" = torch.ops.aten.add.Tensor(x, x)

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:33 in forward, code: return x * y
            mul: "f32[s0, 6]" = torch.ops.aten.mul.Tensor(x, x)

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:43 in forward, code: return self.ma(x, y) - self.mm(y, z)
            sub: "f32[s0, 6]" = torch.ops.aten.sub.Tensor(add, mul);  add = mul = None

             # File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:28 in forward, code: return x + y
            add_1: "f32[s0, 6]" = torch.ops.aten.add.Tensor(x, sub);  x = sub = None
            return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {s0: VR[2, int_oo]}

We can use that graph to get the onnx model.

onx, builder = to_onnx(ep, return_builder=True)
onnx.save(onx, "plot_exporter_exporter_infer_ds.onnx")
print(builder.pretty_text())
dyn---: s0 -> 's0'
opset: : 18
input:: x                                                                       |T1: s0 x 6
Add: x, x -> add                                                                |T1: s0 x 6                   - add_Tensor
Mul: x, x -> mul                                                                |T1: s0 x 6                   - mul_Tensor
Sub: add, mul -> sub                                                            |T1: s0 x 6                   - sub_Tensor
Add: x, sub -> output_0                                                         |T1: s0 x 6                   - add_Tensor2
output:: output_0                                                               |T1: s0 x 6

And visually.

plot exporter exporter infer ds

Total running time of the script: (0 minutes 0.439 seconds)

Related examples

Export Phi-3.5-mini-instruct piece by piece

Export Phi-3.5-mini-instruct piece by piece

Use DYNAMIC or AUTO when dynamic shapes has constraints

Use DYNAMIC or AUTO when dynamic shapes has constraints

torch.onnx.export: Rename Dynamic Shapes

torch.onnx.export: Rename Dynamic Shapes

to_onnx: Rename Dynamic Shapes

to_onnx: Rename Dynamic Shapes

to_onnx and submodules from LLMs

to_onnx and submodules from LLMs

Gallery generated by Sphinx-Gallery