-m onnx_diagnostic compare … compares two models¶
Description¶
The command lines compares two models assuming they represent the same models and most parts of both are the same. Different options were used to export or an optimization was different. This highlights the differences.
usage: compare [-h] model1 model2
Compares two onnx models by aligning the nodes between both models. This is done through an edit distance.
positional arguments:
model1 first model to compare
model2 second model to compare
options:
-h, --help show this help message and exit
Each element (initializer, input, node, output) of the model is converted into an observation. Then it defines a distance between two elements. And finally, it finds the best
alignment with an edit distance.
Example¶
python -m onnx_diagnostic compare <mode1.onnx> <mode1.onnx>
This example is based on python but it produces the same output than the command line.
<<<
import torch
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 5)
self.fc1 = torch.nn.Linear(144, 64)
self.fc2 = torch.nn.Linear(64, 128)
self.fc3 = torch.nn.Linear(128, 10)
def forward(self, x):
x = torch.nn.functional.max_pool2d(
torch.nn.functional.relu(self.conv1(x)), (4, 4)
)
# x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
y = self.fc3(x)
return y
model = Model()
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
dynamic_shapes = ({0: "batch", 3: "dim"},)
onnx_optimized = to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
).model_proto
onnx_not_optimized = to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
).model_proto
seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = ObsComparePair.to_str(pair_cmp)
print(text)
>>>
0.0000e+00 | 0000 INITIA INT64 2 init7_s2_0_-1 | 0000 INITIA INT64 2 init7_s2_0_-1
0.0000e+00 | 0001 INITIA FLOAT 64x144 GemmTransposePattern--p_fc1_weight: | 0001 INITIA FLOAT 64x144 GemmTransposePattern--p_fc1_weight:
0.0000e+00 | 0002 INITIA FLOAT 128x64 GemmTransposePattern--p_fc2_weight: | 0002 INITIA FLOAT 128x64 GemmTransposePattern--p_fc2_weight:
0.0000e+00 | 0003 INITIA FLOAT 10x128 GemmTransposePattern--p_fc3_weight: | 0003 INITIA FLOAT 10x128 GemmTransposePattern--p_fc3_weight:
0.0000e+00 | 0004 INITIA FLOAT 16x3x5x5 conv1.weight | 0004 INITIA FLOAT 16x3x5x5 conv1.weight
0.0000e+00 | 0005 INITIA FLOAT 16 conv1.bias | 0005 INITIA FLOAT 16 conv1.bias
0.0000e+00 | 0006 INITIA FLOAT 64 fc1.bias | 0006 INITIA FLOAT 64 fc1.bias
0.0000e+00 | 0007 INITIA FLOAT 128 fc2.bias | 0007 INITIA FLOAT 128 fc2.bias
0.0000e+00 | 0008 INITIA FLOAT 10 fc3.bias | 0008 INITIA FLOAT 10 fc3.bias
0.0000e+00 | 0009 INPUT FLOAT batchx3x16xdim x | 0009 INPUT FLOAT batchx3x16xdim x
1.9700e+02 | 0010 NODE FLOAT batchx16x12xconv_f Conv conv2d | 0010 NODE FLOAT batchx16x12xconv_f FusedConv relu
1.1980e+03 | 0011 NODE FLOAT batchx16x12xconv_f Relu relu |
1.1980e+03 | 0012 NODE FLOAT batchx16x3xconv_f3 MaxPool max_pool2d | 0011 NODE FLOAT batchx16x3xconv_f3 MaxPool max_pool2d
1.1980e+03 | 0013 NODE FLOAT batchx48*dim//4-48 Reshape flatten | 0012 NODE FLOAT batchx48*dim//4-48 Reshape flatten
1.1980e+03 | 0014 NODE FLOAT batchx64 Gemm linear | 0013 NODE FLOAT batchx64 Gemm linear
1.1980e+03 | 0015 NODE FLOAT batchx64 Relu relu_1 | 0014 NODE FLOAT batchx64 Relu relu_1
1.1980e+03 | 0016 NODE FLOAT batchx128 Gemm linear_1 | 0015 NODE FLOAT batchx128 Gemm linear_1
1.1980e+03 | 0017 NODE FLOAT batchx128 Relu relu_2 | 0016 NODE FLOAT batchx128 Relu relu_2
1.1980e+03 | 0018 NODE FLOAT batchx10 Gemm output_0 | 0017 NODE FLOAT batchx10 Gemm output_0
1.1980e+03 | 0019 OUTPUT FLOAT batchx10 output_0 | 0018 OUTPUT FLOAT batchx10 output_0