from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import onnx
import onnx.helper as oh
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.onnx_helper import dtype_to_tensor_dtype
from ..helpers import make_hash
from ._shape_helper import DYNAMIC_SHAPE
from .evaluate_expressions import evaluate_expression
from ._onnx_helper import (
element_wise_binary_op_types,
element_wise_op_cmp_types,
unary_like_op_types,
)
[docs]
class ShapeBuilder:
"""API for a class computing shapes in an ONNX model."""
_op_type_element_wise_types = element_wise_binary_op_types()
_op_type_element_wise_cmp_types = element_wise_op_cmp_types()
_op_type_unary_like = unary_like_op_types()
@property
def input_names(self) -> List[str]:
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
@property
def output_names(self) -> List[str]:
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def get_shape(self, name: str) -> DYNAMIC_SHAPE:
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def set_shape(self, name: str, shape: DYNAMIC_SHAPE):
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def get_type(self, name: str) -> int:
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def set_type(self, name: str, itype: int):
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def get_rank(self, name: str) -> int:
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def set_rank(self, name: str, rank: int):
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def register_constraint_dimension(self, dim_name: str, value: Any):
raise NotImplementedError(f"not overloaded in {self.__class__.__name__!r}")
def _hash(self) -> str:
return make_hash(self)
[docs]
def update_shapes(self, model: onnx.ModelProto):
"""Updates model shapes with the value stored inside this graph."""
self._update_shapes_graph(model.graph)
def _update_shapes_graph(self, graph: onnx.GraphProto):
exclude = (
set(i.name for i in graph.input)
| set(i.name for i in graph.output)
| set(i.name for i in graph.initializer)
| set(i.name for i in graph.sparse_initializer)
)
include = set()
for node in graph.node:
include |= set(node.output)
include -= exclude
include -= set(i.name for i in graph.value_info)
ordered_include = []
for node in graph.node:
for o in node.output:
if o in include:
ordered_include.append(o)
infos = []
for k in ordered_include:
if not self.has_shape(k):
continue
infos.append(oh.make_tensor_value_info(k, self.get_type(k), list(self.get_shape(k))))
graph.value_info.extend(infos)
[docs]
def get_attribute(
self, node: onnx.NodeProto, att_name: str, exc: bool = True
) -> Optional[onnx.AttributeProto]:
"""Returns an attribute for a node."""
for att in node.attribute:
if att.name == att_name:
return att
assert not exc, (
f"Unable to find attribute {att_name!r} for node "
f"type {node.op_type!r} in node {node}"
)
return None
[docs]
def get_attribute_with_default(
self, node: onnx.NodeProto, name: str, default_value: Any
) -> Any:
"""
Returns an attribute or its default value if missing.
:param node: node
:param name: attribute name
:param default_value: default value
:return: value
"""
for att in node.attribute:
if att.name == name:
if att.type == onnx.AttributeProto.INT:
return att.i
if att.type == onnx.AttributeProto.INTS:
return list(att.ints)
if att.type == onnx.AttributeProto.FLOAT:
return att.f
if att.type == onnx.AttributeProto.FLOATS:
return list(att.floats)
if att.type == onnx.AttributeProto.STRING:
return att.s
raise TypeError(
f"Not implemented for attribute name {att.name!r}, attribute={att}"
)
return default_value
[docs]
def get_attributes_with_default(
self, node: onnx.NodeProto, **default_values
) -> Dict[str, Any]:
"""
Returns int or float attributes. If missing, the default value is returned
if it is not None.
:param node: node
:param default_values: default values
"""
res = {}
for att in node.attribute:
if att.name in default_values:
if att.type == onnx.AttributeProto.INT:
res[att.name] = att.i
elif att.type == onnx.AttributeProto.INTS:
res[att.name] = list(att.ints)
elif att.type == onnx.AttributeProto.FLOAT:
res[att.name] = att.f
elif att.type == onnx.AttributeProto.FLOATS:
res[att.name] = list(att.floats)
elif att.type == onnx.AttributeProto.STRING:
res[att.name] = att.s
else:
raise TypeError(
f"Not implemented for attribute name {att.name!r}, attribute={att}"
)
for k, v in default_values.items():
if k not in res and v is not None:
res[k] = v
res = {k: v for k, v in res.items() if v is not None}
return res
[docs]
def pretty_node(
self,
node: Optional[onnx.NodeProto],
limit: int = 80,
short: bool = True,
shape: bool = False,
) -> str:
"""
Pretty rendering for a node.
:param node: node to render
:param limit: to show type and shapes after the limit
:param short: do not display shape information on the left
:param shape: show shape information below
:return: string
"""
if node is None:
return "None"
if shape:
st = []
for i in node.input:
dt = self.get_type(i) if self.has_type(i) else "-"
sh = (
"x".join(str(_).replace(" ", "") for _ in self.get_shape(i))
if self.has_shape(i)
else (f"rk={self.get_rank(i)}" if self.has_rank(i) else "?")
)
st.append(f"{i}:{dt}|{sh}")
st.append("->")
for i in node.output:
dt = self.get_type(i) if self.has_type(i) else "-"
sh = (
"x".join(str(_).replace(" ", "") for _ in self.get_shape(i))
if self.has_shape(i)
else (f"rk={self.get_rank(i)}" if self.has_rank(i) else "?")
)
st.append(f"{i}:{dt}|{sh}")
shape_info = " ".join(st)
else:
shape_info = ""
text = (
(
f"{node.op_type}[{node.domain}]: "
f"{', '.join(node.input)} -> {', '.join(node.output)}"
)
if node.domain
else f"{node.op_type}: {', '.join(node.input)} -> {', '.join(node.output)}"
)
if shape_info:
text = f"{text} ## {shape_info}"
if short:
return text
add = " " * abs(80 - len(text))
text += add
info = []
for o in node.output:
t = f"T{self.get_type(o)}" if self.has_type(o) else ""
s = " x ".join(map(str, self.get_shape(o))) if self.has_shape(o) else ""
info.append(": ".join([t, s]))
if node.name:
s = f"{text}|{' '.join(info)}"
return f"{s}{' ' * (110 - len(s))}- {node.name}"
return f"{text}|{' '.join(info)}"
def map_value_info_dimension_with_true_values(self, name: str, tensor: np.ndarray):
assert self.has_type(name), f"Missing type for {name!r}."
assert self.has_shape(name), f"Missing shape for {name!r}."
dtype = dtype_to_tensor_dtype(tensor.dtype)
assert dtype == self.get_type(name), (
f"Type mismatch for {name!r}, expecting "
f"{self.get_type(name)}, got {dtype} in "
f"{string_type(tensor, with_shapes=True)}"
)
res = {}
shape = self.get_shape(name)
for i, (value, dim) in enumerate(zip(tensor.shape, shape)):
if isinstance(dim, str):
if dim in res:
assert res[dim] == value, (
f"Shape mismatch for {name!r} for dimension {i}, "
f"known dimensions are {shape}, got "
f"{string_type(tensor, with_shapes=True)}"
)
res[dim] = value
else:
assert dim == value, (
f"Shape mismatch for {name!r} for dimension {i}, "
f"expecting {dim}, got {string_type(tensor, with_shapes=True)}"
)
return res
def evaluate_shape(self, name: str, context: Dict[str, int]) -> Tuple[int, ...]:
shape = self.get_shape(name)
return tuple(evaluate_expression(s, context) for s in shape)
def compare_computed_shape_with_tensor(
self, name: str, tensor: np.ndarray, context: Dict[str, int]
) -> Tuple[Tuple[str, int, int], ...]:
assert self.has_type(name), f"Missing type for {name!r}."
assert self.has_shape(name), f"Missing shape for {name!r}."
dtype = dtype_to_tensor_dtype(tensor.dtype)
assert dtype == self.get_type(name), (
f"Type mismatch for {name!r}, expecting "
f"{self.get_type(name)}, got {dtype} in "
f"{string_type(tensor, with_shapes=True)}"
)
computed = self.evaluate_shape(name, context=context)
return tuple(zip(self.get_shape(name), tensor.shape, computed))