onnx_diagnostic.torch_onnx.compare

class onnx_diagnostic.torch_onnx.compare.ObsCompare(position: int, kind: ObsType, name_or_outputs: Tuple[str], itype: int = 0, index: int = 0, shape: Tuple[Tuple[int | str, ...]] | None = None, op_type: str = '', comment: str = '')[source][source]

The description of an observation, a node, an input, an output, an initializer.

Parameters:
  • position – index of this observation in the original model

  • kind – node type, see ObsType

  • name_or_outputs – name of an initializer or the outputs of a node

  • itype – onnx type

  • index – index of an input or output

  • shape – shape

  • op_type – node op_type

  • comment – comment, unused

distance(obs: ObsCompare) float[source][source]

Computes a cost between two observations.

classmethod obs_sequence_from_model(model: ModelProto | GraphProto) List[ObsCompare][source][source]

Creates a sequence of observations bases on a model.

Parameters:

model – model

Returns:

sequence of observations

class onnx_diagnostic.torch_onnx.compare.ObsComparePair(side1: ObsCompare | None, side2: ObsCompare | None, distance: float)[source][source]

Defines a pair of comparison objects

Parameters:
  • side1 – object from first side

  • side2 – object from first side

  • distance – distance

classmethod distance_sequence(s1: List[ObsCompare], s2: List[ObsCompare]) Tuple[float, List[Tuple[int, int]], List[ObsComparePair]][source][source]

Computes the distance between two sequences of results.

Parameters:
  • s1 – first sequence

  • s2 – second sequence

Returns:

distance and alignment

An example:

<<<

import torch
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 16, 5)
        self.fc1 = torch.nn.Linear(144, 64)
        self.fc2 = torch.nn.Linear(64, 128)
        self.fc3 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.nn.functional.max_pool2d(
            torch.nn.functional.relu(self.conv1(x)),
            (4, 4),
        )
        # x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        y = self.fc3(x)
        return y


model = Model()
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
dynamic_shapes = ({0: "batch", 3: "dim"},)
onnx_optimized = to_onnx(
    model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
).model_proto
onnx_not_optimized = to_onnx(
    model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
).model_proto
seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = ObsComparePair.to_str(pair_cmp)
print(text)

>>>

    0.0000e+00 | 0000 INITIA INT64    2                                  init7_s2_0_-1                       | 0000 INITIA INT64    2                                  init7_s2_0_-1                      
    0.0000e+00 | 0001 INITIA FLOAT    64x144                             GemmTransposePattern--p_fc1_weight: | 0001 INITIA FLOAT    64x144                             GemmTransposePattern--p_fc1_weight:
    0.0000e+00 | 0002 INITIA FLOAT    128x64                             GemmTransposePattern--p_fc2_weight: | 0002 INITIA FLOAT    128x64                             GemmTransposePattern--p_fc2_weight:
    0.0000e+00 | 0003 INITIA FLOAT    10x128                             GemmTransposePattern--p_fc3_weight: | 0003 INITIA FLOAT    10x128                             GemmTransposePattern--p_fc3_weight:
    0.0000e+00 | 0004 INITIA FLOAT    16x3x5x5                           conv1.weight                        | 0004 INITIA FLOAT    16x3x5x5                           conv1.weight                       
    0.0000e+00 | 0005 INITIA FLOAT    16                                 conv1.bias                          | 0005 INITIA FLOAT    16                                 conv1.bias                         
    0.0000e+00 | 0006 INITIA FLOAT    64                                 fc1.bias                            | 0006 INITIA FLOAT    64                                 fc1.bias                           
    0.0000e+00 | 0007 INITIA FLOAT    128                                fc2.bias                            | 0007 INITIA FLOAT    128                                fc2.bias                           
    0.0000e+00 | 0008 INITIA FLOAT    10                                 fc3.bias                            | 0008 INITIA FLOAT    10                                 fc3.bias                           
    0.0000e+00 | 0009 INPUT  FLOAT    batchx3x16xdim                     x                                   | 0009 INPUT  FLOAT    batchx3x16xdim                     x                                  
    1.9700e+02 | 0010 NODE   FLOAT    batchx16x12xconv_f Conv            conv2d                              | 0010 NODE   FLOAT    batchx16x12xconv_f FusedConv       relu                               
    1.1980e+03 | 0011 NODE   FLOAT    batchx16x12xconv_f Relu            relu                                |                                                                                            
    1.1980e+03 | 0012 NODE   FLOAT    batchx16x3xconv_f3 MaxPool         max_pool2d                          | 0011 NODE   FLOAT    batchx16x3xconv_f3 MaxPool         max_pool2d                         
    1.1980e+03 | 0013 NODE   FLOAT    batchx48*dim//4-48 Reshape         flatten                             | 0012 NODE   FLOAT    batchx48*dim//4-48 Reshape         flatten                            
    1.1980e+03 | 0014 NODE   FLOAT    batchx64           Gemm            linear                              | 0013 NODE   FLOAT    batchx64           Gemm            linear                             
    1.1980e+03 | 0015 NODE   FLOAT    batchx64           Relu            relu_1                              | 0014 NODE   FLOAT    batchx64           Relu            relu_1                             
    1.1980e+03 | 0016 NODE   FLOAT    batchx128          Gemm            linear_1                            | 0015 NODE   FLOAT    batchx128          Gemm            linear_1                           
    1.1980e+03 | 0017 NODE   FLOAT    batchx128          Relu            relu_2                              | 0016 NODE   FLOAT    batchx128          Relu            relu_2                             
    1.1980e+03 | 0018 NODE   FLOAT    batchx10           Gemm            output_0                            | 0017 NODE   FLOAT    batchx10           Gemm            output_0                           
    1.1980e+03 | 0019 OUTPUT FLOAT    batchx10                           output_0                            | 0018 OUTPUT FLOAT    batchx10                           output_0
classmethod to_str(seq: List[ObsComparePair]) str[source][source]

Displays every pair in text.

class onnx_diagnostic.torch_onnx.compare.ObsType(*values)[source][source]

Observation kind.