import enum
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import onnx
from ..helpers.onnx_helper import onnx_dtype_name
_NOT_SO_FAR_OPS = [
{"MatMul", "Gemm", "FusedMatMul"},
{"Conv", "FusedConv"},
{"MaxPool"},
]
def _sum_sets(sets):
t = set()
for s in sets:
t |= s
return t
_ALL_NOT_SO_FAR_OPS = _sum_sets(_NOT_SO_FAR_OPS)
def _align(res: str, limit: int) -> str:
if len(res) == limit:
return res
if len(res) > limit:
return res[:limit]
return res + " " * (limit - len(res))
[docs]
class ObsType(enum.IntEnum):
"""Observation kind."""
RESULT = 1
INITIALIZER = 2
SPARSE_INITIALIZER = 4
INPUT = 8
OUTPUT = 16
NODE = 32
def __repr__(self):
return f"{self.__class__.__name__}.{self._name_}"
[docs]
@dataclass
class ObsCompare:
"""
The description of an observation, a node, an input, an output, an initializer.
:param position: index of this observation in the original model
:param kind: node type, see :class:`ObsType`
:param name_or_outputs: name of an initializer or the outputs of a node
:param itype: onnx type
:param index: index of an input or output
:param shape: shape
:param op_type: node op_type
:param comment: comment, unused
"""
position: int
kind: ObsType
name_or_outputs: Tuple[str]
itype: int = 0
index: int = 0
shape: Optional[Tuple[Tuple[Union[int, str], ...]]] = None
op_type: str = ""
comment: str = ""
def __str__(self) -> str:
"usual"
els = [
_align(f"{self.position:04d}", 4),
_align(self.kind._name_, 6),
_align(onnx_dtype_name(self.itype) if self.itype else "?", 8),
_align("?" if self.shape is None else "x".join(map(str, self.shape)), 18),
_align(self.op_type or "", 15),
_align(", ".join(self.name_or_outputs), 35),
]
return " ".join(els)
@classmethod
def to_str(cls, obs: Optional["ObsCompare"]) -> str:
assert not obs or isinstance(obs, ObsCompare), f"unexpected type {type(obs)}"
if obs:
return str(obs)
return " " * (4 + 6 + 8 + 18 + 15 + 35 + 5)
[docs]
def distance(self, obs: "ObsCompare") -> float:
"""Computes a cost between two observations."""
if self.kind != obs.kind:
return 1e6
d: float = 0
if self.itype != obs.itype:
d += 1e5
if self.kind == ObsType.NODE:
cost = 9997
d = 0
if self.op_type != obs.op_type:
if self.op_type in _ALL_NOT_SO_FAR_OPS or obs.op_type in _ALL_NOT_SO_FAR_OPS:
d += 1e2
for aset in _NOT_SO_FAR_OPS:
if self.op_type in aset and obs.op_type in aset:
cost = 97
elif self.op_type in aset or obs.op_type in aset:
d += 5e4
else:
d += 9e2
if len(self.name_or_outputs) == 1 and len(obs.name_or_outputs) == 1:
if self.name_or_outputs[0] != obs.name_or_outputs[0]:
n1 = self.name_or_outputs[0]
n2 = obs.name_or_outputs[0]
n1 = n1.replace("_", "")
n2 = n2.replace("_", "")
if n1 == n2:
d += 1
elif (n1.startswith(("val_", "_onx_")) or "::" in n1 or "--" in n1) and (
n2.startswith(("val_", "_onx_")) or "::" in n2 or "--" in n2
):
# These are name given the exporter
# and not inspired from the model itself.
d += cost / 100
else:
d += cost
else:
a = set(self.name_or_outputs) & set(obs.name_or_outputs)
b = set(self.name_or_outputs) | set(obs.name_or_outputs)
d += cost * (len(b) - len(a))
return d
if self.kind == ObsType.INPUT:
return (
999.7
if self.itype != obs.itype
or self.shape != obs.shape
or self.index != obs.index
else 0
)
if self.kind == ObsType.INITIALIZER or self.kind == ObsType.SPARSE_INITIALIZER:
return 1e3 if self.itype != obs.itype or self.shape != obs.shape else 0
if self.kind == ObsType.OUTPUT:
return (
999.1
if self.itype != obs.itype
or self.shape != obs.shape
or self.index != obs.index
else 0
)
return 1e8
[docs]
@classmethod
def obs_sequence_from_model(
cls,
model: Union[onnx.ModelProto, onnx.GraphProto],
) -> List["ObsCompare"]:
"""
Creates a sequence of observations bases on a model.
:param model: model
:return: sequence of observations
"""
graph = model if isinstance(model, onnx.GraphProto) else model.graph
shapes = {}
types = {}
for info in [*graph.value_info, *graph.input, *graph.output]:
if info.type.tensor_type:
t = info.type.tensor_type
shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim)
types[info.name] = t.elem_type
seq: List[ObsCompare] = []
for init in graph.initializer:
obs = ObsCompare(
position=len(seq),
kind=ObsType.INITIALIZER,
itype=init.data_type,
shape=tuple(init.dims),
name_or_outputs=(init.name,),
)
seq.append(obs)
for i, inp in enumerate(graph.input):
obs = ObsCompare(
position=len(seq),
kind=ObsType.INPUT,
itype=inp.type.tensor_type.elem_type,
index=i,
shape=tuple(
(d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim
),
name_or_outputs=(inp.name,),
)
seq.append(obs)
for node in graph.node:
obs = ObsCompare(
position=len(seq),
kind=ObsType.NODE,
itype=types.get(node.output[0], 0),
index=i,
shape=shapes.get(node.output[0], None),
name_or_outputs=tuple(node.output),
op_type=node.op_type,
)
seq.append(obs)
for i, inp in enumerate(graph.output):
obs = ObsCompare(
position=len(seq),
kind=ObsType.OUTPUT,
itype=inp.type.tensor_type.elem_type,
index=i,
shape=tuple(
(d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim
),
name_or_outputs=(inp.name,),
)
seq.append(obs)
return seq
[docs]
@dataclass
class ObsComparePair:
"""
Defines a pair of comparison objects
:param side1: object from first side
:param side2: object from first side
:param distance: distance
"""
side1: Optional[ObsCompare]
side2: Optional[ObsCompare]
distance: float
def __str__(self) -> str:
"nice display"
return (
f"{self.distance:.4e} | "
f"{ObsCompare.to_str(self.side1)} | {ObsCompare.to_str(self.side2)}"
)
[docs]
@classmethod
def to_str(cls, seq: List["ObsComparePair"]) -> str:
"""Displays every pair in text."""
return "\n".join([f"{str(pair)}" for pair in seq])
[docs]
@classmethod
def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tuple[
float,
List[Tuple[int, int]],
List["ObsComparePair"],
]:
"""
Computes the distance between two sequences of results.
:param s1: first sequence
:param s2: second sequence
:return: distance and alignment
An example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 16, 5)
self.fc1 = torch.nn.Linear(144, 64)
self.fc2 = torch.nn.Linear(64, 128)
self.fc3 = torch.nn.Linear(128, 10)
def forward(self, x):
x = torch.nn.functional.max_pool2d(
torch.nn.functional.relu(self.conv1(x)),
(4, 4),
)
# x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
y = self.fc3(x)
return y
model = Model()
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
dynamic_shapes = ({0: "batch", 3: "dim"},)
onnx_optimized = to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
).model_proto
onnx_not_optimized = to_onnx(
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
).model_proto
seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
text = ObsComparePair.to_str(pair_cmp)
print(text)
"""
delay = max(50, abs(len(s2) - len(s1)) + 1)
distance: Dict[Tuple[int, int], Union[int, float]] = {(-1, -1): 0}
predecessor: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {(-1, -1): None}
insert_cost = 1e3
for i in range(len(s1)):
for j in range(max(0, i - delay), min(len(s2), i + delay)):
best = distance.get((i, j), 1e100)
pred = None
ki, kj = i - 1, j - 1
if (ki, kj) in distance:
d = distance[ki, kj] + s1[i].distance(s2[j])
if d < best:
best = d
pred = (ki, kj)
ki, kj = i - 1, j
if (ki, kj) in distance:
d = distance[ki, kj] + insert_cost + 1
if d < best:
best = d
pred = (ki, kj)
ki, kj = i, j - 1
if (ki, kj) in distance:
d = distance[ki, kj] + insert_cost + 0.1
if d < best:
best = d
pred = (ki, kj)
distance[i, j] = best
predecessor[i, j] = pred
# reverse
way = []
last: Optional[Tuple[int, int]] = len(s1) - 1, len(s2) - 1
while last is not None:
way.append(last)
last = predecessor[last]
indices = list(reversed(way))[1:]
obs_path: List[ObsComparePair] = []
last = -1, -1
for i, j in indices:
di = i - last[0]
dj = j - last[1]
cost = distance.get((i, j), np.nan)
if di == dj == 1:
obs_path.append(ObsComparePair(s1[i], s2[j], distance=cost))
elif di == 0:
obs_path.append(ObsComparePair(None, s2[j], distance=cost))
elif dj == 0:
obs_path.append(ObsComparePair(s1[i], None, distance=cost))
else:
raise RuntimeError(f"issue with di={di}, dj={dj}")
last = i, j
return distance[len(s1) - 1, len(s2) - 1], indices, obs_path