-m onnx_diagnostic compare … compares two models

Description

The command lines compares two models assuming they represent the same models and most parts of both are the same. Different options were used to export or an optimization was different. This highlights the differences.

    usage: compare [-h] model1 model2
    
    Compares two onnx models by aligning the nodes between both models. This is done through an edit distance.
    
    positional arguments:
      model1      first model to compare
      model2      second model to compare
    
    options:
      -h, --help  show this help message and exit
    
    Each element (initializer, input, node, output) of the model is converted into an observation. Then it defines a distance between two elements. And finally, it finds the best
    alignment with an edit distance.

Example

python -m onnx_diagnostic compare <mode1.onnx> <mode1.onnx>

This example is based on python but it produces the same output than the command line.

<<<

import torch
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 16, 5)
        self.fc1 = torch.nn.Linear(144, 64)
        self.fc2 = torch.nn.Linear(64, 128)
        self.fc3 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.nn.functional.max_pool2d(
            torch.nn.functional.relu(self.conv1(x)), (4, 4)
        )
        # x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        y = self.fc3(x)
        return y


model = Model()
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
dynamic_shapes = ({0: "batch", 3: "dim"},)
onnx_optimized = to_onnx(
    model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
).model_proto
onnx_not_optimized = to_onnx(
    model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
).model_proto
seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = ObsComparePair.to_str(pair_cmp)
print(text)

>>>

    0.0000e+00 | 0000 INITIA INT64    2                                  init7_s2_0_-1                       | 0000 INITIA INT64    2                                  init7_s2_0_-1                      
    0.0000e+00 | 0001 INITIA FLOAT    64x144                             GemmTransposePattern--p_fc1_weight: | 0001 INITIA FLOAT    64x144                             GemmTransposePattern--p_fc1_weight:
    0.0000e+00 | 0002 INITIA FLOAT    128x64                             GemmTransposePattern--p_fc2_weight: | 0002 INITIA FLOAT    128x64                             GemmTransposePattern--p_fc2_weight:
    0.0000e+00 | 0003 INITIA FLOAT    10x128                             GemmTransposePattern--p_fc3_weight: | 0003 INITIA FLOAT    10x128                             GemmTransposePattern--p_fc3_weight:
    0.0000e+00 | 0004 INITIA FLOAT    16x3x5x5                           conv1.weight                        | 0004 INITIA FLOAT    16x3x5x5                           conv1.weight                       
    0.0000e+00 | 0005 INITIA FLOAT    16                                 conv1.bias                          | 0005 INITIA FLOAT    16                                 conv1.bias                         
    0.0000e+00 | 0006 INITIA FLOAT    64                                 fc1.bias                            | 0006 INITIA FLOAT    64                                 fc1.bias                           
    0.0000e+00 | 0007 INITIA FLOAT    128                                fc2.bias                            | 0007 INITIA FLOAT    128                                fc2.bias                           
    0.0000e+00 | 0008 INITIA FLOAT    10                                 fc3.bias                            | 0008 INITIA FLOAT    10                                 fc3.bias                           
    0.0000e+00 | 0009 INPUT  FLOAT    batchx3x16xdim                     x                                   | 0009 INPUT  FLOAT    batchx3x16xdim                     x                                  
    1.9700e+02 | 0010 NODE   FLOAT    batchx16x12xconv_f Conv            conv2d                              | 0010 NODE   FLOAT    batchx16x12xconv_f FusedConv       relu                               
    1.1980e+03 | 0011 NODE   FLOAT    batchx16x12xconv_f Relu            relu                                |                                                                                            
    1.1980e+03 | 0012 NODE   FLOAT    batchx16x3xconv_f3 MaxPool         max_pool2d                          | 0011 NODE   FLOAT    batchx16x3xconv_f3 MaxPool         max_pool2d                         
    1.1980e+03 | 0013 NODE   FLOAT    batchx48*dim//4-48 Reshape         flatten                             | 0012 NODE   FLOAT    batchx48*dim//4-48 Reshape         flatten                            
    1.1980e+03 | 0014 NODE   FLOAT    batchx64           Gemm            linear                              | 0013 NODE   FLOAT    batchx64           Gemm            linear                             
    1.1980e+03 | 0015 NODE   FLOAT    batchx64           Relu            relu_1                              | 0014 NODE   FLOAT    batchx64           Relu            relu_1                             
    1.1980e+03 | 0016 NODE   FLOAT    batchx128          Gemm            linear_1                            | 0015 NODE   FLOAT    batchx128          Gemm            linear_1                           
    1.1980e+03 | 0017 NODE   FLOAT    batchx128          Relu            relu_2                              | 0016 NODE   FLOAT    batchx128          Relu            relu_2                             
    1.1980e+03 | 0018 NODE   FLOAT    batchx10           Gemm            output_0                            | 0017 NODE   FLOAT    batchx10           Gemm            output_0                           
    1.1980e+03 | 0019 OUTPUT FLOAT    batchx10                           output_0                            | 0018 OUTPUT FLOAT    batchx10                           output_0