onnx_diagnostic.torch_onnx.compare¶
- class onnx_diagnostic.torch_onnx.compare.ObsCompare(position: int, kind: ObsType, name_or_outputs: Tuple[str], itype: int = 0, index: int = 0, shape: Tuple[Tuple[int | str, ...]] | None = None, op_type: str = '', comment: str = '')[source][source]¶
The description of an observation, a node, an input, an output, an initializer.
- Parameters:
position – index of this observation in the original model
kind – node type, see
ObsTypename_or_outputs – name of an initializer or the outputs of a node
itype – onnx type
index – index of an input or output
shape – shape
op_type – node op_type
comment – comment, unused
- distance(obs: ObsCompare) float[source][source]¶
Computes a cost between two observations.
- classmethod obs_sequence_from_model(model: ModelProto | GraphProto) List[ObsCompare][source][source]¶
Creates a sequence of observations bases on a model.
- Parameters:
model – model
- Returns:
sequence of observations
- class onnx_diagnostic.torch_onnx.compare.ObsComparePair(side1: ObsCompare | None, side2: ObsCompare | None, distance: float)[source][source]¶
Defines a pair of comparison objects
- Parameters:
side1 – object from first side
side2 – object from first side
distance – distance
- classmethod distance_sequence(s1: List[ObsCompare], s2: List[ObsCompare]) Tuple[float, List[Tuple[int, int]], List[ObsComparePair]][source][source]¶
Computes the distance between two sequences of results.
- Parameters:
s1 – first sequence
s2 – second sequence
- Returns:
distance and alignment
An example:
<<<
import torch from onnx_diagnostic.export.api import to_onnx from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare class Model(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 16, 5) self.fc1 = torch.nn.Linear(144, 64) self.fc2 = torch.nn.Linear(64, 128) self.fc3 = torch.nn.Linear(128, 10) def forward(self, x): x = torch.nn.functional.max_pool2d( torch.nn.functional.relu(self.conv1(x)), (4, 4), ) # x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = torch.flatten(x, 1) x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) y = self.fc3(x) return y model = Model() x = torch.randn((2, 3, 16, 17), dtype=torch.float32) dynamic_shapes = ({0: "batch", 3: "dim"},) onnx_optimized = to_onnx( model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True ).model_proto onnx_not_optimized = to_onnx( model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False ).model_proto seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized) seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized) _dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2) text = ObsComparePair.to_str(pair_cmp) print(text)
>>>
0.0000e+00 | 0000 INITIA INT64 2 init7_s2_0_-1 | 0000 INITIA INT64 2 init7_s2_0_-1 0.0000e+00 | 0001 INITIA FLOAT 64x144 GemmTransposePattern--p_fc1_weight: | 0001 INITIA FLOAT 64x144 GemmTransposePattern--p_fc1_weight: 0.0000e+00 | 0002 INITIA FLOAT 128x64 GemmTransposePattern--p_fc2_weight: | 0002 INITIA FLOAT 128x64 GemmTransposePattern--p_fc2_weight: 0.0000e+00 | 0003 INITIA FLOAT 10x128 GemmTransposePattern--p_fc3_weight: | 0003 INITIA FLOAT 10x128 GemmTransposePattern--p_fc3_weight: 0.0000e+00 | 0004 INITIA FLOAT 16x3x5x5 conv1.weight | 0004 INITIA FLOAT 16x3x5x5 conv1.weight 0.0000e+00 | 0005 INITIA FLOAT 16 conv1.bias | 0005 INITIA FLOAT 16 conv1.bias 0.0000e+00 | 0006 INITIA FLOAT 64 fc1.bias | 0006 INITIA FLOAT 64 fc1.bias 0.0000e+00 | 0007 INITIA FLOAT 128 fc2.bias | 0007 INITIA FLOAT 128 fc2.bias 0.0000e+00 | 0008 INITIA FLOAT 10 fc3.bias | 0008 INITIA FLOAT 10 fc3.bias 0.0000e+00 | 0009 INPUT FLOAT batchx3x16xdim x | 0009 INPUT FLOAT batchx3x16xdim x 1.9700e+02 | 0010 NODE FLOAT batchx16x12xconv_f Conv conv2d | 0010 NODE FLOAT batchx16x12xconv_f FusedConv relu 1.1980e+03 | 0011 NODE FLOAT batchx16x12xconv_f Relu relu | 1.1980e+03 | 0012 NODE FLOAT batchx16x3xconv_f3 MaxPool max_pool2d | 0011 NODE FLOAT batchx16x3xconv_f3 MaxPool max_pool2d 1.1980e+03 | 0013 NODE FLOAT batchx48*dim//4-48 Reshape flatten | 0012 NODE FLOAT batchx48*dim//4-48 Reshape flatten 1.1980e+03 | 0014 NODE FLOAT batchx64 Gemm linear | 0013 NODE FLOAT batchx64 Gemm linear 1.1980e+03 | 0015 NODE FLOAT batchx64 Relu relu_1 | 0014 NODE FLOAT batchx64 Relu relu_1 1.1980e+03 | 0016 NODE FLOAT batchx128 Gemm linear_1 | 0015 NODE FLOAT batchx128 Gemm linear_1 1.1980e+03 | 0017 NODE FLOAT batchx128 Relu relu_2 | 0016 NODE FLOAT batchx128 Relu relu_2 1.1980e+03 | 0018 NODE FLOAT batchx10 Gemm output_0 | 0017 NODE FLOAT batchx10 Gemm output_0 1.1980e+03 | 0019 OUTPUT FLOAT batchx10 output_0 | 0018 OUTPUT FLOAT batchx10 output_0