Note
Go to the end to download the full example code.
Infer dynamic shapes before exporting¶
Dynamic shapes need to be specified to get a model able to cope with different dimensions. Input rank are expected to be the same but the dimension may change. The user has the ability to set them up or to call a function able to infer them from two sets of inputs having different values for the dynamic dimensions.
Infer dynamic shapes¶
import onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.piece_by_piece import (
trace_execution_piece_by_piece,
)
class MA(torch.nn.Module):
def forward(self, x, y):
return x + y
class MM(torch.nn.Module):
def forward(self, x, y):
return x * y
class MASMM(torch.nn.Module):
def __init__(self):
super().__init__()
self.ma = MA()
self.mm = MM()
def forward(self, x, y, z):
return self.ma(x, y) - self.mm(y, z)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.ma = MA()
self.masmm = MASMM()
def forward(self, x):
return self.ma(x, self.masmm(x, x, x))
The model.
model = Model()
Two sets of inputs.
inputs = [
((torch.randn((5, 6)),), {}),
((torch.randn((6, 6)),), {}),
]
Then we run the model, stores intermediates inputs and outputs, to finally guess the dynamic shapes.
diag = trace_execution_piece_by_piece(model, inputs, verbose=0)
pretty = diag.pretty_text(with_dynamic_shape=True)
print(pretty)
>>> __main__: Model
DS=(({0: <_DimHint.DYNAMIC: 3>},), {})
> ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],),{})
> ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],),{})
>>> ma: MA
DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
> ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-6.960428714752197,0.9987935423851013:A-1.0601868790884812]),{})
> ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-10.313776016235352,0.9942804574966431:A-0.8953918929522237]),{})
< (CT1s5x6[-8.781851768493652,2.249394655227661:A-1.0655590564012527],)
< (CT1s6x6[-12.677371978759766,2.248688220977783:A-0.8722537219938304],)
<<<
>>> masmm: MASMM
DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
> ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793]),{})
> ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148]),{})
>>> ma: MA
DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
> ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793]),{})
> ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148]),{})
< (CT1s5x6[-3.6428463459014893,6.616884231567383:A-0.010744354128837586],)
< (CT1s6x6[-4.72719144821167,3.6642446517944336:A0.04627636964950296],)
<<<
>>> mm: MM
DS=(({0: <_DimHint.DYNAMIC: 3>}, {0: <_DimHint.DYNAMIC: 3>}), {})
> ((CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793],CT1s5x6[-1.8214231729507446,3.3084421157836914:A-0.005372177064418793]),{})
> ((CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148],CT1s6x6[-2.363595724105835,1.8321223258972168:A0.02313818482475148]),{})
< (CT1s5x6[0.0032380588818341494,10.945789337158203:A1.0494425284443423],)
< (CT1s6x6[0.0002757608890533447,5.586584568023682:A0.9416682639453534],)
<<<
< (CT1s5x6[-6.960428714752197,0.9987935423851013:A-1.0601868790884812],)
< (CT1s6x6[-10.313776016235352,0.9942804574966431:A-0.8953918929522237],)
<<<
< (CT1s5x6[-8.781851768493652,2.249394655227661:A-1.0655590564012527],)
< (CT1s6x6[-12.677371978759766,2.248688220977783:A-0.8722537219938304],)
<<<
The dynamic shapes are obtained with:
ds = diag.guess_dynamic_shapes()
print(ds)
(({0: <_DimHint.DYNAMIC: 3>},), {})
Export¶
We use these dynamic shapes to export.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, 6]"):
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:28 in forward, code: return x + y
add: "f32[s0, 6]" = torch.ops.aten.add.Tensor(x, x)
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:33 in forward, code: return x * y
mul: "f32[s0, 6]" = torch.ops.aten.mul.Tensor(x, x)
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:43 in forward, code: return self.ma(x, y) - self.mm(y, z)
sub: "f32[s0, 6]" = torch.ops.aten.sub.Tensor(add, mul); add = mul = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_infer_ds.py:28 in forward, code: return x + y
add_1: "f32[s0, 6]" = torch.ops.aten.add.Tensor(x, sub); x = sub = None
return (add_1,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {s0: VR[2, int_oo]}
We can use that graph to get the onnx model.
dyn---: s0 -> 's0'
opset: : 18
input:: x |T1: s0 x 6
Add: x, x -> add |T1: s0 x 6 - add_Tensor
Mul: x, x -> mul |T1: s0 x 6 - mul_Tensor
Sub: add, mul -> sub |T1: s0 x 6 - sub_Tensor
Add: x, sub -> output_0 |T1: s0 x 6 - add_Tensor2
output:: output_0 |T1: s0 x 6
And visually.

Total running time of the script: (0 minutes 0.439 seconds)
Related examples

Use DYNAMIC or AUTO when dynamic shapes has constraints