import contextlib
import os
import pprint
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx.external_data_helper import uses_external_data
from onnx_diagnostic.helpers import string_type
from onnx.reference import ReferenceEvaluator
from ._shape_helper import DYNAMIC_SHAPE, is_static_shape
from ._builder_runtime import _BuilderRuntime
from ._shape_runtime import _ShapeRuntime
from ._inference_runtime import _InferenceRuntime
from .rename_expressions import parse_expression_tokens
from .simplify_expressions import simplify_expression
from ._onnx_helper import str_tensor_proto_type
from .shape_builder import ShapeBuilder
@contextlib.contextmanager
def _maybe_disable_fake_tensor_mode() -> Generator:
try:
yield
finally:
pass
[docs]
class BasicShapeBuilder(ShapeBuilder, _BuilderRuntime, _ShapeRuntime, _InferenceRuntime):
"""
Implements a basic class doing shape inference in an ONNX model.
A couple of environment variables can be set to help debugging any issue.
* ``ONNXSTOPSHAPE=<name>``: raises an exception when ``name`` receives a shape.
* ``ONNXSTOPTYPE=<name>``: raises an exception when ``name`` receives a type.
* ``ONNXDYNDIM=<name>``: raises an exception when dimension ``name`` is used
* ``ONNXCST=1``: shows which constant is requested
* ``ONNXSHAPECOMPUTE=1``: raises an exception when a shape is missing
* ``ONNXSTOPVALUESHAPE=<name>``: more information in function dealing with shapes
"""
def __init__(self, verbose: int = 0, opset: Optional[int] = None):
self.verbose = verbose
self._input_names = []
self._output_names = []
self._known_shapes = {}
self._known_ranks = {}
self._known_types = {}
self.constraints_ = {}
self.dynamic_dimensions_ = {}
self.constants_ = {}
#
self._known_value_shape = {}
self.constants_computed_ = {}
self._calls = []
# self.dynamic_dimensions_source={}
# self.dynamic_dimensions_source_flat={}
# self._dynamic_examples={}
self._debug_stop_shape = os.environ.get("ONNXSTOPSHAPE", "#?#")
self._debug_stop_type = os.environ.get("ONNXSTOPTYPE", "#?#")
self._debug_dyn_dim = set(os.environ.get("ONNXDYNDIM", "").split(","))
self._debug_get_constant = int(os.environ.get("ONNXCST", "0"))
self._debug_shape_missing = int(os.environ.get("ONNXSHAPECOMPUTE", "0"))
self._debug_value_shape = os.environ.get("ONNXSTOPVALUESHAPE", "")
self._debug_constant_folding = 0
self._debug_msg = []
self.maybe_disable_fake_tensor_mode = _maybe_disable_fake_tensor_mode
self.main_opset = opset or 18
self.time_evaluation_constants_ = 0
@property
def input_names(self) -> List[str]:
return self._input_names
@property
def output_names(self) -> List[str]:
return self._output_names
[docs]
def is_constant(self, name: str) -> bool:
"""Tells if a result is a constant."""
return name in self.constants_
[docs]
def get_constant(
self,
name: str,
exc: bool = True,
computed_value: bool = False,
as_shape: bool = False,
multiple_outputs: bool = False,
) -> Union[np.ndarray, onnx.NodeProto]:
"""
The method returns the constant *name*. It is a tensor (numpy array)
or a NodeProto which must be evaluated.
If *computed_value* is True, the NodeProto is evaluated with the
ReferenceEvaluator.
:param name: constant name
:param exc: raise an exception if anything is impossible to do
:param computed_value: compute the value if not a constant
:param as_shape: returns a tuple for a shape
:param multiple_outputs: allow multiple outputs
:return: value
"""
assert self.is_constant(name), f"{name!r} is not a constant{self.get_debug_msg()}"
if as_shape:
assert not multiple_outputs, "multiple outputs not allowed with as_shape=True"
res = self.get_constant(name, exc, computed_value=computed_value, as_shape=False)
if res is None:
assert not exc, (
f"No constant for name={name!r}, exc={exc}, "
f"computed_value={computed_value}, as_shape={as_shape}, "
f"multiple_outputs={multiple_outputs}{self.get_debug_msg()}"
)
if self._debug_get_constant:
print(f"[ShapeBuilder-{self._hash()}.get_constant] FAIL(1) name={name!r}")
return None
assert multiple_outputs or not isinstance(
res, tuple
), f"Multiple outputs is not allowed but type is {type(res)} for name={name!r}"
new_res = []
for i in res:
new_res.append(i if isinstance(i, str) else int(i))
return tuple(new_res)
if name in self.constants_computed_:
value = self.constants_computed_[name]
assert value is not None, f"Constant is empty for name={name!r}"
assert multiple_outputs or not isinstance(
value, tuple
), f"Multiple output is not allowed but type is {type(value)} for name={name!r}"
assert (
not exc or value is not None
), f"Unable to compute value {name!r}{self.get_debug_msg()}"
return value
possible_value = self.constants_[name]
if computed_value and isinstance(possible_value, onnx.NodeProto):
assert len(possible_value.output) == 1, (
f"Not implemented for node {self.pretty_node(possible_value)}"
f"{self.get_debug_msg()}"
)
value, _ = self.compute_constant(name, exc=exc)
if value is not None:
self.constants_computed_[name] = value
return value
if isinstance(possible_value, onnx.TensorProto):
if uses_external_data(possible_value):
assert not exc, (
f"Tensor is using external data, data_type={possible_value.data_type}, "
f"dims={possible_value.dims}"
)
return None
v = onh.to_array(possible_value)
assert not multiple_outputs, f"Multiple outputs is not allowed for name={name!r}"
self.constants_computed_[name] = v
return v
assert isinstance(
possible_value, onnx.TensorProto
), f"Unexpected type {type(possible_value)} for a constant{self.get_debug_msg()}"
res, _ = self.compute_constant(name, exc=exc)
if res is None:
# The constant is too big to be computed.
if self._debug_get_constant:
print(f"[ShapeBuilder-{self._hash()}.get_constant] FAIL(2) name={name!r}")
return None
assert multiple_outputs or not isinstance(
res, tuple
), f"Multiple outputs is not allowed but type is {type(res)} for name={name!r}"
assert (
not multiple_outputs
), f"get_constants not implemented when multiple_outputs=True, name={name!r}"
if not isinstance(res, tuple):
return res
if len(res) == 1:
assert multiple_outputs or not isinstance(
value, tuple
), f"Multiple output is not allowed but type is {type(value)} for name={name!r}"
assert (
not exc or res[0] is not None
), f"Unable to compute value {name!r}{self.get_debug_msg()}"
return res[0]
index = list(possible_value.output).index(name)
value = res[index]
assert value is not None, f"Constant is empty for name={name!r}"
assert multiple_outputs or not isinstance(
value, tuple
), f"Multiple outputs is not allowed but type is {type(value)} for name={name!r}"
assert (
not exc or value is not None
), f"Unable to compute value {name!r}{self.get_debug_msg()}"
return value
[docs]
def set_constant(self, name: str, value: Union[onnx.TensorProto, onnx.NodeProto]) -> bool:
"""Tells if a result is a constant."""
assert (
name not in self.constants_
), f"Constant {name!r} is already defined{self.get_debug_msg()}"
self.constants_[name] = value
if isinstance(value, onnx.TensorProto):
if not self.has_type(name):
self.set_type(name, value.data_type)
if not self.has_shape(name):
self.set_shape(name, tuple(value.dims))
elif isinstance(value, onnx.NodeProto):
for att in value.attribute:
if att.name == "value" and att.t:
self.constants_[name] = att.t
if not self.has_type(name):
self.set_type(name, att.t.data_type)
if not self.has_shape(name):
self.set_shape(name, tuple(att.t.dims))
return
# Let's execute the node otherwise.
ref = ReferenceEvaluator(value)
val = ref.run(None, {})[0]
self.constants_computed_[name] = val
self.set_type(name, oh.np_dtype_to_tensor_dtype(val.dtype))
self.set_shape(name, tuple(map(int, val.shape)))
else:
raise TypeError(f"Unexpected type {type(value)} for value.")
[docs]
def set_value_shape(self, name: str, value: Any, equal_to: Optional[Tuple[str, str]] = None):
"""
Sets the value for a shape result.
:param name: name
:param value: it cannot be empty
:param equal_to: if specified, the value is also equal to this value
A value can be a string (for an unknown shape, a tuple for a shape,
an integer for a single scalar.
"""
if self._debug_value_shape and name == self._debug_value_shape:
raise AssertionError(
f"Requested stop, name={name!r}, value={value!r}, equal_to={equal_to!r}"
)
assert isinstance(
name, str
), f"Unexpected type {type(name)} for name={name!r}{self.get_debug_msg()}"
assert not isinstance(value, tuple) or all(isinstance(d, (str, int)) for d in value), (
f"Unexpected value for shape {name!r}, value={value!r}, "
f"types={string_type(value)}{self.get_debug_msg()}"
)
if not self.has_rank(name):
self.set_shape(name, (len(value),) if isinstance(value, tuple) else tuple())
assert self.has_rank(name), (
f"name={name!r}, has no rank, but it should, value={value!r}"
f"{self.get_debug_msg()}"
)
assert self.get_rank(name) in (0, 1), (
f"name={name!r} is not a shape, its rank is {self.get_rank(name)}"
f"{self.get_debug_msg()}"
)
assert not isinstance(value, (int, float)) or self.get_rank(name) == 0, (
f"Mismatch between value={value!r} and rank="
f"{self.get_rank(name)} for name={name!r}"
f"{self.get_debug_msg()}"
)
if equal_to is None:
if name in self._known_value_shape:
existing = self._known_value_shape[name]
if (
isinstance(existing, tuple)
and isinstance(value, tuple)
and len(existing) == len(value) == 1
and isinstance(existing[0], str)
):
self.register_constraint_dimension("existing", value)
return
assert (
name not in self._known_value_shape or self._known_value_shape[name] == value
), (
f"Shape value for {name!r} (value={value!r}) is already "
f"registered and is different from the existing "
f"value={value!r} (equal_to={equal_to!r}), "
f"existing value is {self._known_value_shape.get(name, None)!r}"
f"{self.get_debug_msg()}"
)
if self.verbose > 2:
print(f"[GraphBuilder-{self._hash()}.set_value_shape] {name}:{value}")
self._known_value_shape[name] = value
return
assert (
name in equal_to
), f"Unexpected name={name!r}, it should be in equal_to={equal_to!r}."
values = (
self._known_value_shape.get(equal_to[0], None),
self._known_value_shape.get(equal_to[1], None),
)
assert value in values, (
f"Unexpected value={value} for name={name!r}, equal_to={equal_to}, "
f"values={values}{self.get_debug_msg()}"
)
assert equal_to[0] in self._known_value_shape, (
f"{equal_to[0]!r} should already registered, name={name!r}, "
f"value={value!r}, equal_to={equal_to!r}{self.get_debug_msg()}"
)
# The logic is to get rid of one value instead of keeping
# a mapping between equivalent values.
new_value = self._known_value_shape[equal_to[0]]
for n in equal_to:
if n not in self._known_value_shape:
self._known_value_shape[n] = new_value
[docs]
def has_type(self, name: str) -> Union[bool, int]:
"""Tells if a result has a type. This should be always true."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
if name not in self._known_types:
return False
# If the type is undefined, then it has no type.
return self._known_types[name]
[docs]
def get_type(self, name: str) -> int:
"""Returns the type of a result."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert name in self._known_types, (
f"Type is unknown for result {name!r}, "
f"known_types={self._known_types}{self.get_debug_msg()}."
)
return self._known_types[name]
[docs]
def set_type(self, name: str, dtype: int, exc: bool = True) -> bool:
"""
Sets the shape for a result. It is exists, it checks the new shape
is equal to the existing one.
:param name: name
:param dtype: element type (an integer, ONNX), 0 (unknonw is a possible value)
:param exc: raises an exception
:return: returns True if there is no type conflict
"""
assert (
not name or name != self._debug_stop_type
), f"Requested stop, name={name!r}, dtype={dtype}{self.get_debug_msg()}"
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert isinstance(dtype, int), f"Unexpected type {type(dtype)} for dtype."
int_type = dtype
if name in self._known_types:
# 0 is undefined
if self._known_types[name] != 0 and int_type != self._known_types[name]:
if exc:
raise RuntimeError(
f"Type for name {name!r} already exists and it is different, "
f"known is {self._known_types[name]} != {int_type} (new) - "
f"(mapping={str_tensor_proto_type()}){self.get_debug_msg()}"
)
if "warnings" not in self._debug_msg:
self._debug_msg["warnings"] = []
self._debug_msg["warnings"].append(
f"Type for name {name!r} already exists and it is different, "
f"known is {self._known_types[name]} != {int_type} (new) - "
)
if self.verbose:
print(
f"Type for name {name!r} already exists and it is different, "
f"known is {self._known_types[name]} != {int_type} (new)"
)
return False
if self.verbose > 5:
print(f"[ShapeBuilder-{self._hash()}.set_type] {name}:{int_type}")
self._known_types[name] = int_type
return True
[docs]
def has_rank(self, name: str) -> bool:
"""Tells if a result has a rank."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
return name in self._known_ranks
[docs]
def get_rank(self, name: str) -> int:
"""Returns the rank of a result."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert name in self._known_ranks, (
f"rank is unknown for result {name!r}, has_shape={self.has_shape(name)}, "
f"has_rank={self.has_rank(name)}, "
f"known_ranks={self._known_ranks}{self.get_debug_msg()}"
)
return self._known_ranks[name]
[docs]
def set_rank(self, name: str, value: int) -> bool:
"""
Sets the rank for a result.
:param name: result name
:param value: rank
:return: True if there is no rank conflict
"""
assert (
not self._debug_stop_shape or name != self._debug_stop_shape
), f"Requested stop, name={name!r}, rank={value}"
assert isinstance(value, int), f"Unexpected rank type {type(value)} for {name!r}"
assert not isinstance(value, bool), f"Unexpected rank type {type(value)} for {name!r}"
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
if name in self._known_ranks:
assert value == self._known_ranks[name], (
f"Inconsistent ranks for {name!r}, previous value is "
f"{self._known_ranks[name]}, new value is {value}{self.get_debug_msg()}"
)
if self.verbose > 5:
print(f"[ShapeBuilder-{self._hash()}.set_rank] (again) {name}:{value}")
return True
self._known_ranks[name] = value
if self.verbose > 5:
print(f"[ShapeBuilder-{self._hash()}.set_rank] {name}:{value}")
return True
[docs]
def has_shape(self, name: str, full=False) -> bool:
"""
Tells if a result has a shape.
If *full* is True, it returns True if the shape exists and if it
is a static shape with all dimensions > 0.
"""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
if name not in self._known_shapes:
return False
if full:
shape = self._known_shapes[name]
return is_static_shape(shape) and min(shape) >= 0
return True
[docs]
def get_shape(self, name: str) -> DYNAMIC_SHAPE:
"""Returns the shape of a result."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert name in self._known_shapes, (
f"Shape is unknown for result {name!r}, "
f"known_shapes={self._known_shapes}{self.get_debug_msg()}"
)
return self._known_shapes[name]
[docs]
def register_dynamic_objects_from_dim(self, dim: str):
"""Registers all the dynamic objects required in a dimension."""
assert isinstance(dim, str) and " " not in dim and dim.count("(") == dim.count(")"), (
f"type(dim)={type(dim)} must be a str and should not contain "
f"a comma or a space dim={dim!r} and the same number of opened and closed "
f"brackets{self.get_debug_msg()}"
)
for token in parse_expression_tokens(dim):
if token not in self.dynamic_dimensions_:
self.add_dynamic_dimension(token)
[docs]
def add_dynamic_dimension(self, name: str):
"""Adds a dynamic dimension."""
assert (
name not in self.dynamic_dimensions_
), f"Dynamic dimension {name!r}{self.get_debug_msg()}"
self.dynamic_dimensions_[name] = {name}
[docs]
def set_shape(self, name: str, shape: DYNAMIC_SHAPE, exc: bool = False, **_kwargs):
"""
Sets the shape for a result. It is exists, it checks the new shape
is equal to the existing one.
:param name: result name
:param shape: shape
:param exc: raise an exception if inconsistency
"""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert isinstance(shape, tuple), f"Unexpected shape type {type(shape)}"
assert (
not name or name != self._debug_stop_shape
), f"Requested stop, name={name!r}, shape={shape}{self.get_debug_msg()}"
assert not shape or not isinstance(shape[0], tuple), f"Unexpected shape {shape}"
shape = tuple(simplify_expression(s) for s in shape)
for sdim in shape:
if not isinstance(sdim, str):
continue
self.register_dynamic_objects_from_dim(sdim)
if name in self._known_shapes:
old_shape = self._known_shapes[name]
if self._debug_dyn_dim and self._debug_dyn_dim & (set(shape) | set(old_shape)):
print(
f"[ShapeBuilder.set_shape] set_shape({name!r}, {shape}), "
f"old_shape={old_shape}"
)
if shape != old_shape:
if exc:
raise RuntimeError(
f"Name {name!r} already exists and its shape different "
f"{old_shape} (old) != {shape}{self.get_debug_msg()}"
)
return False
return True
if self._debug_dyn_dim and set(shape) & self._debug_dyn_dim:
print(f"[ShapeBuilder.set_shape] set_shape({name!r}, {shape})")
if self.verbose > 5:
print(f"[ShapeBuilder-{self._hash()}.set_shape] {name}:{shape}")
self._known_shapes[name] = shape
if not self.has_rank(name):
self.set_rank(name, len(shape))
[docs]
def value_as_shape(self, name: str) -> bool:
"""Returns the value of a result if it is a shape."""
if name in self._known_value_shape:
return self._known_value_shape[name]
if not self.has_type(name) or self.get_type(name) != onnx.TensorProto.INT64:
return None
if self.is_constant(name):
# It is probably a shape because the user requested it as a shape.
cst = self.get_constant(name, exc=False, computed_value=True)
if cst is not None and len(cst.shape) == 1 and cst.dtype == np.int64:
value = tuple(map(int, cst))
self._known_value_shape[name] = value
return value
return None
[docs]
def get_debug_msg(self, limit: int = 1000) -> str:
"""
Returns a string providing as much information as possible
to help the developper understand why a conversion failed.
:param limit: limit the string if the model is big
:return: many pieces of informations about the on going conversion
"""
import numpy as np
import onnx.numpy_helper as onh
def assert_sorted(inputs):
try:
return sorted(inputs)
except TypeError:
return list(inputs)
def _align(s, length):
if len(s) < length:
s += " " * (length - len(s))
return s
def _dtype(t):
if hasattr(t, "dtype"):
return t.dtype
if hasattr(t, "data_type"):
return t.data_type
raise RuntimeError(f"dtype unknown for type {type(t)}-{t}.")
def _shape(t):
if hasattr(t, "shape"):
return t.dtype
if hasattr(t, "dims"):
return tuple(t.dims)
raise RuntimeError(f"dtype unknown for type {type(t)}-{t}.")
def _size(t):
if hasattr(t, "numel"):
return t.numel()
if hasattr(t, "size"):
return t.size
if hasattr(t, "dims"):
return np.prod(tuple(t.dims))
raise RuntimeError(f"Size unknown for type {type(t)}-{t}.")
def _values(t):
if hasattr(t, "detach"):
def is_allow_non_fake_inputs_enabled():
from torch._guards import detect_fake_mode
return detect_fake_mode(t)
if is_allow_non_fake_inputs_enabled():
return "FakeTensorMode enabled"
return t.detach().cpu().flatten().tolist()
if hasattr(t, "size"):
return t.ravel().tolist()
if hasattr(t, "dims"):
a = onh.to_array(t)
return a.ravel().tolist()
raise RuntimeError(f"Values unknown for type {type(t)}-{t}.")
rows = ["", "--DEBUG--"]
hs = self._hash()
rows.append(f"[ShapeBuilder-{hs}] Message starts")
# if self._implicit_decisions:
# rows.append("--IMPLICIT DECISIONS--")
# rows.extend(map(str, self._implicit_decisions))
if self.constraints_:
rows.append("--CONSTRAINTS--")
for a, b in assert_sorted(self.constraints_.items()):
rows.append(f" {a} = {b}")
else:
rows.append("--NOCONSTRAINTS--")
rows.append("--SHAPE--")
rows.append(f"_known_shapes={pprint.pformat(self._known_shapes)[:10000]}")
rows.append(f"_known_types={pprint.pformat(self._known_types)[:10000]}")
short_sh = {
k: (v if (isinstance(v, tuple) and len(v) < 10) else string_type(v, with_shape=True))
for k, v in self._known_value_shape.items()
}
rows.append(f"_known_value_shape={pprint.pformat(short_sh)[:10000]}")
rows.append(
f"_known_constants={pprint.pformat(list(assert_sorted(self.constants_))[:10000])}"
)
reminaing_ranks = {
k: v for k, v in self._known_ranks.items() if k not in self._known_shapes
}
rows.append(f"_known_ranks (with no shape)={pprint.pformat(reminaing_ranks )[:10000]}")
if self._calls:
rows.append("--CALLS--")
rows.extend([str(s) for s in self._calls])
else:
rows.append("--NOCALLS--")
return "\n".join(rows)
[docs]
def run_node(self, node: onnx.NodeProto, exc: bool = False):
"""
Uses shapes availables in the ShapeBuilder to infer the output shapes
and types.
"""
if node.op_type == "Constant" and node.domain == "":
self.set_constant(node.output[0], node)
self.simple_update_value_shape_with_node(node)
if self.verbose:
print(
f"[BasicShapeBuilder.run_node] {self.pretty_node(node)} - "
f"{self.get_type(node.output[0])}:{self.get_shape(node.output[0])}"
)
else:
r = self._make_node_set_type_shape(node, exc=exc)
self.simple_update_value_shape_with_node(node)
if all(self.is_constant(i) for i in node.input):
for o in node.output:
if not self.is_constant(o):
self.set_constant(o, node)
if self.verbose:
print(f"[BasicShapeBuilder.run_node] {self.pretty_node(node)}: {r}")
[docs]
def run_value_info(self, info: onnx.ValueInfoProto, is_input: bool):
"""Fills ShapeBuilder with information coming from an input or output."""
assert info.type.tensor_type, f"info is not a tensor type: {info}"
if is_input:
self._input_names.append(info.name)
else:
self._output_names.append(info.name)
self.set_type(info.name, info.type.tensor_type.elem_type)
shape = info.type.tensor_type.shape
value = tuple(d.dim_param or d.dim_value for d in shape.dim)
self.set_shape(info.name, value)
[docs]
def run_model(
self,
model: Union[onnx.ModelProto, onnx.GraphProto],
functions: Optional[Dict[Tuple[str, str], onnx.FunctionProto]] = None,
exc: bool = False,
):
"""Runs inference over a model or a graph."""
self.main_opset = 18
self.time_evaluation_constants_ = 0
if isinstance(model, onnx.ModelProto):
for opset in model.opset_import:
if opset.domain == "":
self.main_opset = opset.version
break
return self.run_model(
model.graph, functions={(f.domain, f.name): f for f in model.functions}
)
assert isinstance(model, onnx.GraphProto), f"Unexpected type {type(model)} for model"
graph = model
for i in graph.initializer:
self.set_constant(i.name, i)
for i in graph.sparse_initializer:
self.set_constant(i.name, i)
for i in graph.input:
self.run_value_info(i, True)
for node in graph.node:
self.run_node(node, exc=exc)
for i in graph.output:
self.run_value_info(i, False)
[docs]
def register_constraint_dimension(self, dim_name: str, value: Any):
"""
Registers a constraint on a dimension.
:param dim_name: dimension name
:param value: value to register
"""
if self._debug_dyn_dim and dim_name in self._debug_dyn_dim:
print(
f"[GraphBuilder.register_constraint_dimension] "
f"dim_name={dim_name!r}, value={value!r}"
)
if dim_name not in self.constraints_:
self.constraints_[dim_name] = set()
if isinstance(value, set):
self.constraints_[dim_name] |= value
else:
self.constraints_[dim_name].add(value)