import contextlib
import pprint
import time
import os
import sys
from collections import Counter
from typing import (
Any,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import numpy as np
from onnx import (
AttributeProto,
FunctionProto,
GraphProto,
ModelProto,
NodeProto,
TensorProto,
TypeProto,
ValueInfoProto,
)
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx.external_data_helper import uses_external_data
from onnx.model_container import make_large_tensor_proto
from onnx.shape_inference import infer_shapes as onnx_infer_shapes
from ..helpers import make_hash, string_sig, pretty_onnx, string_signature
from ..reference import ExtendedReferenceEvaluator
from ._shape_helper import (
DYNAMIC_SHAPE,
STATIC_SHAPE,
_reshape_shape,
all_int,
all_int_or_str,
is_static_dimension,
is_static_shape,
)
from .shape_type_compute import set_shape_type_op_any, set_shape_type_custom
from ._onnx_helper import (
_default_OPSET_TO_IR_VERSION,
_nice_shape,
choose_consistent_domain_opset,
compatible_opsets,
element_wise_binary_op_types,
element_wise_op_cmp_types,
same_function_proto,
unary_like_op_types,
)
from .model_container import TorchModelContainer, proto_from_array, _get_type
from ._dtype_helper import (
dtype_to_tensor_dtype,
onnx_dtype_to_torch_dtype,
torch_dtype_to_onnx_dtype,
)
from .optimization_options import OptimizationOptions
from .expression_dimension import Expression, parse_expression, parse_expression_tokens
from .graph_builder_opset import Opset
from ._graph_builder_runtime import _GraphBuilderRuntime
@contextlib.contextmanager
def _unset_fake_temporarily() -> Generator:
import torch
old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
try:
yield old
finally:
if old is not None:
torch._C._set_dispatch_mode(old)
[docs]
class FunctionOptions:
"""
Defines how local functions must behave.
:param name: function name
:param domain: function domain
:param export_as_function: export the onnx as functions or keep local function
:param external_threshold: whether or not keep initializer as input for the function
or move them as constant of the function
:param move_initializer_to_constant: move initializers as constant first before
creating the function proto, that depends on the size defined by
external_threshold
:param return_initializer: return the remaining initializer and add them as input
to the function
:param inline: inline functions
:param rename_allowed: allow to rename the function if a duplicate is detected
:param merge_allowed: allow to merge a function in case the same code is detected
"""
empty_names = (None, "", "*")
def __init__(
self,
export_as_function: bool = False,
name: str = "",
domain: str = "",
external_threshold: int = 2**25,
move_initializer_to_constant: bool = False,
return_initializer: bool = False,
inline: bool = False,
merge_allowed: bool = False,
rename_allowed: bool = False,
):
if name:
export_as_function = True
assert not export_as_function or name, (
f"to be removed help track bugs, name={name!r}, domain={domain!r}, "
f"export_as_function={export_as_function!r}"
)
assert export_as_function or (not name and not domain), (
f"to be removed help track bugs, name={name!r}, domain={domain!r}, "
f"export_as_function={export_as_function!r}"
)
assert isinstance(
external_threshold, int
), f"Unexpected type {type(external_threshold)} for external_threshold"
self.export_as_function = export_as_function
self.external_threshold = external_threshold
self.move_initializer_to_constant = move_initializer_to_constant
self.name = name
self.domain = domain
self.return_initializer = return_initializer
self.inline = inline
self.rename_allowed = rename_allowed
self.merge_allowed = merge_allowed
def __repr__(self) -> str:
return string_sig(self)
[docs]
class GraphBuilder(_GraphBuilderRuntime):
"""
Simplifies the creation of a model.
:param target_opset_or_existing_proto: a ModelProto, an integer,
a dictionary of domain, version
:param input_names: input names
:param as_function: export as a function or a model
there are less assert when as_function is True
:param optimization_options: optimizations options,
see :class:`OptimizationOptions`
:param args: example of inputs
:param ir_version: ir version when exporting
:param verbose: verbosity
:param infer_shapes: run shape inference, if the value is `'new'`,
existing shapes are ignored
:param raise_list: raise an exception if a new operator belongs to that list
:param dynamic_shapes: dynamic shapes
:param local_domain: domain name to use for local functions if not specified
:param signature: the signature is unused but helps for debugging purposes
Important attributes:
- `input_names: List[str]`: list of input names
- `as_function: bool`: the model must be exported as a function or as a model,
there are less assert when as_function is True
- `optimization_options: OptimizationOptions`:
- `nodes: List[NodeProto]`: list of nodes
- `initializers_dict: Dict[str, Any]`: initializers
- `inputs: List[ValueInfoTensorProto]`: inputs
- `outputs: List[ValueInfoTensorProto]`: outputs
- `ir_version: int`: ir version
- `opsets: Dict[str, int]`: declared opsets
- `input_args: List[T]`: input tensors when
the class is used to convert an existing model
- `functions: Dict[Tuple[str,str], FunctionProto]`:
dictionary of functions to add to the model
- `value_info: List[ValueInfoProto]`: value info of the original model
- `dynamic_shapes: Union[Dict[str, Any], Tuple[Any]]]`: dynamic_shapes informations
- `_parameter_renaming: Dict[str, str]`: to rename parameter and give them
a name which can be found in ``module.named_parameter``
Computed attributes:
- `_unique_names`: used to create unused result names
- `_unique_node_names`: used to create unused node names
- `_known_names`: set of existing results names
- `_known_shapes: Dict[str, DYNAMIC_SHAPE]`: declared shapes
- `_known_types: Dict[str, int]`: declared element types
- `_known_value_shape: Dict[str, Any]`: if a result is a shape or not
(for example the output of operator Shape)
- `_known_ranks: Dict[str, int]`: declared ranks
- `_known_sequences: Dict[str, Dict[str, Any]]`: known sequences
- `_dynamic_examples: Dict[str, Set[Union[int,float]]]`: example of dynamic dimensions
- `constants_node_: Dict[bytes, NodeProto]`: constant node
- `constants_alias_: Dict[str, str]`: alias for constant
- `constants_: Dict[str, Any]`: constant values
- `constants_computed_: Dict[str, Any]`: computed constant values
- `dynamic_objects: Dict[str, torch.SymInt]`: list of dynamic dimension
- `dynamic_objects_rev: Dict[str, str]`: reverse dictionary to fasten lookups
- `_cache_shape: Dict[key,str]`: cache concatenation of shapes
- `_values: Dict[key,str]`: cache initializer value to merge those which are equal
- `_dynamic_alias: Dict[str,str]`: used when the user gives a different
name to the dynamic shapes
- `constraints_: Dict[str, Set[Any]]`:
if a broadcast implies a constraints on a dynamic shape,
it is stored here
- `_events`: is used ot retrieve any information useful to debug
Debugging attributes:
- `_raise_list: Set[str]`: the builder stop if a result falls in that list
(debugging tool)
You can setup environment variable ``ONNXSTOP``, ``ONNXSTOPSHAPE``, ``ONNXSTOPTYPE``
to raise an exception when the type or shape
of a variable is set. Example: ``ONNXSTOP=attn_output python ...``.
``ONNXCST=1`` shows which constant is computed,
``NULLSHAPE=1`` raises an exception as soon as a null shape occurs.
"""
[docs]
class ShapeConstant:
"""
Wraps a constant shape even if the input producing the shape is not.
"""
def __init__(self, name: str, shape: Tuple[int, ...], node: NodeProto):
self.name = name
self.shape = shape
self.node = node
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.name!r}, shape={self.shape!r}, "
f"node=<{self.node.op_type})"
)
[docs]
class WrapSym:
"""
Wraps a symbolic int (a dimension for example).
"""
def __init__(self, sym: Union["torch.SymInt", "torch.SymFloat"]): # noqa: F821
self.sym = sym
def __repr__(self) -> str:
return f"WrapSym({self._dynamic_to_str(self.sym)})"
def _dynamic_to_str(self, obj: Any) -> Optional[str]:
if isinstance(obj, str):
return obj
import torch
if isinstance(obj, torch.export.dynamic_shapes._DerivedDim):
return obj.__name__
if isinstance(obj, torch.export.dynamic_shapes._Dim):
return obj.__name__
if isinstance(obj, torch.SymInt):
if isinstance(obj.node, str):
return obj.node
i = obj.node._expr
if "sympy" in str(type(i)):
return str(i)
return None
raise AssertionError(f"Unexpected type {type(obj)} to convert into string")
# Size of a tensor kept in the onnx file and not stored as exrternal weight.
SMALL_TENSOR = 1024
_op_type_element_wise_types = element_wise_binary_op_types()
_op_type_element_wise_cmp_types = element_wise_op_cmp_types()
_op_type_unary_like = unary_like_op_types()
def __init__(
self,
target_opset_or_existing_proto: Union[int, Dict[str, int], ModelProto, FunctionProto],
input_names: Optional[Sequence[str]] = None,
as_function: bool = False,
optimization_options: Optional[OptimizationOptions] = None,
args: Optional[List[Any]] = None,
ir_version: Optional[int] = None,
verbose: int = 0,
infer_shapes: Union[bool, str] = False,
raise_list: Optional[Set[str]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
local_domain: str = "local_function",
signature: Optional[Any] = None,
):
import torch
self.torch = torch
self.maybe_disable_fake_tensor_mode = _unset_fake_temporarily
self.optimization_options = optimization_options or OptimizationOptions(
verbose=verbose
)
self.local_domain = local_domain
self.as_function = as_function
self.input_args = args
self.verbose = verbose
self.ir_version = ir_version
self._debug_msg = {}
self.dynamic_dimensions_source = {}
self.dynamic_shapes = dynamic_shapes
self.dynamic_objects = {}
self.dynamic_objects_rev = {}
self.functions = {}
self.functions_builder = {}
self.value_info = []
self.raise_list = raise_list
self._raise_list = raise_list or set()
self.constants_computed_ = {}
self._cache_shape = {}
self._values = {}
self._dynamic_alias = {}
self.constants_node_ = {}
self.constants_alias_ = {}
self._events = {}
self.signature = signature
self.nodes = []
self.initializers_dict = {}
self.inputs = []
self.outputs = []
self._known_shapes = {}
self._known_types = {}
self._known_ranks = {}
self._known_torch_value = {}
self._known_names = set()
self._known_sequences = {}
self._unique_names = set()
self._unique_node_names = set()
self._known_value_shape = {}
self._dynamic_examples = {}
self._parameter_renaming = {}
self._parameter_norename = set()
self.constants_ = {}
self.op = Opset(self)
self.anyop = Opset(self, allow_unknown=True)
self._debug_null_shape = int(os.environ.get("NULLSHAPE", "0"))
self._debug_stop = os.environ.get("ONNXSTOP", "#?#")
self._debug_stop_shape = os.environ.get("ONNXSTOPSHAPE", "#?#")
self._debug_stop_type = os.environ.get("ONNXSTOPTYPE", "#?#")
self._debug_get_constant = int(os.environ.get("ONNXCST", "0"))
self.time_evaluation_constants_ = 0
self.statistics_ = {}
self.constraints_ = {}
self._registered_users = {}
self.was_inputs_renamed = input_names is not None and input_names
self.update_dynamic_shape_when_input_name_is_defined = False
assert dynamic_shapes is None or isinstance(dynamic_shapes, (dict, tuple)), (
f"dynamic_shapes is expected to be empty or a dictionary or a tuple "
f"not {type(dynamic_shapes)}, dynamic_shapes={dynamic_shapes}"
)
if self.dynamic_shapes:
self._register_dynamic_object_from_dynamic_shapes()
if isinstance(target_opset_or_existing_proto, (int, dict)):
# starts a model from nothing
assert not infer_shapes, "infer_shapes is used if an existing model is loaded"
self.opsets = (
{"": target_opset_or_existing_proto}
if isinstance(target_opset_or_existing_proto, int)
else target_opset_or_existing_proto
)
self.input_names = input_names or []
self.current_input = 0
self._unique_names = set(self.input_names)
elif isinstance(target_opset_or_existing_proto, ModelProto):
# loads a model from nothing
if input_names:
raise ValueError(
"input_names must be empty if the input is an existing model."
)
self.current_input = len(self.inputs)
self._update_structures_with_proto(target_opset_or_existing_proto, infer_shapes)
self.constant_folding(convert_into_initializer=False)
self._update_shape_types_with_proto(target_opset_or_existing_proto, infer_shapes)
else:
raise NotImplementedError(
f"{type(target_opset_or_existing_proto)} is not supported."
)
def _register_dynamic_object_from_dynamic_shapes_dict(self, pos, pos_vv, vv):
# example:
# args_0 {0: <class '._bash_bench_model_runner.batch'>}
for _k, _v in vv.items():
if isinstance(_v, self.torch.SymInt):
self.make_dynamic_object(
_v.__name__,
_v,
axis=_k,
input_name=pos,
)
elif isinstance(_v, self.torch.export.dynamic_shapes._DerivedDim):
self.make_dynamic_object(
_v.__name__,
self.torch.SymInt(_v.__name__),
axis=_k,
input_name=pos,
)
# It should be recursive.
self.make_dynamic_object(
_v.root.__name__,
self.torch.SymInt(_v.root.__name__),
axis=None,
input_name=None,
)
elif isinstance(_v, self.torch.export.dynamic_shapes._Dim):
self.make_dynamic_object(
_v.__name__,
self.torch.SymInt(_v.__name__),
axis=_k,
input_name=pos,
)
else:
raise AssertionError(
f"Unexpected type {type(_v)} in {vv} for dynamic "
f"dimension {pos!r}, pos_vv={pos_vv!r}, "
f"self.dynamic_shapes={self.dynamic_shapes}"
)
def _register_dynamic_object_from_dynamic_shapes(self):
assert (
self.dynamic_shapes is not None
), "Call this method if self.dynamic_shapes is not None"
self.update_dynamic_shape_when_input_name_is_defined = not isinstance(
self.dynamic_shapes, dict
)
seq_dynamic_shapes = (
list(self.dynamic_shapes.items())
if isinstance(self.dynamic_shapes, dict)
else list(enumerate(self.dynamic_shapes))
)
for input_name_or_position, v in seq_dynamic_shapes:
if isinstance(v, dict):
pos_vv = list(v.items())
elif isinstance(v, (list, tuple)):
if isinstance(input_name_or_position, str):
pos_vv = [(f"{input_name_or_position}_{i}", v[i]) for i in range(len(v))]
else:
pos_vv = [((input_name_or_position, i), v[i]) for i in range(len(v))]
else:
raise AssertionError(
f"Unexpected value for input_name={input_name_or_position!r} and "
f"v={v}, dynamic_shapes={self.dynamic_shapes}"
)
for pos, vv in pos_vv:
if vv is None:
continue
if isinstance(vv, list):
for vvv in vv:
if vvv is None:
continue
assert isinstance(
vvv, dict
), f"Unexpected type {type(vvv)} at pos={pos} and {vv}"
self._register_dynamic_object_from_dynamic_shapes_dict(
pos, pos_vv, vvv
)
elif isinstance(vv, dict):
self._register_dynamic_object_from_dynamic_shapes_dict(pos, pos_vv, vv)
elif isinstance(vv, self.torch.SymInt):
self.make_dynamic_object(
vv.__name__, vv, axis=pos, input_name=input_name_or_position
)
elif isinstance(vv, self.torch.export.dynamic_shapes._DerivedDim):
# Used to specify a dimension as a multiple of something
# We register the root.
self.make_dynamic_object(
vv.__name__,
self.torch.SymInt(vv.__name__),
axis=pos,
input_name=input_name_or_position,
)
# It should be recursive.
self.make_dynamic_object(
vv.root.__name__,
self.torch.SymInt(vv.root.__name__),
axis=None,
input_name=None,
)
elif isinstance(vv, self.torch.export.dynamic_shapes._Dim):
self.make_dynamic_object(
vv.__name__,
self.torch.SymInt(vv.__name__),
axis=pos,
input_name=input_name_or_position,
)
else:
raise AssertionError(
f"Unexpected type {type(vv)}, vv={vv} for dynamic "
f"dimension {pos!r}, pos_vv={pos_vv!r}, "
f"self.dynamic_shapes={self.dynamic_shapes}"
)
[docs]
def add_stat(self, kind: str, name: str):
"""
Increments a counter.
"""
if kind not in self.statistics_:
self.statistics_[kind] = {}
stat = self.statistics_[kind]
if name not in stat:
stat[name] = 1
else:
stat[name] += 1
@property
def output_names(self) -> List[str]:
return [o.name for o in self.outputs]
[docs]
def empty_copy(
self, as_function: bool = False, constant_size: int = 2**24
) -> "GraphBuilder":
"""
Creates an empty copy but with the same opsets.
"""
assert isinstance(as_function, bool), f"wrong type {type(as_function)} for as_function"
opt = OptimizationOptions(
constant_size=constant_size,
constant_fusing=False,
remove_identity=False,
patterns=None,
)
g = GraphBuilder(
self.opsets.copy(),
verbose=self.verbose,
ir_version=self.ir_version,
as_function=as_function,
optimization_options=opt,
)
return g
[docs]
def make_key(self, value: Any) -> Optional[Tuple[Union[str, int], ...]]:
"""
Builds a key identifying a value.
Returns None if it is not possible.
"""
if isinstance(value, TensorProto):
if uses_external_data(value):
key = None
for info in value.external_data:
if info.key == "location":
key = info.value
break
assert key is not None, (
f"External tensor {value.name!r} has no location, "
f"external_data is {value.external_data}"
)
return None
return self.make_key(onh.to_array(value))
if isinstance(value, self.torch.Tensor):
if value.dtype == self.torch.int64 and value.numel() < 8:
return self.make_key(value.detach().cpu().numpy())
return None
if isinstance(value, int):
return int, value
if isinstance(value, np.ndarray):
if value.size < self.SMALL_TENSOR:
return (value.dtype, value.shape, tuple(value.ravel().tolist()))
return None
def pretty_tensor(self, tensor: Any) -> str:
if hasattr(tensor, "shape"):
return f"{tensor.dtype}:{tuple(tensor.shape)}"
return f"no pretty: {type(tensor)}"
def pretty_node(self, node: NodeProto, limit: int = 80):
text = (
(
f"{node.op_type}[{node.domain}]: "
f"{', '.join(node.input)} -> {', '.join(node.output)}"
)
if node.domain
else f"{node.op_type}: {', '.join(node.input)} -> {', '.join(node.output)}"
)
add = " " * abs(80 - len(text))
text += add
info = []
for o in node.output:
t = f"T{self.get_type(o)}" if self.has_type(o) else ""
s = " x ".join(map(str, self.get_shape(o))) if self.has_shape(o) else ""
info.append(": ".join([t, s]))
if node.name:
return f"{text}|{' '.join(info)} - {node.name}"
return f"{text}|{' '.join(info)}"
def pretty_text(self, add_fx_graph: bool = False, recursive: bool = True) -> str:
def _d(d1):
if isinstance(d1, self.torch.SymInt):
return f"SymInt({self._torch_sym_int_to_str(d1)})"
if isinstance(d1, self.WrapSym):
return repr(d1)
if isinstance(d1, self.torch.export.dynamic_shapes._DerivedDim):
return f"_DerivedDim({d1.__name__})"
if isinstance(d1, self.torch.export.dynamic_shapes._Dim):
return f"_Dim({d1.__name__})"
if isinstance(d1, list):
s = ", ".join(map(_d, d1))
return f"[{s}]"
if isinstance(d1, tuple):
s = ", ".join(map(_d, d1))
return f"({s})"
if isinstance(d1, dict):
s = ", ".join(f"{k}:{_d(d)}" for d in d1.items())
return f"{{{s}}}"
if isinstance(d1, (str, int)):
return f"{d1!r}"
raise AssertionError(f"Unexpected type for {type(d1)}")
def _v(v):
if hasattr(v, "shape"):
shape = " x ".join(map(str, v.shape))
dtype = str(v.dtype)
else:
shape = "?"
dtype = "?"
return f"{dtype}: {shape}"
def _io(o, prefix):
text = f"{prefix}: {o}"
add = " " * abs(80 - len(text))
text += add
t = f"T{self.get_type(o)}" if self.has_type(o) else ""
s = " x ".join(map(str, self.get_shape(o))) if self.has_shape(o) else ""
return f"{text}|{': '.join([t,s])}"
rows = [""]
# signature
if self.signature:
rows.append(string_signature(self.signature))
# dynamic shapes
for k, v in sorted(self.dynamic_objects.items()):
rows.append(f"dyn---: {k} -> {_d(v)}")
for k, v in sorted(self.dynamic_objects_rev.items()):
rows.append(f"dynrev: {k} -> {_d(v)}")
for k, v in sorted(self.dynamic_dimensions_source.items()):
rows.append(f"dynsrc: {k} -> {_d(v)}")
for k, v in sorted(self._dynamic_alias.items()):
rows.append(f"dynals: {k} -> {_d(v)}")
if self.dynamic_shapes:
if isinstance(self.dynamic_shapes, dict):
for k, v in sorted(self.dynamic_shapes.items()):
rows.append(f"d-dynshp: {k} -> {_d(v)}")
else:
for k, v in enumerate(self.dynamic_shapes):
rows.append(f"t-dynshp: {k} -> {_d(v)}")
# the rest
for k, v in self.opsets.items():
rows.append(f"opset: {k}: {v}")
for k, v in self.initializers_dict.items():
rows.append(f"init: {k}: {_v(v)}")
for i in self.input_names:
rows.append(_io(i, "input:"))
for node in self.nodes:
rows.append(self.pretty_node(node))
for i in self.output_names:
rows.append(_io(i, "output:"))
for k, f in self.functions.items():
rows.append("")
rows.append(f"FUNCKEY: {k}")
rows.append(f"FUNC {f.name}[{f.domain}]: {f.input} -> {f.output}")
for op in f.opset_import:
n = op.domain if op.domain else "''"
rows.append(f" opset: {n}: {op.version}")
for node in f.node:
rows.append(f" {self.pretty_node(node)}")
if k in self.functions_builder:
rows.append(
self.functions_builder[k].pretty_text(
add_fx_graph=add_fx_graph, recursive=False
)
)
if add_fx_graph:
fx = self._debug_msg.get("process.graph_module")
if fx:
rows.append("-- FX.GRAPH-- ")
rows.append(str(fx))
return "\n".join(rows)
@property
def main_opset(self):
"Returns the opset for the main domain (assuming it is used)."
return self.opsets[""]
[docs]
def get_opset(self, domain: str) -> int:
"""
Returns the opset version for a specific domain.
:param domain: domain name
:return: version
"""
assert (
domain in self.opset
), f"Domain {domain!r} is not registered{self.get_debug_msg()}."
return self.opset[domain]
[docs]
def add_domain(self, domain: str, version: int = 1):
"""
Adds a domain to the list of supported ones.
Checks the version is the same if it exists.
"""
if domain in self.opsets:
assert version == self.opsets[domain], (
f"Version mismatch for domain={domain!r}, current is "
f"{self.opsets[domain]}, new is {version}{self.get_debug_msg()}"
)
return
self.opsets[domain] = version
def _hash(self) -> str:
return make_hash(self)
def _get_tensor_shape(self, proto: Union[NodeProto, TensorProto]) -> STATIC_SHAPE:
if isinstance(proto, TensorProto):
return tuple(proto.dims)
if isinstance(proto, NodeProto):
for att in proto.attribute:
assert att.type != AttributeProto.GRAPH, (
f"Unexpected attribute type for attribute {att.name!r} "
f"attribute list is {[a.name for a in proto.attribute]} "
f"in node type {proto.op_type!r}, name={proto.name!r}, "
f"doc_string={proto.doc_string!r}"
)
if att.name in ("value_float", "value_int"):
return tuple()
if att.name == "value_floats":
return (len(att.floats),)
if att.name == "value_ints":
return (len(att.ints),)
if att.name == "value":
t = onh.to_array(att.t)
return tuple(t.shape)
raise TypeError(f"Unexpected or unsupported scenario type {type(proto)}: {proto}.")
def _get_tensor_type(self, proto: Union[NodeProto, TensorProto]) -> int:
if isinstance(proto, TensorProto):
return proto.data_type
if isinstance(proto, NodeProto):
for att in proto.attribute:
if att.name == "value_float":
return TensorProto.FLOAT
if att.name == "value_int":
return TensorProto.INT64
if att.name == "value_floats":
return TensorProto.FLOAT
if att.name == "value_ints":
return TensorProto.INT64
if att.name == "value":
return att.t.data_type
raise ValueError(f"Unexpected type or value {type(proto)}: {proto}.")
[docs]
def is_constant(self, name: str) -> bool:
"""Tells if a result is a constant."""
return name in self.constants_
[docs]
def get_constant(
self,
name: str,
exc: bool = True,
computed_value: bool = False,
as_shape: bool = False,
multiple_outputs: bool = False,
) -> Union[np.ndarray, NodeProto]:
"""
The method returns the constant *name*. It is a tensor (numpy array)
or a NodeProto which must be evaluated.
If *computed_value* is True, the NodeProto is evaluated wuth the
ReferenceEvaluator.
:param name: constant name
:param exc: raise an exception if anything is impossible to do
:param computed_value: compute the value if not a constant
:param as_shape: returns a tuple for a shape
:param multiple_outputs: allow multiple outputs
:return: value
"""
if self._debug_get_constant:
print(
f"[GraphBuilder.get_constant] name={name}, "
f"computed_value={computed_value}, as_shape={as_shape}, "
f"exc={exc}"
)
if as_shape:
assert not multiple_outputs, "multiple outputs not allowed with as_shape=True"
res = self.get_constant(name, exc, computed_value=computed_value, as_shape=False)
if res is None:
assert not exc, (
f"No constant for name={name!r}, exc={exc}, "
f"computed_value={computed_value}, as_shape={as_shape}, "
f"multiple_outputs={multiple_outputs}{self.get_debug_msg()}"
)
if self._debug_get_constant:
print("[GraphBuilder.get_constant] A: None")
return None
assert multiple_outputs or not isinstance(
res, tuple
), f"Multiple output is not allowed but type is {type(res)} for name={name!r}"
new_res = []
for i in res:
new_res.append(i if isinstance(i, str) else int(i))
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] SHAPE: {tuple(new_res)}")
return tuple(new_res)
if not self.is_constant(name):
if exc:
raise ValueError(f"Result {name!r} is not a constant{self.get_debug_msg()}")
if self._debug_get_constant:
print("[GraphBuilder.get_constant] C: None")
return None
possible_value = self.constants_[name]
if name in self.constants_computed_:
value = self.constants_computed_[name]
assert value is not None, f"Constant is empty for name={name!r}"
assert multiple_outputs or not isinstance(
value, tuple
), f"Multiple output is not allowed but type is {type(value)} for name={name!r}"
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] D: value: {type(value)}")
return value
if possible_value is not None:
assert isinstance(possible_value, (np.ndarray, self.torch.Tensor, NodeProto)), (
f"Unexpected type {type(possible_value)} for a "
f"constant{self.get_debug_msg()}"
)
if computed_value and isinstance(possible_value, NodeProto):
res, _ = self.compute_constant(name, exc=exc)
if res is None:
# The constant is too big to be computed.
if self._debug_get_constant:
print("[GraphBuilder.get_constant] E: None")
return None
assert multiple_outputs or not isinstance(res, tuple), (
f"Multiple output is not allowed but type is {type(res)} "
f"for name={name!r}"
)
assert not multiple_outputs, (
f"get_constants not implemented when multiple_outputs=True, "
f"name={name!r}"
)
if not isinstance(res, tuple):
if self._debug_get_constant:
print("[GraphBuilder.get_constant] F: tuple")
return res
assert isinstance(res, tuple), (
f"Expecting multiple outputs for name={name!r}, "
f"op_type={possible_value.op_type!r}"
)
if len(res) == 1:
assert multiple_outputs or not isinstance(value, tuple), (
f"Multiple output is not allowed but type is {type(value)} "
f"for name={name!r}"
)
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] G: {type(res[0])}")
return res[0]
index = list(possible_value.output).index(name)
value = res[index]
assert value is not None, f"Constant is empty for name={name!r}"
assert multiple_outputs or not isinstance(value, tuple), (
f"Multiple output is not allowed but type is {type(value)} "
f"for name={name!r}"
)
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] H: {type(value)}")
return value
assert possible_value is not None, f"Constant is empty for name={name!r}"
assert multiple_outputs or not isinstance(possible_value, tuple), (
f"Multiple output is not allowed but type is {type(possible_value)} "
f"for name={name!r}"
)
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] I: {type(possible_value)}")
return possible_value
if name not in self.initializers_dict:
if exc:
raise ValueError(
f"Result {name!r} was never evaluated within method 'constant_folding'."
)
if self._debug_get_constant:
print("[GraphBuilder.get_constant] J: None")
return None
value = self.initializers_dict[name]
if isinstance(value, np.ndarray):
if self._debug_get_constant:
if value.size < 10:
print(
f"[GraphBuilder.get_constant] K: np.ndarray {value.shape}, "
f"{value.dtype}, {value}"
)
else:
print(
f"[GraphBuilder.get_constant] K: np.ndarray {value.shape}, "
f"{value.dtype}"
)
return value
if isinstance(value, self.torch.Tensor):
v = value.detach().cpu()
self.constants_computed_[name] = v
assert not multiple_outputs, f"Multiple output is not allowed for name={name!r}"
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] L: nptorch.Tensor {v.shape}, {v.dtype}")
return v
if isinstance(value, TensorProto):
if uses_external_data(value):
if exc:
raise TypeError(
f"Tensor is using external data, data_type={value.data_type}, "
f"dims={value.dims}"
)
if self._debug_get_constant:
print("[GraphBuilder.get_constant] M: None")
return None
v = onh.to_array(value)
assert not multiple_outputs, f"Multiple output is not allowed for name={name!r}"
self.constants_computed_[name] = v
if self._debug_get_constant:
print("[GraphBuilder.get_constant] O: TensorProto")
return v
if isinstance(value, np.float32):
# This should not be needed.
if self._debug_get_constant:
print(f"[GraphBuilder.get_constant] P: np.float32 {value}")
return np.array(value, dtype=np.float32)
raise TypeError(f"Unable to convert type {type(value)} into numpy array.")
[docs]
def is_sequence(self, name: str) -> bool:
"""Tells if a result is a sequence."""
if name in self._known_sequences:
return True
assert self.has_name(name), f"Unknown name={name!r}{self.get_debug_msg()}"
return False
[docs]
def get_sequence(self, name: str) -> Dict[str, Any]:
"""Returns sequence information"""
assert (
name in self._known_sequences
), f"Sequence {name!r} is not known{self.get_debug_msg()}"
return self._known_sequences[name]
[docs]
def set_sequence(
self,
name: str,
dtype: Union[int, Tuple[int, ...]],
shapes: Optional[DYNAMIC_SHAPE] = None,
ranks: Optional[Tuple[int, ...]] = None,
unknown: bool = False,
):
"""
Defines a result as a sequence.
"""
assert (
shapes is not None or ranks is not None or unknown
), f"shapes or ranks must be defines for name={name!r}{self.get_debug_msg()}"
assert self.has_name(name), f"No result name={name!r}{self.get_debug_msg()}"
assert isinstance(dtype, (int, tuple)), (
f"Only one type is allowed in onnx sequences but dtype={dtype!r}, "
f"the interpret allows multiple types for simplicity"
f"{self.get_debug_msg()}"
)
d = dict(dtype=dtype, shapes=shapes, ranks=ranks)
if shapes is not None and ranks is None:
d["ranks"] = tuple(len(s) for s in shapes)
if name not in self._known_sequences:
self._known_sequences[name] = d
else:
assert self._known_sequences[name] == d, (
f"Sequence {name!r} was already declared with a different type "
f"or shape or rank, declared={self._known_sequences[name]}, "
f"new={d}{self.get_debug_msg()}"
)
[docs]
def set_name(self, name: str, marker: str):
"""Adds a name to the list of known names."""
assert name != "", (
f"Empty name {name!r} cannot be registered, "
f"marker={marker!r}{self.get_debug_msg()}"
)
assert len(name) == len(name.strip()), (
f"No space should be added at the extremities of the name {name!r}"
f"{self.get_debug_msg()}"
)
assert name not in self._raise_list, (
f"Name {name!r} is one of the name declared in "
f"the stop list, marker={marker!r}{self.get_debug_msg()}"
)
assert isinstance(name, str), (
f"Unexpected type {type(name)} for name, "
f"marker={marker!r}, existing marker={self._events[name, 'set_name']}"
f"{self.get_debug_msg()}"
)
assert name not in self._known_names, (
f"Name {name!r} already exists, marker={marker!r}, "
f"existing marker is {self._events[name, 'set_name']}"
f"{self.get_debug_msg()}"
)
self._known_names.add(name)
self._unique_names.add(name)
self._events[name, "set_name"] = marker
[docs]
def set_rank(self, name: str, value: int):
"""
Sets the rank for a result.
:param name: result name
:param value: rank
"""
if name == self._debug_stop or name == self._debug_stop_shape:
# Set ONNXSTOP or ONNXSTOPSHAPE to stop here.
raise AssertionError(f"Requested stop, name={name!r}, rank={value}")
assert isinstance(value, int), f"Unexpected rank type {type(value)} for {name!r}"
assert not isinstance(value, bool), f"Unexpected rank type {type(value)} for {name!r}"
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
if name in self._known_ranks:
assert value == self._known_ranks[name], (
f"Inconsistent ranks for {name!r}, previous value is "
f"{self._known_ranks[name]}, new value is {value}{self.get_debug_msg()}"
)
if self.verbose > 5:
print(f"[GraphBuilder-{self._hash()}.set_rank] (again) {name}:{value}")
return
assert (
name not in self._known_ranks
), f"Name {name!r} already exists{self.get_debug_msg()}"
self._known_ranks[name] = value
if self.verbose > 5:
print(f"[GraphBuilder-{self._hash()}.set_rank] {name}:{value}")
def is_more_precise(self, shape: STATIC_SHAPE, base: STATIC_SHAPE) -> bool:
assert len(shape) == len(
base
), f"Cannot compare shapes with different ranks {shape} and {base}"
for a, b in zip(shape, base):
if isinstance(a, int) and isinstance(b, int):
if a != b:
return False
if isinstance(a, str) and isinstance(b, str):
if a != b:
return False
return True
[docs]
def get_is_dimension(
self,
name: str,
elem_type: Optional[int] = None,
shape: Optional[STATIC_SHAPE] = None,
n_outputs: Optional[int] = None,
) -> bool:
"""
Tells if a result is a dynamic dimension or not.
"""
if name in self.dynamic_objects:
res = True
elif name in self._known_torch_value:
value = self._known_torch_value[name]
if value[0] == "run_node":
val1 = value[1]
_exa, val = val1
if val is not None and len(val) == 3:
el_type, size = val[1:]
if el_type in (
self.torch.float32,
self.torch.float64,
self.torch.float16,
self.torch.bfloat16,
self.torch.bool,
):
return False
if len(size) >= 2:
return False
if el_type in (self.torch.int64, self.torch.int32) and len(size) == 0:
# A single integer with no shape, it looks like a dimension.
# Let's assume it is. It is more efficient to consider it as
# a dimension.
return True
# In another case, let's assume it is not.
return False
else:
if elem_type is not None and elem_type in (
self.torch.float32,
self.torch.float64,
self.torch.float16,
self.torch.bfloat16,
self.torch.bool,
):
return False
if shape is not None and len(shape) >= 2:
return False
dtype = self.get_type(name)
if dtype in {
TensorProto.FLOAT16,
TensorProto.FLOAT,
TensorProto.DOUBLE,
TensorProto.BFLOAT16,
TensorProto.BOOL,
}:
return False
if self.has_shape(name):
shape = self.get_shape(name)
if dtype == TensorProto.INT64 and shape == (1,):
return True
elif self.has_rank(name):
if self.get_rank(name) > 1:
return False
if isinstance(val1[0], tuple) and len(val1[0]) >= 1:
v = val1[0]
if (
isinstance(v, tuple)
and len(v) == 3
and v[0] == "example_value"
and len(self.dynamic_objects) == 0
):
# No dynamic shape as input, so there
# shoud not be any dynamic shape as output.
return False
if val1 == ("", ""):
# Another case where it seems False.
return False
raise RuntimeError(
f"Not implemented for name={name!r}, value={value!r} ({type(value)}), "
f"val1={val1}, elem_type={elem_type}, shape={shape}, n_outputs={n_outputs}"
f"{self.get_debug_msg()}"
)
elif value[0] == "call_module":
if isinstance(value[1], tuple) and len(value[1]) == 2:
el_type, size = value[1]
if el_type in (
self.torch.float32,
self.torch.float64,
self.torch.float16,
self.torch.bfloat16,
):
return False
if len(size) >= 2:
return False
if n_outputs == 1:
# We may assume a model would not output just one dimension.
return False
raise RuntimeError(
f"Not implemented for name={name!r}, value={value!r} ({type(value)}), "
f"elem_type={elem_type}, shape={shape}, n_outputs={n_outputs}"
f"{self.get_debug_msg()}"
)
elif "_INT_" in name:
# This is most likely a dimension but not marked as such for the time being.
return False
else:
if elem_type in {
TensorProto.FLOAT16,
TensorProto.FLOAT,
TensorProto.DOUBLE,
TensorProto.BFLOAT16,
TensorProto.BOOL,
}:
return False
raise RuntimeError(
f"Unable to gues if {name!r}, elem_type={elem_type}, "
f"shape={shape} is a dimension{self.get_debug_msg()}"
)
assert not res or (
(
elem_type is None
or elem_type
in {
TensorProto.INT64,
TensorProto.INT32,
TensorProto.UINT64,
TensorProto.UINT32,
# not a dimension but a result of a computation involving a dimension
TensorProto.FLOAT,
}
)
and (shape is None or (isinstance(shape, tuple) and len(shape) == 1))
), (
f"Inconsistent result type for name={name!r}, is_dimension={res}, "
f"elem_type={elem_type}, shape={shape}{self.get_debug_msg()}"
)
return res
def set_shapes_types(
self,
name: Union[str, "torch.fx.Node"], # noqa: F821
where: str,
value: Any,
):
if hasattr(name, "name"):
name = name.name
self._known_torch_value[name] = (where, value)
def _torch_sym_int_to_str(self, value: "torch.SymInt") -> Union[int, str]: # noqa: F821
if isinstance(value.node, str):
return f"{value.node}"
from torch.fx.experimental.sym_node import SymNode
if isinstance(value.node, SymNode):
# '_expr' is safer than expr
return str(value.node._expr)
try:
val_int = int(value)
return val_int
except (
TypeError,
ValueError,
AttributeError,
self.torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
):
pass
raise AssertionError(f"Unable to convert {value!r} into string")
def _check_two_shapes_are_compatible(
self,
old_shape: Tuple[Any, ...],
shape: Tuple[Any, ...],
register_int: bool = True,
name: Optional[str] = None,
):
"""
Raises an exception if two shapes are not compatabible.
"""
import torch
assert len(old_shape) == len(shape), (
f"Rank mismatch between {old_shape} and {shape} (name={name!r}"
f"{self.get_debug_msg()}"
)
for d1, d2 in zip(old_shape, shape):
if isinstance(d1, int) and isinstance(d2, int):
assert d1 == d2, (
f"Shape {name!r} already exists and one dimension "
f"is not compatible existing {old_shape} != {shape} "
f"(new) {self.get_debug_msg()}"
)
continue
d1_, d2_ = d1, d2
if isinstance(d1, torch.SymInt):
d1 = self._torch_sym_int_to_str(d1)
elif isinstance(d1, torch.export.dynamic_shapes._Dim):
d1 = self._torch_sym_int_to_str(d1)
if isinstance(d2, torch.SymInt):
d2 = self._torch_sym_int_to_str(d2)
elif isinstance(d2, torch.export.dynamic_shapes._Dim):
d2 = self._torch_sym_int_to_str(d2)
if isinstance(d1, (int, str)) and isinstance(d2, (int, str)):
if d1 == d2:
continue
if isinstance(d1, str) and isinstance(d2, str):
self.register_constraint_dimension(d1, d2)
self.register_constraint_dimension(d2, d1)
continue
raise RuntimeError(
f"Shape {name!r} already exists "
f"and it is not compatible "
f"existing {old_shape} != {shape} (new) "
f"d1={d1_!r}, d2={d2_!r}, dim types="
f"({type(d1_)}, {type(d2_)}), "
f"d1={d1!r}, d2={d2!r}"
f"{self.get_debug_msg()}"
)
[docs]
def register_dynamic_objects_from_shape(self, shape: DYNAMIC_SHAPE):
"""
Registers all the dynamic objects required in this shape.
"""
for dim in shape:
if isinstance(dim, str):
self.register_dynamic_objects_from_dim(dim)
[docs]
def register_dynamic_objects_from_dim(self, dim: str):
"""
Registers all the dynamic objects required in a dimension.
"""
assert isinstance(
dim, str
), f"type(dim)={type(dim)} must be a str{self.get_debug_msg()}"
for token in parse_expression_tokens(dim):
if token not in self.dynamic_objects:
self.add_dynamic_object(token, token)
if dim not in self.dynamic_objects:
self.add_dynamic_object(dim, dim)
[docs]
def set_shape(
self,
name: str,
shape: DYNAMIC_SHAPE,
set_rank: bool = True,
set_if_more_precise: bool = False,
exc: bool = False,
):
"""
Sets the shape for a result. It is exists, it checks the new shape
is equal to the existing one.
:param name: result name
:param shape: shape
:param set_rank: set the rank as well
:param set_if_more_precise: change the shape if it is more precise
:param exc: raise an exception if inconsistency
"""
if (self._debug_stop or self._debug_stop_shape) and name in (
self._debug_stop,
self._debug_stop_shape,
):
# Set ONNXSTOP or ONNXSTOPSHAPE to stop here.
raise AssertionError(f"Requested stop, name={name!r}, shape={shape}")
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert "torch.Size" not in str(
shape
), f"Unexpected type {type(shape)} for a shape={shape}{self.get_debug_msg()}"
assert isinstance(shape, tuple), f"Unexpected shape type {type(shape)}"
for sdim in shape:
if not isinstance(sdim, str):
continue
self.register_dynamic_objects_from_dim(sdim)
shape = self.verify_shape(shape, 0, name=name)
assert all(not isinstance(t, self.torch.SymInt) for t in shape), (
f"Unexpected type for a shape, shape={shape}, types={[type(_) for _ in shape]}"
f"{self.get_debug_msg()}"
)
shape_int = [d for d in shape if isinstance(d, int)]
assert (
len(shape) == 0 or not shape_int or min(shape_int) >= 0
), f"Negative value in shape {shape} for {name!r}{self.get_debug_msg()}"
assert (
not self._debug_null_shape
or len(shape) == 0
or not shape_int
or min(shape_int) > 0
), f"Zero value in shape {shape} for {name!r}{self.get_debug_msg()}"
if name in self._known_shapes:
old_shape = self._known_shapes[name]
if len(shape) == len(old_shape) and set_if_more_precise:
if not self.is_more_precise(shape=shape, base=old_shape):
if old_shape == (0,) or shape == (0,):
if "warnings" not in self._debug_msg:
self._debug_msg["warnings"] = []
self._debug_msg["warnings"].append(
f"Name {name!r} already exists and it is not compatible "
f"existing {old_shape} != {shape} (new)"
)
else:
self._check_two_shapes_are_compatible(
old_shape, shape, name=name, register_int=False
)
elif shape != old_shape:
if exc:
raise RuntimeError(
f"Name {name!r} already exists and its shape different "
f"{old_shape} (old) != {shape}{self.get_debug_msg()}"
)
return
else:
return
if self.verbose > 5:
print(f"[GraphBuilder-{self._hash()}.set_shape] {name}:{shape}")
for s in shape:
if isinstance(s, str) and s not in self.dynamic_objects:
self.add_dynamic_object(s, s)
self._known_shapes[name] = shape
if set_rank and not self.has_rank(name):
self.set_rank(name, len(shape))
[docs]
def set_type_shape_or_rank(self, name: str, like: str):
"""
Sets the type and the shape of *name* like *like*.
"""
if self.has_type(like):
self.set_type(name, self.get_type(like))
if self.has_shape(like):
self.set_shape(name, self.get_shape(like))
elif self.has_rank(like):
self.set_rank(name, self.get_rank(like))
[docs]
def set_type(self, name: str, dtype: int, exc: bool = True):
"""
Sets the shape for a result. It is exists, it checks the new shape
is equal to the existing one.
:param name: name
:param dtype: element type (an integer, ONNX)
:param exc: raises an exception
"""
if name in (self._debug_stop, self._debug_stop_type):
raise AssertionError(f"Requested stop, name={name!r}, dtype={dtype}")
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
if isinstance(dtype, int):
int_type = dtype
else:
int_type = _get_type(dtype)
if name in self._known_types:
# 0 is undefined
if self._known_types[name] != 0 and int_type != self._known_types[name]:
if exc:
from . import str_tensor_proto_type
raise RuntimeError(
f"Type for name {name!r} already exists and it is different, "
f"known is {self._known_types[name]} != {int_type} (new) - "
f"(mapping={str_tensor_proto_type()}){self.get_debug_msg()}"
)
if "warnings" not in self._debug_msg:
self._debug_msg["warnings"] = []
self._debug_msg["warnings"].append(
f"Type for name {name!r} already exists and it is different, "
f"known is {self._known_types[name]} != {int_type} (new) - "
)
if self.verbose > 5:
print(
f"Type for name {name!r} already exists and it is different, "
f"known is {self._known_types[name]} != {int_type} (new) - "
)
return
if self.verbose > 5:
print(f"[GraphBuilder-{self._hash()}.set_type] {name}:{int_type}")
self._known_types[name] = int_type
[docs]
def rank(self, name: str) -> int:
"""Shortcut to :meth:`get_rank`."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
return self.get_rank(name)
[docs]
def has_name(self, name: str) -> bool:
"""Tells if a result exists."""
assert isinstance(
name, str
), f"Unexpected type {type(name)} for name (name={name!r}){self.get_debug_msg()}"
return name in self._known_names
[docs]
def has_rank(self, name: str) -> bool:
"""Tells if a result has a rank."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
return name in self._known_ranks
[docs]
def has_shape(self, name: str, full=False) -> bool:
"""Tells if a result has a shape."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
if name not in self._known_shapes:
return False
if full:
shape = self._known_shapes[name]
return is_static_shape(shape) and min(shape) >= 0
return True
[docs]
def has_type(self, name: str) -> bool:
"""Tells if a result has a type. This should be always true."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
return name in self._known_types
[docs]
def get_rank(self, name: str) -> int:
"""Returns the rank of a result."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert name in self._known_ranks, (
f"rank is unknown for result {name!r}, has_shape={self.has_shape(name)}, "
f"has_rank={self.has_rank(name)}, "
f"known_ranks={self._known_ranks}{self.get_debug_msg()}"
)
return self._known_ranks[name]
[docs]
def get_shape(self, name: str) -> int:
"""Returns the shape of a result."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert name in self._known_shapes, (
f"Shape is unknown for result {name!r}, "
f"known_shapes={self._known_shapes}{self.get_debug_msg()}"
)
return self._known_shapes[name]
[docs]
def get_type_known(self, name: str) -> Optional[int]:
"""Returns the type known by torch to help solve mismatches."""
if name in self._known_torch_value:
value = self._known_torch_value[name]
# something like (
# 'run_node',
# (
# '',
# ('val', torch.float16, torch.Size([2, 12, 2048, 2048]))
# )
# )
assert (
isinstance(value, tuple)
and len(value) == 2
and isinstance(value[1], tuple)
and len(value[1][1]) == 3
), f"Unexpected value {value} for {name!r}{self.get_debug_msg()}"
dtype = value[1][1][1]
itype = torch_dtype_to_onnx_dtype(dtype)
return itype
return None
[docs]
def get_type(self, name: str) -> int:
"""Returns the type of a result."""
assert isinstance(name, str), f"Unexpected type {type(name)} for name."
assert name in self._known_types, (
f"Type is unknown for result {name!r}, "
f"known_types={self._known_types}{self.get_debug_msg()}."
)
return self._known_types[name]
[docs]
def value_as_shape(self, name: str) -> bool:
"""Returns the value of a result if it is a shape."""
if name in self._known_value_shape:
return self._known_value_shape[name]
return None
[docs]
def set_value_shape(
self, name: str, value: Any, equal_to: Optional[Tuple[str, str]] = None
):
"""
Sets the value for a shape result.
:param name: name
:param value: it cannot be empty
:param equal_to: if specified, the value is also
equal to this value
"""
assert isinstance(
name, str
), f"Unexpected type {type(name)} for name={name!r}{self.get_debug_msg()}"
assert value not in {
tuple()
}, f"Unexpected value for shape {name!r}, value={value}{self.get_debug_msg()}"
if equal_to is None:
assert (
name not in self._known_value_shape or self._known_value_shape[name] == value
), (
f"Shape value for {name!r} (value={value!r}) is already "
f"registered and is different from the existing "
f"value={value!r} (equal_to={equal_to!r}), "
f"existing value is {self._known_value_shape.get(name, None)!r}"
f"{self.get_debug_msg()}"
)
if self.verbose > 2:
print(f"[GraphBuilder-{self._hash()}.set_value_shape] {name}[{value}]")
self._known_value_shape[name] = value
return
assert (
name in equal_to
), f"Unexpected name={name!r}, it should be in equal_to={equal_to!r}."
values = (
self._known_value_shape.get(equal_to[0], None),
self._known_value_shape.get(equal_to[1], None),
)
assert value in values, (
f"Unexpected value={value} for name={name!r}, equal_to={equal_to}, "
f"values={values}{self.get_debug_msg()}"
)
assert equal_to[0] in self._known_value_shape, (
f"{equal_to[0]!r} should already registered, name={name!r}, "
f"value={value!r}, equal_to={equal_to!r}{self.get_debug_msg()}"
)
# The logic is to get rid of one value instead of keeping
# a mapping between equivalent values.
new_value = self._known_value_shape[equal_to[0]]
for n in equal_to:
if n not in self._known_value_shape:
self._known_value_shape[n] = new_value
def unique_name(self, prefix: str) -> str:
if prefix in self._unique_names:
i = 2
sug = f"{prefix}2"
while sug in self._unique_names:
i += 1
sug = f"{prefix}{i}"
self._unique_names.add(sug)
return sug
self._unique_names.add(prefix)
return prefix
def unique_node_name(self, name: str) -> str:
if name in self._unique_node_names:
i = 2
sug = f"{name}2"
while sug in self._unique_node_names:
i += 1
sug = f"{name}{i}"
self._unique_node_names.add(sug)
return sug
self._unique_node_names.add(name)
return name
[docs]
def elem_size(self, elem_type: int) -> int:
"Returns the size in byte of the an element of this size."
if elem_type in {TensorProto.FLOAT, TensorProto.INT32, TensorProto.UINT32}:
return 4
if elem_type in {
TensorProto.DOUBLE,
TensorProto.INT64,
TensorProto.UINT64,
TensorProto.COMPLEX64,
}:
return 8
if elem_type in {TensorProto.COMPLEX128}:
return 16
if elem_type in {
TensorProto.INT16,
TensorProto.UINT16,
TensorProto.FLOAT16,
TensorProto.BFLOAT16,
}:
return 2
if elem_type in {TensorProto.BOOL, TensorProto.UINT8, TensorProto.INT8}:
return 1
from . import str_tensor_proto_type
raise AssertionError(
f"elem_size not implemented for elem_type={elem_type}, "
f"among {str_tensor_proto_type()}"
)
[docs]
def make_dynamic_object(
self,
name: str,
value: Any,
shape_as_input: bool = False,
input_name: Optional[str] = None,
axis: Optional[int] = None,
) -> str:
"""
Creates a dynamic shapes.
:param name: name
:param value: value
:param shape_as_input: adds the name to the list of the inputs
of the onnx model
:param input_name: the dimension comes from this input
:param axis: the dimension comes this axis
:return: the name
"""
def _append_to_source(name, input_name, axis, value):
if input_name is not None and isinstance(value, self.torch.SymInt):
assert axis is not None, (
f"input_name={input_name!r} but axis is None for "
f"dynamic shape {name!r}, value type is {type(value)!r} "
f"{self.get_debug_msg}"
)
assert name != input_name, (
f"Name {name!r} cannot be defined from itself (axis={axis}), "
f"value type is {type(value)}{self.get_debug_msg()}"
)
source = dict(input_name=input_name, axis=axis)
if name in self.dynamic_dimensions_source:
self.dynamic_dimensions_source[name].append(source)
else:
self.dynamic_dimensions_source[name] = [source]
if name in self.dynamic_objects:
# The dimension is already registered but it is used for another input.
_append_to_source(name, input_name, axis, value)
return None
assert name not in self.dynamic_objects, (
f"Dynamic object {name!r}, value={value!r} "
f"is already there{self.get_debug_msg()}"
)
if isinstance(value, self.WrapSym):
value = value.sym
assert isinstance(
value, (self.torch.SymInt, self.torch.SymFloat, self.torch.SymBool)
), f"Unexpected type {type(value)} for value{self.get_debug_msg()}"
_append_to_source(name, input_name, axis, value)
self.add_dynamic_object(name, value, parse=True)
if (
shape_as_input
and isinstance(value, self.torch.SymInt)
and value.node.maybe_as_int() is None
):
# Then an input is a shape.
self.add_dynamic_object(str(value), value)
if name not in self._known_value_shape:
self._known_value_shape[name] = name
key = None
if isinstance(value, (self.torch.SymInt, self.torch.SymFloat)):
if hasattr(value, "node"):
from torch.fx.experimental.sym_node import SymNode
if isinstance(value.node, str):
key = f"{value.node}"
elif isinstance(value.node, SymNode):
key = f"{value.node}"
key2 = value.node.maybe_as_int()
if key2 is not None:
self._add_dynamic_example(key, key2)
else:
raise AssertionError(
f"Unexpected type {type(value.node)} for value.node={value.node}"
)
# key = str(value)
if key is None:
key = str(value)
self.add_dynamic_objects_rev(key, (name, value))
if shape_as_input:
assert isinstance(value, (self.torch.SymInt, self.torch.SymFloat)), (
f"shape_as_input={shape_as_input}, unexpected type "
f"{type(value)} for value{self.get_debug_msg()}"
)
# torch.compile adds input for dynamic shapes
return self.make_tensor_input(
self._known_value_shape[name],
(
TensorProto.INT64
if isinstance(value, self.torch.SymInt)
else TensorProto.FLOAT
),
(1,),
is_dimension=True,
marker="make_dynamic_object",
)
return self._known_value_shape[name]
def get_dimension_as_result(self, name: str) -> str:
if self.has_name(name):
return name
assert (
name in self.dynamic_dimensions_source
and len(self.dynamic_dimensions_source[name]) > 0
), (
f"Dimension {name!r} has no registered source "
f"it cannot be created as a result{self.get_debug_msg()}"
)
source = self.dynamic_dimensions_source[name][0]
axis = source["axis"]
input_name = source["input_name"]
shape_name = self.unique_name(f"_onx_shape_{name}")
self.make_node("Shape", [input_name], [shape_name], name="_get_dimension_as_result")
axis_name = self.make_initializer("", np.array([axis], dtype=np.int64))
self.make_node(
"Gather", [shape_name, axis_name], [name], name="_get_dimension_as_result"
)
return name
[docs]
def make_shape_from_results(self, shape: DYNAMIC_SHAPE, name="") -> str:
"""
Creates a shape coming from intermediate results.
"""
assert isinstance(
shape, (list, tuple)
), f"Unexpected type {type(shape)} for shape{self.get_debug_msg()}"
if all_int(shape):
return self.make_initializer("", np.array(shape, dtype=np.int64))
key = []
for d in shape:
if isinstance(d, int):
key.append(d)
elif isinstance(d, (str, self.torch.SymInt)):
value = self._torch_sym_int(d)
key.append(value)
else:
raise RuntimeError(
f"Unexpected type {type(d)} for a dimension in "
f"{shape}{self.get_debug_msg()}"
)
assert all_int_or_str(key), (
f"Unexpected key {key} type are {[type(_) for _ in key]}, "
f"shape={shape}{self.get_debug_msg()}"
)
key = ("Concat", *key)
if key in self._cache_shape:
# The same shape was already requested.
return self._cache_shape[key]
conc = []
for d in shape:
if isinstance(d, int):
conc.append(self.make_initializer("", np.array([d], dtype=np.int64)))
elif isinstance(d, str):
value = d
if value in self.dynamic_objects_rev:
assert len(self.dynamic_objects_rev[value]) >= 1
name = self.dynamic_objects_rev[value][0][0]
assert not isinstance(name, tuple), (
f"Unexpected type {type(name)}, name={name!r}, value={value!r}"
f"{self.get_debug_msg()}"
)
name = self.get_dimension_as_result(name)
else:
name = value
assert name in self.dynamic_objects or self.has_name(
name
), f"Unknown dynamic object {d!r} (or {name!r}){self.get_debug_msg()}"
if self.has_rank(name):
assert self.get_rank(name) <= 1, (
f"Unexpected rank={self.get_rank(name)} "
"for a shape{self.get_debug_msg()}"
)
if self.get_rank(name) == 0:
r = self.op.UnsqueezeAnyOpset(
name, np.array([0], dtype=np.int64), name=f"_mkshape1_{name}"
)
self.set_type(r, self.get_type(name))
self.set_shape(r, (1,))
conc.append(r)
else:
conc.append(name)
else:
conc.append(name)
elif isinstance(d, self.torch.SymInt):
value = self._torch_sym_int(d)
if value in self.dynamic_objects_rev:
assert len(self.dynamic_objects_rev[value]) >= 1
name = self.dynamic_objects_rev[value][0][0]
assert not isinstance(name, tuple), (
f"Unexpected type {type(name)}, name={name!r}, value={value!r}"
f"{self.get_debug_msg()}"
)
name = self.get_dimension_as_result(name)
else:
name = value
if isinstance(name, self.torch.SymInt):
name = self._torch_sym_int(name)
assert not isinstance(name, self.torch.SymInt)
assert not isinstance(name, self.torch.SymInt)
assert name in self.dynamic_objects or self.has_name(
name
), f"Unknown dynamic object {d!r}-{name!r}{self.get_debug_msg()}"
if self.has_rank(name):
assert self.get_rank(name) <= 1, (
f"Unexpected rank={self.get_rank(name)} "
"for a shape{self.get_debug_msg()}"
)
if self.get_rank(name) == 0:
r = self.op.UnsqueezeAnyOpset(
name, np.array([0], dtype=np.int64), name=f"_mkshape2_{name}"
)
self.set_type(r, self.get_type(name))
self.set_shape(r, (1,))
conc.append(r)
else:
conc.append(name)
else:
conc.append(name)
else:
raise RuntimeError(
f"Unexpected type {type(d)} for a dimension in "
f"{shape}{self.get_debug_msg()}"
)
if len(conc) > 1:
res = self.make_node("Concat", conc, axis=0, name=f"_mkshape_{name}")
else:
assert len(conc) > 0, f"No shape to concatenate{self.get_debug_msg()}"
res = self.make_node("Identity", conc[0], name=f"_mkshape1_{name}")
self._cache_shape[key] = res
return res
[docs]
def make_initializer(
self,
name: str,
value: Any,
external: bool = False,
msg: str = "",
parameter_name: Optional[str] = None,
) -> str:
"""
Adds an initializer to the graph.
The function detects duplicated small containers, only if they are
integers. Other type might be used as weights. Even similar, they could
change after training.
:param name: name, if empty (`""`), a unique names is given, if not empty,
it is more like a prefix, the method might change it to make it unique
:param value: value (TensorProto)
:param external: external initializer or not (not stored in the graph model)
:param msg: added to the error message if something goes wrong
:param parameter_name: the parameter name is different than its name in the fx graph,
they are restored when the model is finally exported into onnx,
until then, the mapping is kept in attribute ``_parameter_renaming``
:return: name of the initializer
"""
if external:
raise NotImplementedError("External initializers are not implemented yet.")
if isinstance(value, int):
value = np.array(value, dtype=np.int64)
elif isinstance(value, float):
value = np.array(value, dtype=np.float32)
elif hasattr(value, "data"):
# torch.nn.parameter.Parameter -> np.ndarray
assert "FakeTensor" not in str(type(value)), (
f"FakeTensor {name!r} cannot be an initializer {type(value)}"
f"{self.get_debug_msg()}"
)
elif isinstance(value, TensorProto):
assert "FakeTensor" not in str(type(value)), (
f"FakeTensor {name!r} cannot be an initializer {type(value)}"
f"{self.get_debug_msg()}"
)
elif isinstance(value, np.ndarray):
pass
else:
raise RuntimeError(
f"Initializer name={name!r}, "
f"unexpected type {type(value)} for value={value!r} ({msg})"
f"{self.get_debug_msg()}"
)
key = self.make_key(value)
if key and key in self._values:
if name == "":
assert (
not parameter_name
), f"Empty name cannot be used with parameter_name={parameter_name!r}"
new_name = self._values[key]
assert new_name in self.initializers_dict, f"Unable to find {new_name!r}"
return new_name
return self.make_node(
"Identity",
[self._values[key]],
[name],
name="make_initializer",
insert_position="HEAD",
)
if isinstance(value, TensorProto):
itype = value.data_type
shape = tuple(value.dims)
else:
itype = _get_type(value.dtype)
shape = tuple(value.shape)
if name == "":
sh = "x".join(map(str, shape))
size = np.prod(value.size()) if hasattr(value, "detach") else value.size
sh2 = (
"_".join(map(str, value.ravel().tolist()))
if size <= 5 and value.dtype == np.int64
else ""
)
name = self.unique_name(f"init{itype}_s{sh}_{sh2}")
self.add_initializer(
name, value, itype=itype, shape=shape, key=key, parameter_name=parameter_name
)
return name
[docs]
def add_initializer(
self,
name: str,
value: Any,
itype: Optional[int] = None,
shape: Optional[Tuple[int, ...]] = None,
cst: Optional[Any] = None,
key: Optional[Any] = None,
existing: bool = False,
allow_empty: bool = False,
parameter_name: Optional[str] = None,
):
"""
Adds an initializer.
:param name: constant name
:param value: initializer
:param itype: to overwrite the type
:param shape: to overwrite the shape
:param cst: value to send to :meth:`update_node_constant
<experimental_experiment.xbuilder.GraphBuilder.update_node_constant>`
:param key: used to register the initializer
:param existing: if True, shape and type should exist,
if False, it should not exist, if None, both case are allowed
:param allow_empty: allow empty tensor anyway
:param parameter_name: the parameter name is different than its name in the fx graph,
they are restored when the model is finally exported into onnx,
until then, the mapping is kept in attribute ``_parameter_renaming``
"""
is_proto = isinstance(value, (TensorProto, NodeProto))
if shape is None:
shape = (
self._get_tensor_shape(value)
if is_proto
else tuple(int(i) for i in value.shape)
)
if itype is None:
itype = (
self._get_tensor_type(value)
if is_proto
else dtype_to_tensor_dtype(value.dtype)
)
if existing:
assert allow_empty or len(shape) == 0 or min(shape) > 0, (
f"Initializer {name!r} has an empty shape={shape}, itype={itype}, "
f"existing shape={self.get_shape(name) if self.has_shape(name) else '?'}, "
f"type={type(value)}{self.get_debug_msg()}"
)
assert self.has_name(name), (
f"value {name!r} is replaced by an initializer "
f"already exists{self.get_debug_msg()}"
)
assert not self.has_type(name) or itype == self.get_type(name), (
f"Type mismatch for {name!r}, existing type {self.get_type(name)}, "
f"new type {itype}{self.get_debug_msg()}"
)
assert not self.has_shape(name) or shape == self.get_shape(name), (
f"Type mismatch for {name!r}, existing shape "
f"{self.get_shape(name)}, new shape {shape}{self.get_debug_msg()}"
)
self.set_shape(name, shape)
self.set_type(name, itype)
else:
assert allow_empty or len(shape) == 0 or min(shape) > 0, (
f"Initializer {name!r} has an empty shape={shape}, itype={itype}, "
f"type={type(value)}{self.get_debug_msg()}"
)
assert existing is None or name not in self.initializers_dict, (
f"initializer {name!r} was already added (itype={itype}, shape={shape})"
f"{self.get_debug_msg()}"
)
assert existing is None or not self.has_name(
name
), f"initializer {name!r} already exists{self.get_debug_msg()}"
self.set_shape(name, shape)
self.set_type(name, itype)
if not self.has_name(name):
self.set_name(name, "make_initializer")
self._unique_names.add(name)
self.initializers_dict[name] = value
if parameter_name and parameter_name != name:
# We want a specific name for this one, let's keep that information in
# main so that we can rename them later.
assert not self.has_name(parameter_name), (
f"Parameter {name!r} cannot be renamed int {parameter_name!r} as it "
f"is already taken{self.get_debug_msg()}"
)
self._parameter_renaming[name] = parameter_name
self._parameter_norename.add(parameter_name)
self.initializers_dict[parameter_name] = value
self.set_name(parameter_name, marker="parameter_name")
self.set_shape(parameter_name, self.get_shape(name))
self.set_type(parameter_name, self.get_type(name))
if self.verbose and (
self.verbose > 3 or (self.verbose > 2 and np.prod(shape) > 100)
):
print(
f"[GraphBuilder-{self._hash()}.make_initializer] "
f"{name}[{itype}:{shape}] -> {parameter_name}"
)
else:
if self.verbose and (
self.verbose > 3 or (self.verbose > 2 and np.prod(shape) > 100)
):
print(
f"[GraphBuilder-{self._hash()}.make_initializer] {name}[{itype}:{shape}]"
)
if cst is None and isinstance(value, NodeProto):
cst = value
is_constant = self.update_node_constant(name, cst)
if is_constant and parameter_name:
self.update_node_constant(parameter_name, cst)
if key is None:
key = self.make_key(value)
if key not in self._values:
self._values[key] = name
return name
def is_dynamic_shape(
self, shape: DYNAMIC_SHAPE, verify: bool = True, allow_none: bool = False
) -> bool:
return all(
self.is_dynamic_dimension(x, verify=verify, allow_none=allow_none) for x in shape
)
[docs]
def is_constant_or_attribute(
self, node: NodeProto, input_index: int, att_name: str
) -> bool:
"""
Tells if an input is a constant or returns true if in an older
opset, it was named as an attribute.
"""
if input_index < len(node.input):
return self.is_constant(node.input[input_index])
return True
[docs]
def get_constant_or_attribute(
self, node: NodeProto, input_index: int, att_name: str
) -> Any:
"""
Tells if an input is a constant or returns true if in an older
opset, it was named as an attribute.
"""
if input_index < len(node.input):
return self.get_constant(node.input[input_index])
for att in node.attribute:
if att.name == att_name:
assert (
att.type == AttributeProto.INTS
), f"Not Implemented when att.type={att.type}{self.get_debug_msg()}"
return np.array(list(att.ints), dtype=np.int64)
return None
def simple_update_value_shape_with_node(self, node) -> bool:
if node.domain != "":
return False
if node.op_type not in {
"Abs",
"Add",
"Concat",
"Div",
"Gather",
"Greater",
"GreaterOrEqual",
"Equal",
"Identity",
"Less",
"LessOrEqual",
"Mod",
"Mul",
"Not",
"Range",
"Scatter",
"Shape",
"Slice",
"Squeeze",
"Sub",
"Unsqueeze",
}:
return False
if node.op_type == "Identity":
value = self.value_as_shape(node.input[0])
if value is not None:
node.doc_string += "#II1"
self.set_value_shape(
node.output[0], value, equal_to=(node.input[0], node.output[0])
)
return True
node.doc_string += "#II2"
return False
if node.op_type == "Squeeze":
if self.is_constant_or_attribute(node, 1, "axes"):
node.doc_string += "#ISq1"
y = self.value_as_shape(node.input[0])
if y is None:
return False
i = self.get_constant_or_attribute(node, 1, "axes")
if isinstance(i, int):
ii = i
elif (
isinstance(i, np.ndarray)
and i.dtype == np.int64
and i.shape in ((1,), tuple())
):
ii = int(i[0]) if i.shape == (1,) else int(i)
else:
raise RuntimeError(
f"Not implemented when node Squeeze with inputs={node.input}, "
f"y={y!r}, i={i!r}{self.get_debug_msg()}"
)
assert (
ii == 0
), f"A shape should only have one axis i={i}, y={y}{self.get_debug_msg()}"
if isinstance(y, str):
self.set_value_shape(node.output[0], f"squeeze({y})")
return True
if isinstance(y, int):
self.set_value_shape(node.output[0], y)
return True
assert isinstance(
y, tuple
), f"Unexpected type {type(y)} for y={y} and i={i}{self.get_debug_msg()}"
self.set_value_shape(node.output[0], y[0])
return True
node.doc_string += "#ISq2"
return False
if node.op_type == "Shape":
if len(node.attribute) == 0:
node.doc_string += "#IS1"
if self.has_shape(node.input[0]):
shape = self.get_shape(node.input[0])
self.set_value_shape(node.output[0], shape)
if all_int(shape):
self.update_node_constant(node.output[0], node)
self.set_shape(node.output[0], (len(shape),))
else:
self.set_value_shape(node.output[0], node.output[0])
return True
node.doc_string += "#IS2"
start = self.get_attribute(node, "start", exc=False) or 0
end = self.get_attribute(node, "end", exc=False)
if end is None:
if self.has_rank(node.input[0]):
end = self.get_rank(node.input[0])
if self.has_shape(node.input[0]):
node.doc_string += "#IS3"
shape = self.get_shape(node.input[0])
assert start.i < len(shape), (
f"Shape mismatch, start={start.i}, shape of {node.input[0]!r} "
f"is {shape}{self.get_debug_msg()}"
)
if end is None:
n_shape = shape[start.i :]
self.set_value_shape(node.output[0], n_shape)
if all_int(shape):
self.update_node_constant(node.output[0], node)
self.set_shape(node.output[0], (len(n_shape),))
node.doc_string += "#IS4"
return True
assert getattr(end, "i", end) <= len(shape), (
f"Shape mismatch, end={getattr(end, 'i', end)}, "
f"shape of {node.input[0]!r} "
f"is {shape}{self.get_debug_msg()}"
)
n_shape = shape[start.i : getattr(end, "i", end)]
if all_int(shape):
node.doc_string += "#IS5"
self.update_node_constant(node.output[0], node)
self.set_value_shape(node.output[0], n_shape)
self.set_shape(node.output[0], (len(n_shape),))
node.doc_string += "#IS6"
return True
if end is None:
self.set_value_shape(node.output[0], f"{node.input[0]}[{start.i}:]")
node.doc_string += "#IS6"
return False
self.set_value_shape(
node.output[0],
f"{node.input[0]}[{start.i}:{getattr(end, 'i', end)}]",
)
node.doc_string += "#IS7"
return True
if node.op_type == "Gather":
if self.is_constant(node.input[1]):
node.doc_string += "#IG1"
y = self.value_as_shape(node.input[0])
if y is None:
node.doc_string += "#IG2"
return False
i = self.get_constant(node.input[1], computed_value=True)
if isinstance(y, str) and isinstance(i, int):
self.set_value_shape(node.output[0], f"{y}[{i}]")
node.doc_string += "#IG3"
return True
if (
isinstance(y, str)
and isinstance(i, np.ndarray)
and i.dtype == np.int64
and i.shape in ((1,), tuple())
):
ii = int(i[0]) if i.shape == (1,) else int(i)
self.set_value_shape(node.output[0], f"{y}[{ii}]")
node.doc_string += "#IG4"
return True
if isinstance(y, tuple) and isinstance(i, int):
self.set_value_shape(node.output[0], y[i])
node.doc_string += "#IG5"
return True
if (
isinstance(y, tuple)
and isinstance(i, np.ndarray)
and i.dtype == np.int64
and i.shape in ((1,), tuple())
):
ii = int(i[0]) if i.shape == (1,) else int(i)
assert ii < len(y), (
f"Unexpected value for y={y!r}, i={i!r} in node Gather "
f"with inputs={node.input}{self.get_debug_msg()}"
)
self.set_value_shape(node.output[0], y[ii])
node.doc_string += "#IG6"
return True
raise RuntimeError(
f"Not implemented when node Gather with inputs={node.input}, "
f"y={y!r}, i={i!r}{self.get_debug_msg()}"
)
node.doc_string += "#IG7"
return False
values = [self.value_as_shape(x) for x in node.input[0]]
if any(x is None for x in values):
# it is not a shape
node.doc_string += "#IZ"
return False
if node.op_type == "Concat":
node.doc_string += "#IC1"
self.set_shape_value(node.output[0], tuple(values))
return True
raise RuntimeError(
f"Unable to compute a shape for node {node.op_type!r} "
f"with inputs={node.input}{self.get_debug_msg()}"
)
def is_dynamic_dimension(
self, dim: Any, verify: bool = True, allow_none: bool = False
) -> bool:
if allow_none and dim is None:
return True
if not isinstance(dim, (int, self.torch.SymInt, str)):
return False
if str(dim) in self._dynamic_alias:
dim = self._dynamic_alias[str(dim)]
assert (
not verify
or is_static_dimension(dim)
or str(dim) in self.dynamic_objects
or str(dim) in self.dynamic_objects_rev
or self.has_name(str(dim))
or self.parse_dimension_expression(dim)
), (
f"dim={dim!r} (type={type(dim)}) not in found in "
f"{self.dynamic_objects}, self.dynamic_shapes={self.dynamic_shapes}, "
f"self._dynamic_alias={self._dynamic_alias}{self.get_debug_msg()}"
)
return True
def get_dynamic_dimension(self, dim: Any, keep_const: bool = True) -> Any:
if isinstance(dim, int):
if keep_const:
return np.array([dim], dtype=np.int64)
return self.make_initializer("", np.array([dim], dtype=np.int64))
assert isinstance(
dim, (str, self.torch.SymInt)
), f"Unexpected type {type(dim)} for dim={dim}{self.get_debug_msg()}"
if isinstance(dim, str):
if self.has_name(dim):
return dim
assert dim in self.dynamic_objects, (
f"Unable to find a dynamic object for {dim:r}, "
f"list={list(self.dynamic_objects)}"
f"{self.get_debug_msg()}"
)
assert dim in self.dynamic_dimensions_source, (
f"Unable to find a result to express dim {dim:r}, "
f"sources={list(self.dynamic_dimensions_source)}"
f"{self.get_debug_msg()}"
)
raise NotImplementedError(
f"Source is available for {dim!r}, "
f"source={self.dynamic_dimensions_source[dim]}"
)
name = self._torch_sym_int_to_str(dim)
assert name, f"Unable to expression a dynamic dimension{self.get_debug_msg()}"
if self.has_name(name):
return name
assert name in self.dynamic_objects, (
f"Unable to find a dynamic object for {name:r}, "
f"list={list(self.dynamic_objects)}"
f"{self.get_debug_msg()}"
)
assert name in self.dynamic_dimensions_source, (
f"Unable to find a result to express dim {name:r}, "
f"sources={list(self.dynamic_dimensions_source)}"
f"{self.get_debug_msg()}"
)
raise NotImplementedError(
f"Source is available for {dim!r}, source={self.dynamic_dimensions_source[dim]}"
)
def _get_dynamic_dimension(self, name: str, dim: int) -> Optional[str]:
if self.dynamic_shapes is None:
return None
if not self.dynamic_shapes or name not in self.dynamic_shapes:
return None
dyn = self.dynamic_shapes[name]
if dim not in dyn:
return None
v = dyn[dim]
st = str(type(v))
if "_Dim" in st or "_DerivedDim" in st:
name = v.__name__
else:
name = v
return name
[docs]
def add_dynamic_object(
self,
key: str,
value: Any,
name: Optional[str] = None,
dim: Optional[int] = None,
parse: bool = False,
):
"""
Registers a dynamic object such as a dynamic dimension.
:param key: string
:param value: SymInt, Dim, _DerivedDim
:param name: input name it comes from
:param dim: dimension for this dimension in input
:param parse: parse the expression add pieces of it as well
"""
assert not isinstance(
value, self.torch.export.dynamic_shapes._Dim
), f"Unexpected dimension type {type(value)} for key={key!r}{self.get_debug_msg()}"
self.dynamic_objects[key] = (
self.WrapSym(value)
if isinstance(value, (self.torch.SymInt, self.torch.SymFloat))
else value
)
if name is not None and name in self.dynamic_shapes and dim is not None:
dyn_shape = self.dynamic_shapes[name]
if dim in dyn_shape:
dyndim = dyn_shape[dim]
self._dynamic_alias[key] = dyndim.__name__
if parse:
self.register_dynamic_objects_from_dim(key)
else:
tokens = parse_expression_tokens(key)
for t in tokens:
if isinstance(t, str):
assert t in self.dynamic_objects, (
f"Token {t!r} from {key!r} is not registered "
f"among {list(self.dynamic_objects)}{self.get_debug_msg()}"
)
[docs]
def has_dynamic_object(self, name: str) -> bool:
"""Tells if a result is a dynamic object, `torch.SymInt` for torch."""
return name in self.dynamic_objects
def add_dynamic_objects_rev(self, key, name_value):
if key not in self.dynamic_objects_rev:
self.dynamic_objects_rev[key] = []
self.dynamic_objects_rev[key].append(name_value)
def _add_dynamic_example(self, name: str, value: Union[int, float]):
if name not in self._dynamic_examples:
self._dynamic_examples[name] = set()
self._dynamic_examples[name].add(value)
def _torch_sym_int(self, d, add: bool = False) -> Optional[Union[int, str, float]]:
assert isinstance(
d, (self.torch.SymInt, str, self.torch.SymFloat)
), f"unexpected type for d={d}, type={type(d)}"
value = None
try:
# don't use 'expr'
dyn_val = str(d.node._expr)
value = dyn_val
except AttributeError:
pass
if isinstance(d, (str, int, float)):
return d
if value is None:
# Is it an integer?
value = d.node.maybe_as_int()
else:
# maybe an expression which is a single integer
try:
return int(value)
except (TypeError, ValueError):
pass
if isinstance(value, int):
assert not isinstance(value, self.torch.SymInt)
return value
if value is None:
value = str(d)
if value in self._dynamic_alias:
value = self._dynamic_alias[value]
if value not in self.dynamic_objects_rev:
# The dynamic dimension does not seem to be registered.
# Maybe it is constant.
try:
val_int = d.node.maybe_as_int()
if val_int is not None:
self._add_dynamic_example(value, val_int)
return value
except (TypeError, ValueError):
pass
if value in self.dynamic_objects:
assert not isinstance(value, self.torch.SymInt)
return value
if (
value not in self.dynamic_objects_rev
and value not in self._known_value_shape
and add
):
self.add_dynamic_object(value, value)
self.add_dynamic_object_rev(value, value)
assert value in self.dynamic_objects_rev or value in self._known_value_shape, (
f"value={value!r}, unable to find dimension {d!r} ({type(d)}) "
f"(str(d)={str(d)!r}) in {self.dynamic_objects_rev} "
f"or {self._dynamic_alias} or {self._known_value_shape}"
f"{dir(d)}{self.get_debug_msg()}"
)
assert not isinstance(
value, self.torch.SymInt
), f"Unexpected type {type(value)} for d={d!r}"
if value in self.dynamic_objects_rev:
new_value = self.dynamic_objects_rev[value]
assert isinstance(new_value, list), (
f"Unexpected type {type(new_value)} for value={value!r}, d={d}"
f"{self.get_debug_msg()}"
)
assert len(new_value) >= 1, (
f"Unexpected number of items in {new_value}, value={value!r}, d={d}"
f"{self.get_debug_msg()}"
)
# We assume if len(new_value) > 1 that all names are equivalent.
# The graph is doing the same computation multiple times.
final = new_value[0]
assert isinstance(final, str) or (isinstance(final, tuple) and len(final) == 2), (
f"Unexpected type {type(final)}, final={final}, value={value}, d={d}"
f"new_value={new_value}, {self.get_debug_msg()}"
)
if isinstance(final, str):
# A formula
return final
# An alias
name = final[0]
assert isinstance(name, str), (
f"Unexpected type {type(name)}, name={final}, value={value}, d={d}, "
f"new_value={new_value}, {self.get_debug_msg()}"
)
return name
# Its value is in self._known_value_shape. We still return its name.
return value
[docs]
def verify_dynamic_shape(
self,
shape: Any,
name: Optional[str] = None,
add: bool = True,
) -> DYNAMIC_SHAPE:
"""
The implementation of this method should be revisited.
"""
if shape is None:
return None
if is_static_shape(shape):
return tuple(int(i) for i in shape)
new_shape = []
for dim, d in enumerate(shape):
if isinstance(d, (self.torch.SymInt, str)):
dyn_name = self._get_dynamic_dimension(name, dim)
if dyn_name is not None:
if add:
self.add_dynamic_object(dyn_name, dyn_name, parse=True)
new_shape.append(dyn_name)
continue
value = self._torch_sym_int(d, add=add)
assert (
value is not None
), f"Unexpected type {type(d)} in shape={shape}{self.get_debug_msg()}"
new_shape.append(value)
continue
if (
isinstance(d, self.torch.export.dynamic_shapes._Dim)
or "_DerivedDim" in str(d)
or "_Dim" in str(d)
):
raise NotImplementedError(
f"verify_dynamic_shape not yet implemented for type(d)={type(d)}, d={d}"
)
if isinstance(d, int):
new_shape.append(d)
continue
assert (
d is not None
), f"Unexpected type {type(d)} in shape={shape}{self.get_debug_msg()}"
assert all(isinstance(d, (int, str)) for d in new_shape), (
f"Issue with shape={new_shape}, types={[type(d) for d in new_shape]}"
f"{self.get_debug_msg()}"
)
return tuple(new_shape)
[docs]
def register_users(self, name: str, users: Iterable[str]):
"""
Registers users. This is not used except to check the conversion
is valid.
"""
assert (
name not in self._registered_users
), f"{name!r} is already registered{self.get_debug_msg()}"
self._registered_users[name] = set(users)
def _dynamic_to_str(
self, obj: Any, exc: bool = True, register_if_not_exist: bool = False
) -> Optional[str]:
if register_if_not_exist:
res = self._dynamic_to_str(obj, exc, register_if_not_exist=False)
if res is None:
return res
if res not in self.dynamic_objects:
self.add_dynamic_object(res, res, parse=register_if_not_exist)
return res
if isinstance(obj, str):
return obj
if isinstance(obj, self.torch.export.dynamic_shapes._DerivedDim):
return obj.__name__
if isinstance(obj, self.torch.export.dynamic_shapes._Dim):
return obj.__name__
if isinstance(obj, self.torch.SymInt):
i = obj.node._expr
if "sympy" in str(type(i)):
return str(i)
if exc:
raise AssertionError(
f"Object has {type(obj)} but could not find a dynamic interpretation"
)
return None
raise AssertionError(f"Unexpected type {type(obj)} to convert into string")
def _is_dynamic_dimension(self, dyn: Any) -> bool:
return isinstance(
dyn,
(
str,
self.torch.export.dynamic_shapes._DerivedDim,
self.torch.export.dynamic_shapes._Dim,
self.torch.SymInt,
),
)
def _fill_dynamic_alias(
self, dyn_shape: Optional[Tuple[Any, ...]], name: str
) -> Optional[Tuple[Any, ...]]:
if dyn_shape is None:
return None
res = []
for pos, k in enumerate(dyn_shape):
if not self._is_dynamic_dimension(k):
res.append(None)
continue
alias = self._dynamic_to_str(k, exc=False)
if alias is None:
res.append(None)
continue
if not self.dynamic_shapes or name not in self.dynamic_shapes:
res.append(None)
continue
ds = self.dynamic_shapes[name]
if pos not in ds:
res.append(None)
continue
dim = ds[pos]
sdim = self._dynamic_to_str(dim, register_if_not_exist=True)
if sdim != alias:
self._dynamic_alias[alias] = sdim
res.append(sdim)
return tuple(res)
def _make_tensor_input_finalize(self, name: str, shape: Any, dyn_shape: Any):
tuple_shape = tuple(shape)
assert len(tuple_shape) == len(
dyn_shape
), f"mismatch between shape={shape}, dynamic_shape={dyn_shape}"
for _idim, (a, b) in enumerate(zip(tuple_shape, dyn_shape)):
if isinstance(a, int) and isinstance(b, int):
assert a == b, (
f"Unexpected shape mismatch shape={shape}, "
f"dyn_shape={dyn_shape}{self.get_debug_msg()}"
)
continue
if isinstance(a, str) and isinstance(b, str):
assert a == b, (
f"Unexpected shape mismatch shape={shape}, "
f"dyn_shape={dyn_shape}{self.get_debug_msg()}"
)
sb = b
if sb not in self.dynamic_objects:
self.add_dynamic_object(sb, sb)
continue
self._dynamic_to_str(b, register_if_not_exist=True)
self._dynamic_to_str(a, register_if_not_exist=True)
[docs]
def make_tensor_output(
self,
name: Union[str, List[str]],
elem_type: Optional[int] = None,
shape: Optional[STATIC_SHAPE] = None,
indexed: bool = True,
is_dimension: Optional[bool] = None,
) -> Union[str, List[str]]:
"""
Adds a tensor output to the onnx graph.
:param name: name
:param elem_type: element type
:param shape: shape
:param indexed: the name must be indexed?
:param is_dimension: torch is using torch.SymInt to add a dynamic input
to the graph
:return: output name
"""
assert is_dimension is not None, (
f"is_dimension must be specified for output name={name!r}, "
f"elem_type={elem_type}, shape={shape!r}."
)
if isinstance(name, list):
assert not is_dimension, f"name={name!r} not compatible with is_dimension=True"
res = []
for n in name:
res.append(self.make_tensor_output(n, elem_type, shape))
assert self.has_name(
n
), f"Output {n!r} among {name} not found{self.get_debug_msg()}"
return res
assert (
not indexed or "_" in name
), f"Name {name!r} is not indexed like 'output_0'{self.get_debug_msg()}"
assert (is_dimension and "_dim_" in name) or (
not is_dimension and "_dim_" not in name
), (
f"Inconsistence for input {name!r}, "
f"elem_type={elem_type}, shape={shape!r}, "
f"is_dimension={is_dimension}"
)
elem_type = _get_type(elem_type, False)
if not self.as_function and elem_type == 0:
raise RuntimeError(f"Undefined element type for {name!r}.")
dyn_shape = self.verify_shape(shape, name=name, elem_type=elem_type)
node = oh.make_tensor_value_info(name, elem_type, dyn_shape)
node.doc_string += ".\n" + self._info_shape_type([name]) + "\n"
self.outputs.append(node)
assert self.has_name(name), f"Output {name!r} not found{self.get_debug_msg()}"
if self.verbose > 1:
print(
f"[GraphBuilder-{self._hash()}.make_tensor_output] "
f"{name}[{elem_type}: {dyn_shape}]"
)
if dyn_shape:
self.set_shape(name, dyn_shape)
if elem_type:
self.set_type(name, elem_type)
return name
[docs]
def select_outputs(self, output_names: List[str]):
"""
Selects new outputs. The type is assumed to be unknown.
The method only wipes out the outputs to replace them by
others. It assumes the unused nodes are removed afterwards.
:param output_names: new outputs
"""
new_outputs = []
for name in output_names:
if self.has_type(name) and self.has_rank(name):
if self.has_shape(name):
out = oh.make_tensor_value_info(
name, self.get_type(name), self.get_shape(name)
)
if self.verbose > 1:
print(
f"[GraphBuilder-{self._hash()}.make_tensor_output] "
f"{name}[{self.get_type(name)}:R{self.get_shape(name)}]"
)
else:
out = oh.make_tensor_value_info(
name, self.get_type(name), [None] * self.get_rank(name)
)
if self.verbose > 1:
print(
f"[GraphBuilder-{self._hash()}.make_tensor_output] "
f"{name}[{self.get_type(name)}:R{self.get_rank(name)}]"
)
else:
out = oh.make_value_info(name, TypeProto())
if self.verbose > 1:
print(f"[GraphBuilder-{self._hash()}.make_tensor_output] {name}")
new_outputs.append(out)
self.outputs = new_outputs
def verify_shape(
self,
shape: Optional[DYNAMIC_SHAPE],
elem_type: int,
name: Optional[str] = None,
) -> Optional[DYNAMIC_SHAPE]:
assert isinstance(
elem_type, int
), f"elem_type must be an integer not {type(elem_type)}"
assert shape is None or isinstance(
shape, tuple
), f"Shape must be a tuple not {type(shape)}"
if shape is None:
return None
assert is_static_shape(shape) or self.is_dynamic_shape(shape, allow_none=True), (
f"Shape={shape} is not a shape (type={[type(i) for i in shape]}), "
f"name={name!r}, elem_type={elem_type}{self.get_debug_msg()}"
)
new_shape = self.verify_dynamic_shape(shape, name=name)
return new_shape
def _get_symbol(self, i: str) -> str:
k = 0
if self.has_type(i):
k += 1
if self.has_rank(i):
k += 2
if self.has_shape(i):
k += 4
return k
def _debug_string_inputs(
self, inputs: List[str], outputs: List[str], align: Optional[int] = None
) -> str:
"""
Meaning:
- ``"-"``: (0) none
- ``"T"``: (1) type
- ``"R"``: (2) rank
- ``"U"``: (3) rank + type
- ``"S"``: (4) shape
- ``"V"``: (5) shape + type
- ``"W"``: (6) shape + rank
- ``"#"``: (7) shape + type + rank
"""
st = ""
c = "-TRUSVW#"
for i in inputs:
st += c[self._get_symbol(i)]
st += ":"
for o in outputs:
st += c[self._get_symbol(o)]
if align and len(st) < align:
st += " " * (align - len(st))
return st
def _check_op_type(
self,
op_type: str,
inputs: List[str],
outputs: List[str],
domain: str,
name: str,
attributes: Optional[List[AttributeProto]] = None,
**kwargs: Dict[str, Any],
):
assert (
not op_type.startswith("Reduce")
or domain != ""
or (len(inputs) == 2 and "axes" not in kwargs)
or len(inputs) == 1
), (
f"Operator {op_type!r} defines twice the axes, kwargs={kwargs}, "
f"len(inputs)={len(inputs)}{self.get_debug_msg()}"
)
assert (
op_type != "Cast"
or domain != ""
or ("to" in kwargs and kwargs["to"] is not None)
or (attributes is not None and "to" in set(att.name for att in attributes))
), (
f"Operator Cast needs arguments to but kwargs={kwargs}, name={name!r}"
f"{self.get_debug_msg()}"
)
assert op_type != "Concat" or domain != "" or len(inputs) > 1, (
f"Concatenation of zero or one input is not necessary, "
f"len(inputs)={len(inputs)}{self.get_debug_msg()} "
)
if self.main_opset <= 11:
assert op_type != "Squeeze" or domain != "" or len(inputs) == 1, (
f"Operator Squeeze is not correclty specified for opset "
f"{self.main_opset}, inputs={inputs}, kwargs={kwargs}, "
f"atts={attributes}{self.get_debug_msg()}"
)
else:
n_entries = len(inputs) + len(attributes or []) + len(kwargs)
assert op_type != "Squeeze" or domain != "" or n_entries in (1, 2), (
f"Operator Squeeze is not correclty specified for opset "
f"{self.main_opset}, n_entries={n_entries}, "
f"inputs={inputs}, kwargs={kwargs}, "
f"atts={attributes}{self.get_debug_msg()}"
)
assert (
op_type not in {"NegXplus1", "ReplaceZero"} or domain != ""
), f"Type={op_type!r} and domain {domain!r} mismatch{self.get_debug_msg()}"
[docs]
def do_not_remove(self, node: NodeProto) -> bool:
"""Tells if a node should be removed or not."""
return node.name.startswith("_DONOTREMOVE_")
[docs]
def make_node(
self,
op_type: str,
inputs: Union[str, List[str]],
outputs: Union[int, List[str], str] = 1,
domain: str = "",
attributes: Optional[List[AttributeProto]] = None,
check: Optional[bool] = None,
name: Optional[str] = None,
sts: Optional[Dict[str, Any]] = None,
do_not_remove: bool = False,
insert_position: Optional[int] = None,
**kwargs,
) -> Union[str, List[str]]:
"""
Adds a node in the graph.
:param op_type: operator type
:param inputs: input names
:param outputs: output names, may be None, in that case,
the builder chooses them for the user
:param domain: domain
:param attributes: list of attributes to add as AttributeProto
:param check: do some verification
:param name: node name
:param sts: if not specified, tries to set the shape and the type of
the new results aftr the node is added, it is not possible
for every node, there is no tool which determines the output shape
of just one node
:param do_not_remove: prevent this node from being removed
:param insert_position: insert the node at the end (None) or
at the top (HEAD).
:param kwargs: additional attributes to add the node
:return: output names
"""
assert name is not None and not name.startswith("None"), (
f"It is good practice to give every node a name so that is "
f"easier to see where this node is created but name={name!r} "
f"and op_type={op_type!r}."
)
assert (
not kwargs or not attributes
), f"Only attributes or kwargs can be filled for node {op_type!r}."
if isinstance(inputs, tuple):
inputs = list(inputs)
if isinstance(outputs, int):
if outputs < 1:
raise ValueError(f"outputs={outputs} must be > 0.")
assert isinstance(op_type, str), f"Unexpected type {type(op_type)}: {op_type}"
lower = op_type.lower()
output_names = [self.unique_name(f"_onx_{lower}{i}") for i in range(outputs)]
elif isinstance(outputs, str):
output_names = [outputs]
else:
output_names = outputs
if isinstance(inputs, str):
inputs = [inputs]
inputs, kwargs = self._partial_rewrite_opset_version(
op_type, inputs, kwargs, domain=domain
)
if self.verbose == 2:
print(
f"[GraphBuilder-{self._hash()}.make_node]"
f"[{self._debug_string_inputs(inputs, output_names)}] "
f"{op_type}: {inputs}->{outputs}"
)
if check is not False:
for i in inputs:
assert isinstance(i, str), (
f"Unexpected type {type(i)} in {inputs}, op_type={op_type!r}, "
f"name={name!r}, {self.get_debug_msg()}"
)
if i == "":
# Optional input.
continue
assert self.has_name(i), (
f"Input {i!r} does not exist for operator {op_type!r}, "
f"inputs={inputs}, outputs={outputs}, name={name!r} "
f"({self._hash()}){self.get_debug_msg()}"
)
for i in output_names:
assert isinstance(
i, str
), f"Unexpected type {type(i)} in {output_names}{self.get_debug_msg()}"
if i == "":
# Optional output.
continue
assert not self.has_name(i), (
f"Output {i!r} already exists for operator {op_type!r} "
f"({self._hash()}){self.get_debug_msg()}"
)
if check is True:
for i in inputs:
assert self.has_shape(i), f"Input {i!r} has no known shape."
assert self.has_type(i), f"Input {i!r} has no known type."
if do_not_remove:
name = self.unique_node_name(f"_DONOTREMOVE_{name or ''}")
elif name:
name = self.unique_node_name(name)
self._check_op_type(
op_type,
inputs,
outputs,
domain=domain,
name=name,
attributes=attributes,
**kwargs,
)
# break?
# if op_type == "ReduceSum":
# raise AssertionError(f"MANUAL BREAK{self.get_debug_msg()}")
# next
try:
node = oh.make_node(
op_type, inputs, output_names, domain=domain, name=name, **kwargs
)
except TypeError as e:
iti = [type(i) for i in inputs]
ito = [type(o) for o in outputs] if isinstance(outputs, (tuple, list)) else outputs
raise TypeError(
f"A node {op_type!r} cannot be created with "
f"inputs={inputs} (types={iti}), outputs={outputs} (types={ito}), "
f"domain={domain!r}, kwargs={kwargs}."
) from e
assert len(node.output) == len(set(node.output)) or "" in node.output, (
f"Repeated outputs for node {node.op_type}({', '.join(node.input)}) -> "
f"{', '.join(node.output)}"
)
if attributes:
node.attribute.extend(attributes)
if node.domain == "" and node.op_type in {"Constant", "ConstantOfShape"}:
if len(node.attribute) == 1 and node.attribute[0].name == "value":
t = node.attribute[0].t
size = np.prod(t.dims)
assert size < self.optimization_options.constant_size, (
f"A node Constant is created with a size {size} greater than "
f"the limit {self.optimization_options.constant_size}"
f"{self.get_debug_msg()}"
)
# A exact constant may be already existing,
# In that case, we just return an identity node.
origin = self.is_exact_same_constant(node)
if origin is not None:
if self.verbose > 2:
print(
f"[GraphBuilder-{self._hash()}.make_node] "
f"duplicated constant detected for "
f"{node.op_type}: {node.input}->{node.output}"
)
node = oh.make_node(
"Identity",
[origin.output[0]],
[node.output[0]],
name=self.unique_node_name(".make_node"),
)
self.constants_alias_[node.output[0]] = origin.output[0]
else:
self.add_constant_node(node)
# constant handling, shape, type
self._make_node_set_type_shape_constant(node, sts=sts)
if self.verbose > 3:
print(
f"[GraphBuilder-{self._hash()}.make_node] "
f"[{self._debug_string_inputs(node.input, output_names)}] "
f"{node.op_type}: {node.input}->{node.output}"
)
# shape inference
shape_set = self.simple_update_value_shape_with_node(node)
# add the node
for o in node.output:
if o == "":
continue
self.set_name(o, f"make_node_{op_type}_{o}")
if insert_position == "HEAD":
self.nodes.insert(0, node)
else:
self.nodes.append(node)
if not shape_set:
# second try
self._make_node_set_type_shape(node)
node.doc_string += ".\n" + self._info_shape_type(node.output) + "\n"
if len(output_names) == 1:
return output_names[0]
return output_names
def _info_shape_type(self, outputs: List[str]) -> str:
rows = []
for o in outputs:
st = f"-T{self.get_type(o)}:" if self.has_type(o) else "T?:"
if self.has_shape(o):
st += "x".join(map(str, self.get_shape(o)))
elif self.has_rank(o):
st += f"R{self.get_rank(o)}"
rows.append(st)
return "/".join(rows)
@property
def last_added_node(self):
return self.nodes[-1] if self.nodes else None
def _partial_rewrite_opset_version(
self, op_type: str, inputs: List[str], kwargs: Dict[str, Any], domain: str
) -> Tuple[List[str], Dict[str, Any]]:
if domain == "":
opset = self.opsets[""]
if op_type == "Unsqueeze":
if opset < 13 and len(inputs) == 2:
assert isinstance(
inputs[1], (list, np.ndarray)
), f"Unexpected type for axis={inputs[1]} and operator Unsqueeze"
kwargs["axes"] = (
[inputs[1]] if isinstance(inputs[1], int) else inputs[1].tolist()
)
inputs = inputs[:1]
elif opset >= 13 and len(inputs) == 1:
name = self.make_initializer("", np.array(kwargs["axes"], dtype=np.int64))
inputs.append(name)
del kwargs["axes"]
return inputs, kwargs
def _make_node_set_type_shape_constant(
self, node: NodeProto, sts: Optional[Dict[str, Any]]
):
if node.domain != "":
return
if all(self.is_constant(i) for i in node.input):
for o in node.output:
self.update_node_constant(o, node)
if node.op_type == "Constant":
assert (
len(node.attribute) == 0
or node.attribute[0].name != "value"
or node.attribute[0].type != AttributeProto.GRAPH
), f"{node}"
if len(node.attribute) == 1 and node.attribute[0].name == "value":
size = np.prod(node.attribute[0].t.dims)
else:
size = len(node.SerializeToString())
assert size < self.optimization_options.constant_size, (
f"A node Constant holds a tensor bigger than "
f"the constant: {size} >= {self.optimization_options.constant_size}."
)
k = node.output[0]
self.update_node_constant(k, node)
node.doc_string += ":constant-3:"
shape = self._get_tensor_shape(node)
dtype = self._get_tensor_type(node)
self.set_shape(k, shape)
self.set_type(k, dtype)
if self.verbose > 2 or np.prod(shape) > 100:
print(f"[GraphBuilder-{self._hash()}.make_node] {k}[{dtype}: {shape}]")
elif node.op_type == "ConstantOfShape":
if len(node.attribute) == 1 and node.attribute[0].name == "value":
itype = node.attribute[0].t.data_type
else:
itype = TensorProto.FLOAT
self.set_type(node.output[0], itype)
if self.is_constant(node.input[0]):
value = self.get_constant(
node.input[0], computed_value=True, as_shape=True, exc=False
)
if value is not None:
self.set_shape(node.output[0], value)
node.doc_string += ":constant-9:"
elif node.op_type == "Identity":
if self.has_shape(node.input[0]):
self.set_shape(node.output[0], self.get_shape(node.input[0]))
elif self.has_rank(node.input[0]):
self.set_rank(node.output[0], self.get_rank(node.input[0]))
if self.has_type(node.input[0]):
self.set_type(node.output[0], self.get_type(node.input[0]))
if self.is_constant(node.input[0]):
self.update_node_constant(node.output[0], node)
node.doc_string += ":constant-4:"
elif node.op_type == "Expand":
if self.has_type(node.input[0]):
self.set_type(node.output[0], self.get_type(node.input[0]))
if (
self.has_shape(node.input[0])
and is_static_shape(self.get_shape(node.input[0]))
and self.is_constant(node.input[1])
):
cst, _ = self.compute_constant(node.input[1], exc=False, only_array=True)
if cst is not None:
assert not isinstance(
cst, self.torch._subclasses.fake_tensor.FakeTensor
), (
f"self.compute_constant returns a FakeTensor for {node.input[1]!r}"
f"\n{self.pretty_text()}"
)
assert (
not self.has_rank(node.input[1]) or self.get_rank(node.input[1]) == 1
), (
f"Unexpected rank {self.get_rank(node.input[1])} for {node.input[1]!r}"
f"{self.get_debug_msg()}"
)
assert len(cst.shape) == 1 and cst.min() > 0, (
f"Unexpected shape {cst.shape} "
f"for computed constant {node.input[1]!r}, "
f"cst={cst}{self.get_debug_msg()}"
)
shape = self.get_shape(node.input[0])
new_shape = tuple(int(i) for i in cst)
if len(shape) < len(new_shape):
shape = (1,) * (len(new_shape) - len(shape)) + shape
self.set_shape(
node.output[0], tuple(max(i, j) for i, j in zip(shape, new_shape))
)
elif node.op_type == "Reshape":
if self.has_type(node.input[0]):
self.set_type(node.output[0], self.get_type(node.input[0]))
if self.is_constant(node.input[1]):
cst, _ = self.compute_constant(
node.input[1], exc=False, only_array=True, allow_empty=True
)
if cst is not None:
shape_cst = tuple(int(i) for i in cst)
if -1 in shape_cst:
if self.has_shape(node.input[0]):
sh = self.get_shape(node.input[0])
if is_static_shape(sh):
self.set_shape(node.output[0], _reshape_shape(sh, shape_cst))
node.doc_string += ":constant-7:"
else:
self.set_shape(node.output[0], shape_cst)
node.doc_string += ":constant-7:"
elif node.op_type == "Shape":
self.set_type(node.output[0], TensorProto.INT64)
if self.has_shape(node.input[0]) and len(node.attribute) == 0:
shape = self.get_shape(node.input[0])
self.set_shape(node.output[0], (len(shape),))
else:
self.set_rank(node.output[0], 1)
if self.is_constant(node.input[0]):
self.update_node_constant(node.output[0], node)
node.doc_string += ":constant-2:"
elif node.op_type == "Size":
self.set_type(node.output[0], TensorProto.INT64)
self.set_shape(node.output[0], (1,))
if self.is_constant(node.input[0]):
self.update_node_constant(node.output[0], node)
node.doc_string += ":constant-2s:"
elif not sts:
if node.op_type == "GatherElements":
if self.has_type(node.input[0]):
self.set_type(node.output[0], self.get_type(node.input[0]))
if self.has_shape(node.input[1]):
self.set_shape(node.output[0], self.get_shape(node.input[1]))
elif self.has_rank(node.input[1]):
self.set_rank(node.output[0], self.get_rank(node.input[1]))
[docs]
def update_node_constant(self, name: str, node: NodeProto) -> bool:
"""
Updates a constant NodeProto.
"""
assert isinstance(name, str), f"Unexpected type {type(name)} for name"
assert node is None or isinstance(
node, NodeProto
), f"Unexpected type {type(node)} for name={name!r}"
if self.verbose > 2:
print(
f"[GraphBuilder.update_node_constant] new constant "
f"{name!r}, node={None if node is None else node.op_type}"
)
self.constants_[name] = node
return True
[docs]
def get_attribute(
self, node: NodeProto, att_name: str, exc: bool = True
) -> Optional[AttributeProto]:
"""
Returns an attribute for a node.
"""
for att in node.attribute:
if att.name == att_name:
return att
assert not exc, (
f"Unable to find attribute {att_name!r} for node "
f"type {node.op_type!r} in node {node}"
)
return None
[docs]
def get_attributes_with_default(self, node: NodeProto, **default_values) -> Dict[str, Any]:
"""
Returns int or float attributes. If missing, the default value is returned.
:param node: node
:param default_values: default values
"""
res = {}
for att in node.attribute:
if att.name in default_values:
def_val = default_values[att.name]
if isinstance(def_val, int):
res[att.name] = att.i
elif isinstance(def_val, float):
res[att.name] = att.f
elif isinstance(def_val, str):
res[att.name] = att.s
else:
raise TypeError(
f"Unexpected type {type(def_val)} for attribute name {att.name!r}, "
f"attribute={att}"
)
for k, v in default_values.items():
if k not in res:
res[k] = v
return res
def _make_node_set_type_shape(self, node: NodeProto):
if node.domain != "":
node.doc_string += "#Io1"
set_shape_type_custom(self, node)
else:
if node.input and not self.has_type(node.input[0]):
# It is probably coming from an inlined function.
return
node.doc_string += "#Io2"
set_shape_type_op_any(self, node)
[docs]
def make_nodes(
self,
builder: "GraphBuilder",
input_names: List[str],
output_names: List[str],
prefix: str = "",
function_options: Optional[FunctionOptions] = None,
optimize: bool = False,
) -> Union[str, List[str]]:
"""
Appends all nodes and initializers from another builder.
Handles the renaming of results.
The content stored in 'builder' is modified inplace to avoid copying.
:param builder: other builder
:param input_names: input names
:param output_names: output names
:param prefix: prefix all name from this builder if `function_options` is None
:param function_options: defines how to create a local function if needed
:param optimize: optimize the function
:return: output names
"""
if function_options is not None and function_options.export_as_function:
new_inits, (fdomain, fname) = self.make_local_function(
builder,
function_options=function_options,
optimize=optimize,
)
self.make_node(
fname,
[*input_names, *new_inits],
output_names,
domain=fdomain,
check=False,
name=fname,
)
# Shape information, needs to handle multiple outputs
# hopefully, the interpreter fills this information with what it knows
# fproto = self.functions[fdomain, fname]
# for o, no in zip(fproto.output, output_names):
# if builder.has_shape(o):
# shape = builder.get_shape(o)
# if None in shape:
# self.set_rank(no, len(shape))
# else:
# self.set_shape(no, shape)
# if builder.has_type(o):
# self.set_type(no, builder.get_type(o))
if fdomain not in self.opsets:
self.opsets[fdomain] = 1
else:
renaming = {}
for init, value in builder.initializers_dict.items():
if init in builder._parameter_renaming:
# Its copy already exists.
continue
if init in builder._parameter_norename:
name = init
assert not self.has_name(init), (
f"Parameter {init!r} must be renamed as another one "
f"already exists{self.get_debug_msg()}"
)
else:
name = self.unique_name(f"{prefix}{init}")
renaming[init] = name
if isinstance(value, TensorProto):
value.name = name
else:
assert "FakeTensor" not in str(type(value)), (
f"FakeTensor {name!r} cannot be an initializer {type(value)}"
f"{self.get_debug_msg()}"
)
self.add_initializer(
name,
value,
itype=builder._known_types[init],
shape=builder._known_shapes[init],
)
for k, v in builder.dynamic_objects.items():
self.make_dynamic_object(k, v)
assert len(input_names) == len(
builder.inputs
), f"Inconsistency between input_names={input_names} and inputs={builder.inputs}"
for name, inp in zip(input_names, builder.inputs):
new_name = self.unique_name(f"{prefix}{inp.name}")
renaming[inp.name] = new_name
self.make_node("Identity", [name], [new_name], name=".make_nodes")
for node in builder.nodes:
assert name is not None and not name.startswith("None"), (
f"It is good practice to give every node a name so that is "
f"easier to see where this node is created but name={name!r} "
f"and op_type={node.op_type!r}."
)
new_inputs = [renaming[i] for i in node.input]
new_outputs = [self.unique_name(f"{prefix}{o}") for o in node.output]
for o, no in zip(node.output, new_outputs):
renaming[o] = no
self.make_node(
node.op_type,
new_inputs,
new_outputs,
domain=node.domain,
attributes=node.attribute,
check=False,
name=node.name,
)
for o, no in zip(node.output, new_outputs):
if builder.has_shape(o):
shape = builder.get_shape(o)
if None in shape:
self.set_rank(no, len(shape))
else:
self.set_shape(no, shape)
if builder.has_type(o):
self.set_type(no, builder.get_type(o))
assert len(output_names) == len(builder.outputs), (
f"Inconsistency between output_names={output_names} and "
f"outputs={builder.outputs}, renaming={renaming}{self.get_debug_msg()}"
)
for name, out in zip(output_names, builder.outputs):
self.make_node("Identity", [renaming[out.name]], [name], name=".make_nodes")
# opsets and domains
for o, v in builder.opsets.items():
if o in self.opsets:
assert self.opsets[o] == builder.opsets[o], (
f"Opset mismatch for domain {o!r}, "
f"{self.opsets[o]} != {builder.opsets[o]}."
)
continue
self.opsets[o] = v
if len(output_names) == 1:
return output_names[0]
return output_names
def _build_large_initializers(
self, external_threshold: int, full_parameter_name: bool = True
):
assert isinstance(
external_threshold, int
), f"Unexpected type {type(external_threshold)} for external_threshold"
new_inits = {}
large_inits = {}
for k, v in self.initializers_dict.items():
if self._parameter_renaming and (
(full_parameter_name and k in self._parameter_renaming)
or (not full_parameter_name and k not in self._parameter_norename)
):
# Those parameters are present under another name already.
continue
itype = self.get_type(k)
shape = self.get_shape(k)
size = np.prod(shape) * self.elem_size(itype)
if size < external_threshold:
new_inits[k] = v
else:
location = f"#{k}"
nt = make_large_tensor_proto(location, k, itype, shape)
new_inits[k] = nt
large_inits[location] = v
return new_inits, large_inits
def _build_initializers(
self,
large_model: bool,
switch_low_high: bool,
external_threshold: Union[bool, int],
full_parameter_name: bool = True,
) -> Tuple[List[TensorProto], Dict[str, TensorProto]]:
"""
Builds initializers.
:param large_model: build with a large container
:param switch_low_high: invert low, high precision
:param external_threshold: size to use when moving a tensor to the list of tensors
stored outside the model, if can be False for none of them, true for all of them
or a number, if the threshold is specified and large_model is False,
then all tensors above this threshold are ignored
:param full_parameter_name: keeps the full name for the parameters or not
:return: a list of tensors to stored in the model,
another list to tensors stored outside the model
"""
if self.verbose:
begin = time.perf_counter()
print(
f"[GraphBuilder-{self._hash()}._build_initializers] "
f"start with {len(self.initializers_dict)} initializers, "
f"large_model={large_model}, external_threshold={external_threshold}"
)
init_dict, large_inits = (
self._build_large_initializers(
external_threshold, full_parameter_name=full_parameter_name
)
if large_model
else (self.initializers_dict, {})
)
if switch_low_high:
# Let's try to minimize the time.
if self.verbose:
print(
f"[GraphBuilder-{self._hash()}._build_initializers] "
f"switch low/high order"
)
initializer = []
for k, v in init_dict.items():
if self._parameter_renaming and (
(full_parameter_name and k in self._parameter_renaming)
or (not full_parameter_name and k not in self._parameter_norename)
):
# Those parameters are present under another name already.
continue
if isinstance(v, TensorProto):
if self.verbose > 1:
print(
f"[GraphBuilder-{self._hash()}._build_initializers] "
f"TensorProto-{k}:{v.data_type}[{tuple(v.dims)}]"
)
initializer.append(v)
continue
if self.verbose > 1:
print(
f"[GraphBuilder-{self._hash()}._build_initializers] "
f"<{v.__class__.__name__}>-{k}:{v.dtype}[{v.shape}]"
)
if isinstance(v, np.ndarray):
itype = dtype_to_tensor_dtype(v.dtype)
if itype in {
TensorProto.BOOL,
TensorProto.STRING,
TensorProto.UNDEFINED,
TensorProto.COMPLEX64,
TensorProto.COMPLEX128,
getattr(TensorProto, "UINT4", 0),
getattr(TensorProto, "INT4", 0),
}:
t = onh.from_array(v, name=k)
initializer.append(t)
continue
from_np = True
elif isinstance(v, np.float32):
# This should not happen.
t = onh.from_array(np.array([v], dtype=np.float32), name=k)
initializer.append(t)
continue
elif isinstance(v, np.float16):
# This should not happen.
t = onh.from_array(np.array([v], dtype=np.float16), name=k)
initializer.append(t)
continue
else:
assert isinstance(
v, self.torch.Tensor
), f"tensor {k!r} has un unexpected type {type(v)}"
assert "FakeTensor" not in str(
type(v)
), f"tensor {k!r} cannot be a FakeTensor: {type(v)}"
from_np = False
itype = dtype_to_tensor_dtype(v.dtype)
# How to avoid a copy?
if from_np:
tensor = TensorProto()
tensor.name = k
tensor.dims.extend(v.shape)
tensor.data_type = itype
tensor.raw_data = v.tobytes()
else:
with self.maybe_disable_fake_tensor_mode():
tensor = proto_from_array(v, name=k, verbose=self.verbose)
initializer.append(tensor)
if self.verbose:
begin = time.perf_counter()
print(
f"[GraphBuilder-{self._hash()}._build_initializers] "
f"done in {time.perf_counter() - begin}s "
f"with {len(initializer)} initializers, "
f"{len(large_inits)} large initializers"
)
return initializer, large_inits
assert (
not self.large_model
), "_build_initializers not implemented when large_model is True"
large_inits = {}
res = []
for k, v in sorted(init_dict.items()):
if self._parameter_renaming and (
(full_parameter_name and k in self._parameter_renaming)
or (not full_parameter_name and k not in self._parameter_norename)
):
# Those parameters are present under another name already.
continue
if isinstance(v, TensorProto):
res.append(v)
continue
if isinstance(v, self.torch.Tensor):
# no string tensor
t = self.from_array(v, name=k)
res.append(t)
continue
if isinstance(v, np.ndarray):
if self.verbose > 2 and np.prod(v.shape) > 100:
print(
f"[GraphBuilder-{self._hash()}._build_initializers]"
f"onh.from_array:{k}:{v.dtype}[{v.shape}]"
)
t = onh.from_array(v, name=k)
res.append(t)
continue
raise TypeError(
f"Unable to convert initializer {k!r} with type "
f"{type(v)} into a TensorProto."
)
if self.verbose:
begin = time.perf_counter()
print(
f"[GraphBuilder-{self._hash()}._build_initializers] "
f"done in {time.perf_counter() - begin}s "
f"with {len(res)} initializers, "
f"{len(large_inits)} large initializers"
)
return res, large_inits
[docs]
def get_initializer_size(self, name: str) -> int:
"""
Returns the size of an initializer.
:param name: name
:return: size
"""
assert name in self.initializers_dict, f"Initializer {name!r} was not found."
init = self.initializers_dict[name]
if hasattr(init, "numel"):
# torch.Tensor
return init.numel()
if hasattr(init, "size"):
# numpy array
return init.size
if hasattr(init, "size"):
# TensorProto
return np.prod(init.dims)
raise AssertionError(f"Unexpected type {type(init)} for initializer {name!r}")
[docs]
def get_debug_msg(self, limit: int = 1000) -> str:
"""
Returns a string providing as much information as possible
to help the developper understand why a conversion failed.
:param limit: limit the string if the model is big
:return: many pieces of informations about the on going conversion
"""
def _align(s, length):
if len(s) < length:
s += " " * (length - len(s))
return s
def _dtype(t):
if hasattr(t, "dtype"):
return t.dtype
if hasattr(t, "data_type"):
return t.data_type
raise RuntimeError(f"dtype unknown for type {type(t)}-{t}.")
def _shape(t):
if hasattr(t, "shape"):
return t.dtype
if hasattr(t, "dims"):
return tuple(t.dims)
raise RuntimeError(f"dtype unknown for type {type(t)}-{t}.")
def _size(t):
if hasattr(t, "numel"):
return t.numel()
if hasattr(t, "size"):
return t.size
if hasattr(t, "dims"):
return np.prod(tuple(t.dims))
raise RuntimeError(f"Size unknown for type {type(t)}-{t}.")
def _values(t):
if hasattr(t, "detach"):
return t.detach().cpu().flatten().tolist()
if hasattr(t, "size"):
return t.ravel().tolist()
if hasattr(t, "dims"):
a = onh.to_array(t)
return a.ravel().tolist()
raise RuntimeError(f"Values unknown for type {type(t)}-{t}.")
rows = ["", "--DEBUG--"]
hs = self._hash()
rows.append(
f"[GraphBuilder-{hs}] Message starts, there are "
f"{len(self.initializers_dict)} initializers, "
f"{len(self.nodes)} nodes, {len(self.inputs)} inputs, "
f"{len(self.inputs)} outputs."
)
rows.append("--LOCAL FUNCTIONS--")
for k, v in self.functions.items():
rows.append(f"{k[0]},{k[1]}({v.input}) -> {v.output}")
if self.constraints_:
rows.append("--CONSTRAINTS--")
for a, b in sorted(self.constraints_.items()):
rows.append(f"{a} = {b}")
rows.append("--PARAMETERS--")
rows.append("dynamic_examples=")
for i, (k, v) in enumerate(sorted(self._parameter_renaming.items())):
rows.append(f" {k} = {v!r}")
if i >= 10000:
break
rows.append("--SHAPE--")
rows.append("dynamic_examples=")
for i, (k, v) in enumerate(sorted(self._dynamic_examples.items())):
try:
rows.append(f" {k} = {v!r}")
except AttributeError:
rows.append(f" {k} = ERR: {type(v)!r}:{getattr(v, 'node', 'node=?')!r}")
if i >= 10000:
break
rows.append("dynamic_objects=")
for i, (k, v) in enumerate(sorted(self.dynamic_objects.items())):
try:
rows.append(f" {k} = {v!r}")
except AttributeError:
rows.append(f" {k} = ERR: {type(v)!r}:{getattr(v, 'node', 'node=?')!r}")
if i >= 10000:
break
rows.append("dynamic_objects_rev=")
for i, (k, v) in enumerate(sorted(self.dynamic_objects_rev.items())):
if isinstance(v, (list, tuple)):
rows.append(f" {k!r} = {type(v)}")
for vv in v:
if isinstance(vv, tuple):
rows.append(" tuple")
for vvv in vv:
try:
rows.append(f" {vvv!r}")
except AttributeError:
rows.append(
f" ERR**: {type(vvv)!r}:"
f"{getattr(vvv, 'node', 'node=?')!r}"
)
else:
try:
rows.append(f" {vv!r}")
except AttributeError:
rows.append(
f" ERR*: {type(vv)!r}:"
f"{getattr(vv, 'node', 'node=?')!r}"
)
else:
try:
rows.append(f" {k} = {v!r}")
except AttributeError:
rows.append(f" {k} = ERR-: {type(v)!r}:{getattr(v, 'node', 'node=?')!r}")
if i >= 10000:
break
rows.append(
f"dynamic_dimensions_source="
f"{pprint.pformat(self.dynamic_dimensions_source)[:10000]}"
)
rows.append(f"dynamic_alias={pprint.pformat(self._dynamic_alias)[:10000]}")
rows.append(f"dynamic_shapes={pprint.pformat(self.dynamic_shapes)[:10000]}")
rows.append(f"_known_value_shape={pprint.pformat(self._known_value_shape)[:10000]}")
rows.append(f"_known_types={pprint.pformat(self._known_types)[:10000]}")
rows.append(f"_known_shapes={pprint.pformat(self._known_shapes)[:10000]}")
rows.append(
f"_known_constants={pprint.pformat(list(sorted(self.constants_))[:10000])}"
)
reminaing_ranks = {
k: v for k, v in self._known_ranks.items() if k not in self._known_shapes
}
rows.append(f"_known_ranks={pprint.pformat(reminaing_ranks )[:10000]}")
rows.append("--TORCH-USERS--")
for k, v in sorted(self._registered_users.items()):
rows.append(f"{k} -> {v}")
rows.append("--TORCH-SHAPES--")
for kk, vv in self._known_torch_value.items():
rows.append(
f"{kk}: {vv} --- "
f"{self.get_type(kk) if self.has_type(kk) else ''}:"
f"{self.get_rank(kk) if self.has_rank(kk) else ''}:"
f"{self.get_shape(kk) if self.has_shape(kk) else ''}:"
)
if len(rows) > limit:
rows.append("...")
break
rows.append("--ONNX--")
for k, v in self._debug_msg.items():
rows.append(f"-- {k} --")
rows.append(pprint.pformat(v) if isinstance(v, dict) else str(v))
if len(rows) > limit:
rows.append("...")
break
rows.append("--")
for io in self.inputs:
shh = _nice_shape(io.type.tensor_type.shape)
rows.append(
f"[GraphBuilder-{hs}.make_tensor_input] {io.name}"
f"[{io.type.tensor_type.elem_type}:{shh}]"
)
for name, init in self.initializers_dict.items():
sval = "" if _size(init) > 5 else f":{_values(init)}"
rows.append(
f"[GraphBuilder-{hs}.make_initializer] "
f"{name}[{_dtype(init)}:{_shape(init)}{sval}]"
)
if len(rows) > limit:
rows.append("...")
break
for node in self.nodes:
if node is None:
continue
if node.op_type == "Cast":
ext = f">{node.attribute[0].i}"
else:
ext = ""
rows.append(
f"[GraphBuilder-{hs}.make_node] "
f"{_align(node.name, 15)} "
f"[{self._debug_string_inputs(node.input, node.output, 6)}] "
f"{node.op_type}{ext}:{node.input}->{node.output}"
)
if len(rows) > limit:
rows.append("...")
break
for io in self.outputs:
shh = _nice_shape(io.type.tensor_type.shape)
rows.append(
f"[GraphBuilder-{hs}.make_tensor_output] {io.name}"
f"[{io.type.tensor_type.elem_type}:{shh}]"
)
rows.append(
f"[GraphBuilder-{hs}] Message completed, there are "
f"{len(self.initializers_dict)} initializers, "
f"{len(self.nodes)} nodes, {len(self.inputs)} inputs, "
f"{len(self.inputs)} outputs."
)
return "\n".join(rows)
[docs]
def process(
self,
graph_module: "torch.fx.GraphModule", # noqa: F821
interpreter: "DynamoInterpreter", # noqa: F821
):
"""
Environment variable ``ONNX_BUILDER_PROGRESS=1`` can be used to show
a progress bar on big models.
"""
self._debug_msg["process.graph_module"] = graph_module.graph
# looks into output marked as "alias_of_input"
# see https://pytorch.org/functorch/main/_modules/torch/_functorch/aot_autograd.html
# in that case, gen_alias_from_base is mixing the input data and the output stride
# places = []
# for node in graph_module.graph.nodes:
# if node.op == "placeholder":
# places.append(node)
# for node in places:
# with graph_module.graph.inserting_after(node):
# cloned_node = graph_module.graph.call_method("clone", args=(node.target,))
# node.replace_all_uses_with(cloned_node)
# graph_module.recompile()
if int(os.environ.get("ONNX_BUILDER_PROGRESS", "0")) or (
self.verbose and len(graph_module.graph.nodes) > 100
):
try:
import tqdm
loop = tqdm.tqdm(list(enumerate(graph_module.graph.nodes)))
except ImportError:
loop = enumerate(graph_module.graph.nodes)
else:
loop = enumerate(graph_module.graph.nodes)
for i, node in loop:
self._debug_msg["process.progress"] = (
f"node {i}/{len(graph_module.graph.nodes)} target={node.target}"
)
interpreter.run_node(node)
def _extend_local_function_inputs(self) -> Tuple[Tuple[str, Any], Set[Tuple[str, str]]]:
"""
All the initializers may not be used in the main function.
The function filter out all the unused initializers.
The functions also filters out the unused local functions.
:param proto: function to modify, modified inplace
:param functions: other local functions
:param initializers_dict: initializers
:return: the new proto, the local functions names, and the used initializers
"""
# Let's sort the additional inputs by size, bigger is first.
used_initializers = self._get_used_initializers()
inits = [
(self.get_initializer_size(k), k, v)
for k, v in self.initializers_dict.items()
if k in used_initializers
]
inits.sort(reverse=True)
inits = [_[1:] for _ in inits]
inputs_to_add = [_[0] for _ in inits]
used_functions = self._get_used_local_functions()
return inputs_to_add, used_functions
def _check_constant(self, node: NodeProto, prefix: str):
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
if node.op_type != "Constant":
return
assert (
len(node.attribute) == 1
), f"{prefix}: unexpected number of attribute in node {node}"
assert (
node.attribute[0].type != AttributeProto.GRAPH
), f"{prefix}: wrong attribute type in node {node}"
def _check_constants(self, prefix="before-inline", add: Optional[Any] = None):
for v in self.constants_node_.values():
self._check_constant(v, prefix)
for v in self.nodes:
self._check_constant(v, prefix)
for k, v in self.functions.items():
for node in v.node:
self._check_constant(node, f"{prefix}-[{k}]")
if add is not None:
assert isinstance(add, FunctionProto), f"Not implemented for type {type(add)}"
for node in add.node:
self._check_constant(node, f"{prefix}-[add]")
[docs]
def to_onnx(
self,
optimize: bool = True,
large_model: bool = False,
external_threshold: int = 1024,
return_optimize_report: bool = False,
inline: bool = False,
function_options: Optional[FunctionOptions] = None,
) -> Union[FunctionProto, ModelProto, TorchModelContainer, Dict[str, Any]]:
"""
Conversion to onnx. Only then the initializers are converted into TensorProto.
:param optimize: disable or enable the optimization,
the optimization are set when the class constructor is called
:param large_model: if True returns a :class:`onnx.model_container.ModelContainer`,
it lets the user to decide later if the weights should be part of the model
or saved as external weights
:param external_threshold: if large_model is True, every tensor above this limit
is stored as external
:param return_optimize_report: return statistics about the optimization as well
:param inline: inline local functions, this is done before
any optimization takes place
:param function_options: to be set to export as a function
:return: the proto
"""
if function_options is None:
function_options = FunctionOptions()
if len(self.nodes) == 0:
raise RuntimeError(f"The onnx model is empty (no node).\n{self.get_debug_msg()}")
if inline:
self._check_constants("before-inline")
stats = self.inline_functions(verbose=self.verbose)
self._check_constants("after-inline")
else:
stats = None
if optimize:
statso = self.optimize()
if stats:
stats.extend(statso)
else:
stats = statso
if self._parameter_renaming:
# Adding the true names back.
update = {}
for k, v in self.initializers_dict.items():
if k in self._parameter_renaming:
update[self._parameter_renaming[k]] = v
self.initializers_dict.update(update)
assert len(self.nodes) > 0, (
f"The onnx model is empty after optimization (no node)."
f"\n{self.get_debug_msg()}"
)
opsets = [oh.make_opsetid(*o) for o in self.opsets.items()]
if function_options.export_as_function:
if self.verbose:
print(
f"[GraphBuilder-{self._hash()}.to_onnx] make_function "
f"{len(self.initializers_dict)} inits "
f"{len(self._parameter_renaming)} params"
)
assert (
function_options.name not in FunctionOptions.empty_names
and function_options.domain not in FunctionOptions.empty_names
), (
f"Function name={function_options.name!r} cannot be empty and "
f"Function domain={function_options.domain!r} cannot be empty."
)
key = function_options.domain, function_options.name
assert key not in self.functions, (
f"The given name {key} is already taken by a local function"
f"{self.pretty_text()}"
)
if function_options.move_initializer_to_constant:
self.move_initializers_to_constant(
full_parameter_name=False,
threshold=function_options.external_threshold,
verbose=max(0, self.verbose - 1),
)
# if self._parameter_renaming: we don't necessarily need to rename here.
# We better not if we want to make this function equivalent to it.
proto = oh.make_function(
function_options.domain,
function_options.name,
[i.name for i in self.inputs],
[o.name for o in self.outputs],
self.nodes,
opsets,
)
if not function_options.return_initializer:
return proto
if len(self.initializers_dict) == 0 and len(self.functions) == 0:
res = dict(proto=proto)
if self.functions:
used_functions = self._get_used_local_functions()
if used_functions:
res["functions"] = used_functions
return res
# We need to move the initializers as inputs, we sort than by decresing size
inits, functions = self._extend_local_function_inputs()
proto.input.extend(inits)
res = dict(
proto=proto,
initializers_name=inits,
initializers_dict={
self._parameter_renaming.get(k, k): v
for k, v in self.initializers_dict.items()
if k in set(inits)
},
)
if functions:
res["functions"] = [v for k, v in self.functions.items() if k in functions]
return res
if self.ir_version:
ir_version = self.ir_version
elif "" in self.opsets:
ir_version = _default_OPSET_TO_IR_VERSION()[self.opsets[""]]
if self.verbose:
print(
f"[GraphBuilder-{self._hash()}.to_onnx] make_model "
f"{len(self.initializers_dict)} inits "
f"{len(self._parameter_renaming)} params"
)
print(
f"[GraphBuilder-{self._hash()}.time_evaluation_constants_] "
f"{self.time_evaluation_constants_}"
)
# building the model
model = ModelProto()
model.graph.CopyFrom(GraphProto())
model.graph.name = "experiment"
model.graph.output.extend(self.outputs)
if self._parameter_renaming:
assert self.initializers_dict, (
f"Some parameters are renamed {self._parameter_renaming} "
f"but there is no initializer{self.get_debug_msg()}"
)
# We rename.
nodes_add = []
setp = set(self._parameter_renaming)
for node in self.nodes:
needs_rewrite = False
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
hidden = self._get_hidden_inputs(att.g)
if hidden & setp:
needs_rewrite = True
# needs rewrite
new_g = self._rename_inputs_in_subgraph(
att.g, self._parameter_renaming
)
if not needs_rewrite:
seti = set(node.input)
if not (seti & setp):
nodes_add.append(node)
continue
node2 = NodeProto()
node2.doc_string = node.doc_string
node2.name = node.name
node2.op_type = node.op_type
node2.domain = node.domain
node2.input.extend([self._parameter_renaming.get(i, i) for i in node.input])
node2.output.extend(node.output)
atts = []
for att in node.attribute:
if att.type != AttributeProto.GRAPH:
atts.append(att)
continue
new_g = self._rename_inputs_in_subgraph(att.g, self._parameter_renaming)
atts.append(oh.make_attribute(att.name, new_g))
node2.attribute.extend(atts)
nodes_add.append(node2)
model.graph.node.extend(nodes_add)
seti = set(i.name for i in self.inputs)
if seti & setp:
new_inputs = []
for i in self.inputs:
if i.name in self.setp:
v = ValueInfoProto()
v.ParseFromString(v.SerializeToString())
v.name = self._parameter_renaming[i.name]
new_inputs.append(v)
else:
new_inputs.append(i)
model.graph.input.extend(new_inputs)
else:
model.graph.input.extend(self.inputs)
else:
model.graph.node.extend(self.nodes)
model.graph.input.extend(self.inputs)
# initializer
initializers, large_initializers = self._build_initializers(
switch_low_high=sys.byteorder != "big",
large_model=large_model,
external_threshold=external_threshold,
full_parameter_name=True,
)
assert not self._parameter_renaming or initializers, (
f"Some parameters are renamed {self._parameter_renaming} "
f"self.initializers_dict={set(self.initializers_dict)} "
f"but there is no initializer{self.get_debug_msg()}"
)
model.graph.initializer.extend(initializers)
model.opset_import.extend(opsets)
model.functions.extend(self.functions.values())
model.ir_version = ir_version
self._add_shape_information(model)
doc_string = (
f"large_model={large_model}, inline={inline}, "
f"external_threshold={external_threshold}"
f"\nfunction_options={function_options!r}"
)
model.doc_string += doc_string + (
f"\noptimized:{self.optimization_options!r}" if optimize else "not-optimized"
)
assert (
not optimize
or not self.optimization_options.remove_identity
or len([n for n in model.graph.node if n.op_type == "Identity"])
<= len(model.graph.output)
), (
f"The optimization was not applied. There are two many nodes identity"
f"\n{self.pretty_text()}"
)
if large_model:
lm = TorchModelContainer()
lm.model_proto = model
if large_initializers:
lm.set_large_initializers(large_initializers)
lm.check_large_initializers()
return (lm, stats) if return_optimize_report else lm
return (model, stats) if return_optimize_report else model
def _add_shape_information(self, model: ModelProto):
if len(model.graph.node) == 0:
raise RuntimeError(
f"The onnx model is empty after export to onnx (no node)."
f"\n{self.get_debug_msg()}"
)
# restores the existing value_info
done = (
set(init.name for init in model.graph.initializer)
| set(i.name for i in model.graph.input)
| set(i.name for i in model.graph.output)
)
for val in self.value_info:
if self.has_name(val.name):
model.graph.value_info.append(val)
done.add(val.name)
# adding shape information
addition = []
for name in self._known_names:
if name in done or not self.has_type(name) or self.get_type(name) == 0:
continue
if self.has_shape(name):
addition.append(
oh.make_tensor_value_info(
name, self.get_type(name), list(self.get_shape(name))
)
)
elif self.has_rank(name):
addition.append(
oh.make_tensor_value_info(
name, self.get_type(name), [None] * self.get_rank(name)
)
)
if addition:
model.graph.value_info.extend(addition)
[docs]
def io_names(self):
"""
Returns the list of inputs, output for nodes.
"""
input_names = {i.name for i in self.inputs}
init_names = set(self.initializers_dict)
output_names = {i.name for i in self.outputs}
rows = [
"",
f"I<-[{','.join(sorted(input_names))}]",
f"C<-[{','.join(sorted(init_names))}]",
]
for node in self.nodes:
rows.append(
f"N:{node.op_type}:[{','.join(sorted(node.input))}]->[{','.join(sorted(node.output))}]"
)
rows.append(f"O->[{','.join(sorted(output_names))}]")
rows.append("")
return "\n".join(rows)
def _check(self, stats: List, step: str):
begin = time.perf_counter()
assert (
len(self.nodes) > 0
), f"The onnx model is empty (step {step}, no node){self.get_debug_msg()}"
known = set(n.name for n in self.inputs)
known |= set(self.initializers_dict)
for node in self.nodes:
assert (
node.domain in self.opsets
), f"Domain {node.domain!r} is not registered in {self.opsets}"
for i in node.input:
assert not i or self.has_name(i), (
f"Name {i!r} not registered, node type is "
f"{node.op_type!r}, node name is {node.name!r}, "
f"input are {node.input}{self.get_debug_msg()}"
)
if i == "":
continue
assert i in known, (
f"Unknown input {i!r}, step {step!r} in node type "
f"{node.op_type}, name is {node.name!r}\n{node}"
)
known |= set(node.output)
for o in self.outputs:
assert o.name in known, f"Unknown output {o.name!r}, step {step!r}"
stats.append(dict(pattern=f"check_{step}", time_in=time.perf_counter() - begin))
[docs]
def optimize(self) -> List[Dict[str, Any]]:
"""
Optimizes a graph.
Returns the list of applied processes.
"""
self._clean_values_cache()
statistics = []
main_begin = time.perf_counter()
if self.verbose or self.optimization_options.verbose:
print(f"[GraphBuilder.optimize] start with {len(self.nodes)} nodes")
if self.verbose > 1:
print(f"[GraphBuilder.optimize] options={self.optimization_options!r}")
else:
n_patterns = (
0
if self.optimization_options is None
or self.optimization_options.patterns is None
else len(self.optimization_options.patterns)
)
print(f"[GraphBuilder.optimize] #patterns={n_patterns}")
self._check(statistics, "A")
if self.optimization_options.remove_identity:
begin = time.perf_counter()
nr, na = self.remove_identity_nodes()
statistics.append(
dict(
pattern="remove_identity_nodes",
removed=nr,
added=na,
time_in=time.perf_counter() - begin,
)
)
self._check(statistics, "B")
if self.optimization_options.remove_unused:
begin = time.perf_counter()
n = self.remove_unused()
statistics.append(
dict(
pattern="remove_unused",
removed=n,
time_in=time.perf_counter() - begin,
)
)
self._check(statistics, "C")
if self.optimization_options.constant_folding:
# First constant removal
begin = time.perf_counter()
n = self.constant_folding()
statistics.append(
dict(
pattern="constant_folding",
removed=n,
time_in=time.perf_counter() - begin,
iteration=0,
)
)
self._check(statistics, "Da")
if self.optimization_options.remove_unused:
begin = time.perf_counter()
n = self.remove_unused()
statistics.append(
dict(
pattern="remove_unused",
removed=n,
time_in=time.perf_counter() - begin,
)
)
self._check(statistics, "Ea")
if self.optimization_options.patterns:
assert (
self.optimization_options.remove_unused
), "remove_unused must be positive for pattern optimizations"
n = len(self.nodes)
begin = time.perf_counter()
res = self.optimize_with_patterns()
statistics.extend(res)
statistics.append(
dict(
pattern="pattern_optimization",
removed=n - len(self.nodes),
time_in=time.perf_counter() - begin,
)
)
self._check(statistics, "F")
begin = time.perf_counter()
n = self.remove_unused()
statistics.append(
dict(
pattern="remove_unused",
removed=n,
time_in=time.perf_counter() - begin,
)
)
self._check(statistics, "G")
if self.optimization_options.constant_folding:
# Second constant removal
begin = time.perf_counter()
n = self.constant_folding()
statistics.append(
dict(
pattern="constant_folding",
removed=n,
time_in=time.perf_counter() - begin,
iteration=1,
)
)
self._check(statistics, "Db")
if self.optimization_options.remove_unused:
begin = time.perf_counter()
n = self.remove_unused()
statistics.append(
dict(
pattern="remove_unused",
removed=n,
time_in=time.perf_counter() - begin,
)
)
self._check(statistics, "Eb")
if self.optimization_options.order:
res = self.optimize_order()
statistics.extend(res)
self._check(statistics, "order")
if self.verbose or self.optimization_options.verbose:
duration = time.perf_counter() - main_begin
print(
f"[GraphBuilder.optimize] done with "
f"{len(self.nodes)} nodes in {duration:.3f}"
)
if self.verbose > 1:
print(self._compile_statistics(statistics))
return statistics
def _compile_statistics(self, statistics: List[Dict[str, Any]]) -> str:
stats = {}
for obs in statistics:
pattern = obs["pattern"]
if pattern not in stats:
stats[pattern] = {
"time_in": 0.0,
"iteration": [],
"match_index": [],
"removed": 0,
"added": 0,
"instances": 0,
}
o = stats[pattern]
for k, v in obs.items():
if k == "pattern":
continue
if k in {"time_in", "removed", "added", "instances"}:
o[k] += v
continue
if k in {"changed", "scale"}:
if k not in o:
o[k] = 0
o[k] += v
continue
if k in {"iter"}:
if k not in o:
o[k] = 0
o[k] = max(o[k], v)
continue
if k in {"algo"} and k not in o:
o[k] = []
o[k].append(v)
rows = []
for k, v in sorted(stats.items()):
line = (
f" STAT {k} +{v['added']} -{v['removed']} "
f"#it={len(set(v['iteration']))} "
f"maxmatch={max(v['match_index']) if v['match_index'] else 0} "
f"i={v['instances']} - time={v['time_in']}"
)
rows.append(line)
# adding statistics on node type
rows.append(self._compile_model_statistics(False))
rows.append(self._compile_model_statistics(True))
return "\n".join(rows)
def _compile_model_statistics(self, detailed: bool):
rows = [
f"--MODEL: {len(self.nodes)} nodes, {len(self.inputs)} inputs, "
f"{len(self.outputs)} outputs, "
f"{len(self.initializers_dict)} initializers--"
f"{'DETAILED--' if detailed else ''}"
]
if detailed:
def _shape(name):
if self.has_shape(name):
s = self.get_shape(name)
if len(s) == 0:
return "1"
return "x".join(map(str, s))
if self.has_rank(name):
r = self.get_rank(name)
if r == 0:
return "1"
return "x".join(["?"] * r)
return "?"
def _key(name):
if not name:
return ""
if isinstance(name, str):
tt = self.get_type(name) if self.has_type(name) else "?"
return f"{tt}t[{_shape(name)}]"
if name.op_type == "Transpose":
perm = ";".join(map(str, self.get_attribute(name, "perm").ints))
return f"{_key(name.input[0])}-perm={perm}"
return ", ".join(map(_key, name.input))
cc = Counter([_key(i) for i in self.input_names])
for k, v in sorted(cc.items()):
rows.append(f" INPUT: {v:3d} x {k}")
cc = Counter([_key(i) for i in self.output_names])
for k, v in sorted(cc.items()):
rows.append(f" OUTPUT: {v:3d} x {k}")
cc = Counter([_key(i) for i in self.initializers_dict])
for k, v in sorted(cc.items()):
rows.append(f" INIT: {v:3d} x {k}")
op_types = [(n.domain, n.op_type, _key(n)) for n in self.nodes]
cc = Counter(op_types)
for k, v in sorted(cc.items()):
if k[0] == "":
rows.append(f" NODE: {v:3d} x {k[1]} -SIG- {k[2]}")
else:
rows.append(f" NODE: {v:3d} x {k[0]}.{k[1]} -SIG- {k[2]}")
else:
cc = Counter([self.get_type(i) for i in self.input_names])
for k, v in sorted(cc.items()):
rows.append(f" INPUT: {v:3d} x {k}t")
cc = Counter([self.get_type(i) for i in self.output_names])
for k, v in sorted(cc.items()):
rows.append(f" OUTPUT: {v:3d} x {k}t")
cc = Counter([self.get_type(i) for i in self.initializers_dict])
for k, v in sorted(cc.items()):
rows.append(f" INIT: {v:3d} x {k}t")
op_types = [(n.domain, n.op_type) for n in self.nodes]
cc = Counter(op_types)
for k, v in sorted(cc.items()):
if k[0] == "":
rows.append(f" NODE: {v:3d} x {k[1]}")
else:
rows.append(f" NODE: {v:3d} x {k[0]}.{k[1]}")
return "\n".join(rows)
[docs]
def optimize_with_patterns(self) -> List[Dict[str, Any]]:
"""
Optimizes this graph with patterns.
"""
from ..xoptim import GraphBuilderPatternOptimization
gro = GraphBuilderPatternOptimization(
self,
verbose=max(self.verbose, self.optimization_options.verbose),
patterns=self.optimization_options.patterns,
recursive=self.optimization_options.recursive,
verifies=self.optimization_options.verifies,
dump_applied_patterns=self.optimization_options.dump_applied_patterns,
processor=self.optimization_options.processor,
)
return gro.optimize(
max_iter=self.optimization_options.max_iter,
remove_identity=self.optimization_options.remove_identity,
stop_after=self.optimization_options.stop_after,
)
def optimize_order(self):
from ..xoptim.order_optim import OrderOptimization
opt = OrderOptimization(
self,
algorithm=self.optimization_options.order,
verbose=max(self.verbose, self.optimization_options.verbose),
)
return opt.optimize()
@classmethod
def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
"""
Returns the hidden inputs (inputs coming from an upper context)
used by a subgraph.
"""
hidden = set()
memo = set(i.name for i in graph.initializer)
memo |= set(i.name for i in graph.sparse_initializer)
for node in graph.node:
for i in node.input:
if i not in memo:
hidden.add(i)
for att in node.attribute:
if att.type == AttributeProto.GRAPH and att.g:
hid = self._get_hidden_inputs(att.g)
less = set(h for h in hid if h not in memo)
hidden |= less
memo |= set(node.output)
return hidden
def _get_used_initializers(self) -> Set[str]:
"""
Returns the initializers name involved in the graph.
"""
hidden = set()
memo = set(i.name for i in self.inputs)
for node in self.nodes:
for i in node.input:
if i not in memo:
hidden.add(i)
for att in node.attribute:
if att.type == AttributeProto.GRAPH and att.g:
hid = self._get_hidden_inputs(att.g)
less = set(h for h in hid if h not in memo)
hidden |= less
memo |= set(node.output)
assert all(name in self.initializers_dict for name in hidden if name), (
f"Some hidden inputs in {sorted(hidden)!r} are not initializers "
f"{sorted(self.initializers_dict)}. It is unexpected."
)
return hidden
def _get_used_local_functions(
self, nodes: Optional[Sequence[NodeProto]] = None
) -> Set[Tuple[str, str]]:
"""
Returns the local functions used in the graph.
"""
if nodes is None:
nodes = self.nodes
used = set()
for node in nodes:
key = node.domain, node.op_type
if key in self.functions:
used.add(key)
for att in node.attribute:
if att.type == AttributeProto.GRAPH and att.g:
used |= self._get_used_local_functions(att.g.node)
# Looking into used functions
done = set()
stack = list(used)
while stack:
key = stack.pop()
if key not in done:
f = self.functions[key]
used |= self._get_used_local_functions(f.node)
done.add(key)
return used
@classmethod
def _enumerate_inputs_with_subgraph(cls, node: NodeProto) -> Iterator[str]:
"""
Enumerates all inputs from a node including all the hidden inputs
from subgraphs.
"""
yield from node.input
if node.op_type in {"Loop", "Scan", "If", "SequenceMap"}:
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
hidden_inputs = cls._get_hidden_inputs(att.g)
yield from hidden_inputs
[docs]
def remove_unused(self) -> int:
"""
Simple function to remove unused nodes.
It does not look into subgraphs and assumes there is none.
Everything is done in one pass.
Returns the number of removed nodes.
"""
self._clean_values_cache()
start = len(self.nodes)
# mark outputs
marked = {o.name: set() for o in self.outputs}
for node in reversed(self.nodes):
used = False
node_inputs = list(self._enumerate_inputs_with_subgraph(node))
for o in node.output:
if o in marked:
for i in node_inputs:
marked[o].add(i)
used = True
if used or self.do_not_remove(node):
for i in node_inputs:
marked[i] = set()
# removed nodes
removed = set()
marked_set = set(marked)
for ind, node in enumerate(self.nodes):
if not (set(node.output) & marked_set):
removed.add(ind)
if self.optimization_options.verbose > 1:
n_not_marked = 0
for i, (k, v) in enumerate(self.initializers_dict.items()):
if k not in marked:
n_not_marked += 1
v = self.initializers_dict[k]
if hasattr(v, "dtype") and hasattr(v, "shape"):
print(
f"[GraphBuilder.remove_unused] {i}/{len(self.initializers_dict)}"
f"remove_initializer:{k}:{v.dtype}[{v.shape}]"
)
else:
print(
f"[GraphBuilder.remove_unused] remove_initializer {n_not_marked}:"
f"{i}/{len(self.initializers_dict)}: {k}]"
)
self.initializers_dict = {
k: v for k, v in self.initializers_dict.items() if k in marked
}
self.constants_ = {k: v for k, v in self.constants_.items() if k in marked}
if self.optimization_options.verbose > 2:
for i in removed:
node = self.nodes[i]
print(
f"[GraphBuilder.remove_unused_node] remove "
f"{node.op_type}-{node.name} -> {node.output}"
)
self.nodes = [node for i, node in enumerate(self.nodes) if i not in removed]
return start - len(self.nodes)
[docs]
def compute_constant(
self, name: str, exc: bool = True, only_array: bool = False, allow_empty: bool = False
) -> Tuple[np.ndarray, Optional[Dict[str, np.ndarray]]]:
"""
Computes a constant.
:param name: constant name
:param exc: raises an exception if any failure
:param only_array: do not return TensorProto
:param allow_empty: allow empty result
:return: constant
If returns None if the constant is a FakeTensor.
"""
if self.main_opset < 18:
# This functionality is not enabled before that opset.
return None, None
assert self.is_constant(name), f"Name {name!r} is not a constant"
if name in self.initializers_dict:
value = self.initializers_dict[name]
assert not isinstance(
value, tuple
), f"Unexpected type {type(value)} for name={name!r}"
if only_array and isinstance(value, TensorProto):
# Should reuse memory buffer here.
v = onh.to_array(value)
self.add_initializer(name, v, existing=True, allow_empty=allow_empty)
return v, None
if isinstance(value, self.torch._subclasses.fake_tensor.FakeTensor):
return None, None
return value, None
v = self.constants_[name]
# It should not be None but a node as it is not an initializer.
assert isinstance(
v, NodeProto
), f"Unexpected type {type(v)} for constant name={name!r}"
if self._debug_get_constant:
print(f"[GraphBuilder.compute_constant] {self.pretty_node(v)}")
if v.op_type == "Shape":
if not self.has_shape(v.input[0]):
# We stop.
return None, None
shape = self.get_shape(v.input[0])
if is_static_shape(shape):
if v.attribute:
start = 0
end = None
for att in v.attribute:
if att.name == "start":
start = att.i
elif att.name == "end":
end = att.i
shape = shape[start:] if end is None else shape[start:end]
if self._debug_get_constant:
print(
f"[GraphBuilder.compute_constant] - SHAPE "
f"{name}: {shape}? start={start}, end={end}"
)
elif self._debug_get_constant:
print(f"[GraphBuilder.compute_constant] - SHAPE {name}: {shape}?")
return np.array(shape, dtype=np.int64), {
v.input[0]: self.ShapeConstant(v.input[0], shape, v)
}
if not self.is_constant(v.input[0]):
# One exception here as the input maybe not
# be constant but the shape may be known.
assert all_int(shape), (
f"Shape must be static ({shape}) if shape is constant in {v}"
f"{self.get_debug_msg()}"
)
with self.maybe_disable_fake_tensor_mode():
output = self._apply_shape_on_shape(v, shape)
if isinstance(output[0], self.torch.Tensor):
# We convert the tensor into numpy array,
# it is a small shape anyway so the FakeMode
# does not come up as an issue.
output = [output[0].detach().cpu().numpy()]
if self._debug_get_constant:
print(
f"[GraphBuilder.compute_constant] - A "
f"{name}: {self.pretty_tensor(output[0])}"
)
return output[0], {v.input[0]: self.ShapeConstant(v.input[0], shape, v)}
return None, None
feeds = {i: self.get_constant(i, exc=exc, computed_value=True) for i in v.input}
for kval, val in feeds.items():
if not exc and "FakeTensor" in str(type(val)):
return None, None
assert "FakeTensor" not in str(type(val)), (
f"FakeTensor {kval!r} cannot be an initializer {type(val)}, "
f"v.op_type={v.op_type!r}"
f"{self.get_debug_msg()}"
)
if val is None:
return None, None
assert len(val.shape) == 0 or min(val.shape) > 0, (
f"One input has a empty shape {val.shape}, name={kval!r}"
f"v.op_type={v.op_type!r}, v.name={v.name!r}{self.get_debug_msg()}"
)
with self.maybe_disable_fake_tensor_mode():
if v.op_type == "Identity":
# much faster this way
output = [feeds[v.input[0]]]
elif v.op_type in {"Mul", "Add", "Sub", "Div"}:
# bypassing onnx.numpy_helper.from_array, too slow
output = self._apply_binary_op(v, feeds)
elif v.op_type in {"Sqrt"}:
# bypassing onnx.numpy_helper.from_array, too slow
output = self._apply_unary_function(v, feeds)
elif hasattr(self, f"_apply_{v.op_type.lower()}"):
output = getattr(self, f"_apply_{v.op_type.lower()}")(v, feeds)
elif all(isinstance(v, np.ndarray) for v in feeds.values()):
# Let's avoid big computation on CPU.
max_dim = 0
for _v in feeds.values():
max_dim = max(max_dim, np.prod(_v.shape))
if max_dim >= 2**22:
if self.verbose > 1:
print(
f"[GraphBuilder.compute_constant] stop computing a "
f"constant as it may be too big, shapes are "
f"{[_.shape for _ in feeds.values()]}"
)
return None, None
begin = time.perf_counter()
ref = ExtendedReferenceEvaluator(v)
try:
output = ref.run(None, feeds)
except (ValueError, TypeError) as e:
sf = ", ".join(f"{k}:{v.dtype}:{v.shape}" for k, v in feeds.items())
if "warnings" not in self._debug_msg:
self._debug_msg["warnings"] = []
sv = str(v).replace("\n", " ")
self._debug_msg["warnings"].append(f"Issue with v={sv}, feeds={sf}, e={e}")
self.time_evaluation_constants_ += time.perf_counter() - begin
return None, None
self.time_evaluation_constants_ += time.perf_counter() - begin
else:
return None, None
cst = None
for n, val in zip(v.output, output):
assert not isinstance(val, tuple), f"Unexpected type {type(val)} for n={n!r}"
assert "FakeTensor" not in str(type(val)), (
f"FakeTensor detected {type(val)} in constant {name!r}, "
f"v.op_type={v.op_type!r}{self.get_debug_msg()}"
)
if self.has_type(n):
# numpy changes the expected type sometimes
# (like transpose(x: float36) --> float32)
itype = self.get_type(n)
if hasattr(val, "detach"):
val = val.to(onnx_dtype_to_torch_dtype(itype))
else:
val = val.astype(oh.tensor_dtype_to_np_dtype(itype))
self.constants_computed_[n] = val
if name == n:
cst = val
assert len(cst.shape) == 0 or min(cst.shape) > 0, (
f"Output has empty shape {cst.shape}, name={name!r}"
f"v.op_type={v.op_type!r}, v.name={v.name!r}{self.get_debug_msg()}"
)
assert cst is not None, f"Constant {name!r} was not found in {v.output}"
if isinstance(cst, self.torch._subclasses.fake_tensor.FakeTensor):
return None, None
if self._debug_get_constant:
print(f"[GraphBuilder.compute_constant] - A {name}: {self.pretty_tensor(cst)}")
return cst, feeds
[docs]
def constant_folding(self, convert_into_initializer: bool = True) -> int:
"""
Folds all constants. Constants are marked during the creation of the graph.
There is no need to propagate this information.
:param convert_into_initializer: moves the constant as an initializer,
otherwise, just evaluates it
:return: number of removed nodes
"""
if self.verbose > 1:
begin_ = time.perf_counter()
print(
f"[GraphBuilder.constant_folding] starts with "
f"{len(self.constants_)} constants and "
f"{len(self.nodes)} nodes."
)
if self.verbose >= 10:
for name in self._known_names:
print(
f"[GraphBuilder.constant_folding] cst:: "
f"{1 if self.is_constant(name) else '.'} :: {name}"
)
start = len(self.nodes)
updates = {}
node_to_remove = set()
for k, v in self.constants_.items():
if v is None:
# this is an initiliazer
if self.verbose > 4:
print(f"[GraphBuilder.constant_folding] initializer: {k}")
continue
assert isinstance(v, NodeProto), f"Unexpected type {type(v)} for k={k!r}"
if self.verbose > 4:
print(f"[GraphBuilder.constant_folding] from: {v.op_type}({k})")
# a node
if all(map(self.is_constant, v.input)):
# node evaluation
output, feeds = self.compute_constant(k, exc=False)
if output is None:
# Evaluation failed.
continue
if convert_into_initializer:
node_to_remove.add(tuple(v.output))
if not isinstance(output, tuple):
output = (output,)
for name, value in zip(v.output, output):
updates[name] = None
if convert_into_initializer:
assert "FakeTensor" not in str(type(value)), (
f"FakeTensor {name!r} cannot be an initializer {type(value)}, "
f"v.op_type={v.op_type!r} (input types: "
f"{[type(i) for i in feeds.values()]})"
f"{self.get_debug_msg()}"
)
self.add_initializer(name, value, existing=None)
else:
updates[name] = v
if self.verbose > 3:
print(
f"[GraphBuilder.constant_folding] fold_constant:"
f"{v.op_type}:{name}[{value.dtype}:"
f"{value.shape}]:from:{','.join(sorted(feeds))}"
)
for k, v in updates.items():
self.update_node_constant(k, v)
new_nodes = []
for node in self.nodes:
if self.do_not_remove(node):
new_nodes.append(node)
continue
if tuple(node.output) in node_to_remove:
continue
new_nodes.append(node)
self.nodes = new_nodes
if self.verbose > 1:
print(
f"[GraphBuilder.constant_folding] ends with "
f"{len(self.constants_)} constants and "
f"{len(self.nodes)} nodes in "
f"{time.perf_counter() - begin_} seconds"
)
return start - len(self.nodes)
def _clean_values_cache(self):
"""
Cleans the cache. The cache is used to avoid the creation of new constants
while creating a graph. It should be removed the graph is modified.
"""
if self._values:
self._values.clear()
@classmethod
def _rename_inputs_in_node(
cls, node: NodeProto, replacements: Dict[str, str], to_rename: Set[str]
) -> NodeProto:
set_inputs = set(cls._enumerate_inputs_with_subgraph(node))
if set_inputs & to_rename:
new_inputs = [replacements.get(i, i) for i in node.input]
new_outputs = [replacements.get(i, i) for i in node.output]
assert set(new_inputs) & set(new_outputs) in ({""}, set()), (
f"Node type {node.op_type}-{node.name} is incorrectly replaced "
f"{node.input}->{new_inputs} and {node.output}->{new_outputs}\n"
f"replacements are\n{pprint.pformat(replacements)}"
)
new_node = oh.make_node(
node.op_type,
new_inputs,
new_outputs,
domain=node.domain,
name=node.name,
)
if node.op_type in {"Loop", "Scan", "If", "SequenceMap"}:
# Hidden inputs must be taken care of.
node_attributes = []
for att in node.attribute:
node_attributes.append(
att
if att.type != AttributeProto.GRAPH
else oh.make_attribute(
att.name,
cls._rename_inputs_in_subgraph(att.g, replacements),
)
)
else:
node_attributes = node.attribute
new_node.attribute.extend(node_attributes)
return new_node
return node
@classmethod
def _rename_inputs_in_subgraph(
cls, graph: GraphProto, replacements: Dict[str, str]
) -> GraphProto:
"""
Renames inputs.
There is a weird case when one of the result defined in the inner context
is overriden by the subgraph itself. We should assume this case never happens.
"""
# graph inputs and outputs should not be changed, initializer as well
to_rename = set(replacements)
was_copied = False
if set(i.name for i in graph.input) & to_rename:
# An input of the graph is overrides one of the replacements.
# The replacement should noe take place then.
replacements = replacements.copy()
for i in graph.input:
if i.name in to_rename:
del replacements[i.name]
to_rename = set(replacements)
was_copied = True
nodes = []
for node in graph.node:
nodes.append(cls._rename_inputs_in_node(node, replacements, to_rename))
if set(node.output) & to_rename:
# An output overrides a replacement
if not was_copied:
replacements = replacements.copy()
was_copied = True
for i in node.output:
if i in to_rename:
del replacements[i]
to_rename = set(replacements)
return oh.make_graph(
nodes,
graph.name,
graph.input,
graph.output,
graph.initializer,
sparse_initializer=graph.sparse_initializer,
)
[docs]
def remove_identity_nodes(self) -> Tuple[int, int]:
"""
Removes identity nodes. Returns the number of removed nodes
and the number of added nodes.
.. note::
onnxruntime does not handle well when it is executing from domain
*'org.pytorch.aten'* (ATen for example) which outputs results
on CPU where the expected output is on CUDA. An identity node must be
kept or inserted in that case. In that particular case, a node can be
marked so that it does not get deleted: its name must start with
``'_DONOTREMOVE_'``.
"""
self._clean_values_cache()
# make_initializer
if self.verbose > 1:
begin_ = time.perf_counter()
print(f"[GraphBuilder.remove_identity_nodes] starts with {len(self.nodes)}")
# first pass: detect replacements
new_nodes = []
input_names = set(i.name for i in self.inputs)
output_names = set(i.name for i in self.outputs)
replacements = {}
replacements_rev = {}
removed = 0
for node in self.nodes:
if node.op_type != "Identity" or self.do_not_remove(node):
new_nodes.append(node)
continue
if node.output[0] not in output_names:
old_name, new_name = node.output[0], node.input[0]
elif (
node.input[0] not in input_names
and node.input[0] not in output_names
and node.input[0] not in replacements
):
old_name, new_name = node.input[0], node.output[0]
else:
new_nodes.append(node)
continue
# the new name can be set for replacements as well
if new_name in replacements:
new_name = replacements[new_name]
assert new_name not in replacements, (
f"Name {old_name!r} still in {replacements}, "
f"node.op_type={node.op_type!r}, "
f"node.input={node.input}, node.output={node.output}, "
f"input_names={input_names}, output_names={output_names}"
)
if old_name in replacements_rev:
old_old_name = replacements_rev[old_name]
replacements[old_old_name] = new_name
replacements_rev[new_name] = old_old_name
if old_name in replacements:
replacements[replacements[old_name]] = new_name
assert new_name not in replacements, (
f"Name {old_name!r} still in {replacements}, "
f"node.op_type={node.op_type!r}, "
f"node.input={node.input}, node.output={node.output}, "
f"input_names={input_names}, output_names={output_names}"
)
if old_name in replacements_rev:
# A tricky case:
# x -> Identity -> a -> Identity -> b -> Flatten -> output1
# x -> Identity -> output0
# How x should be renamed?
assert new_name in output_names, (
f"replacement {old_name}->{new_name} "
f"is not possible because of "
f"[{replacements_rev[old_name]}->{old_name}] "
f"and {new_name!r} is not an output"
)
updates = {}
for k, v in replacements.items():
if v == old_name:
updates[k] = new_name
replacements.update(updates)
replacements[old_name] = new_name
replacements_rev[new_name] = old_name
# verification
if len(replacements) < 500:
for k, v in replacements.items():
assert v not in replacements, (
f"replacement {k}->{v} is not possible because of "
f"[{v}->{replacements[v]}], old_name={old_name!r}, "
f"new_name={new_name!r}"
)
removed += 1
# second pass: replacements in initializer
if self.verbose > 1:
print(
f"[GraphBuilder.remove_identity_nodes] found "
f"{len(replacements)} replacements"
)
for k, v in replacements.items():
if k in self.initializers_dict:
if self.optimization_options.verbose > 2:
print(
f"[GraphBuilder.remove_identity_nodes] "
f"rename initializer {k!r} by {v!r}"
)
assert "FakeTensor" not in str(type(self.initializers_dict[k])), (
f"FakeTensor {k!r} cannot be an initializer "
f"{type(self.initializers_dict[k])}{self.get_debug_msg()}"
)
self.add_initializer(
v,
self.initializers_dict[k],
itype=self.get_type(k),
shape=self.get_shape(k),
cst=self.constants_[k],
existing=None,
)
del self.initializers_dict[k]
del self.constants_[k]
# third pass: replacements in node
if self.verbose > 1:
print(f"[GraphBuilder.remove_identity_nodes] kept {len(new_nodes)} nodes")
self.nodes = []
added = 0
for node in new_nodes:
repo = {o for o in node.output if o in replacements}
repi = {o for o in self._enumerate_inputs_with_subgraph(node) if o in replacements}
if repi or repo:
new_inputs = [replacements.get(i, i) for i in node.input]
new_outputs = [replacements.get(i, i) for i in node.output]
assert set(new_inputs) & set(new_outputs) in ({""}, set()), (
f"Node type {node.op_type}-{node.name} is incorrectly replaced "
f"{node.input}->{new_inputs} and {node.output}->{new_outputs}\n"
f"replacements are\n{pprint.pformat(replacements)}"
)
if self.optimization_options.verbose > 2:
print(
f"[GraphBuilder.remove_identity_nodes] "
f"node {node.op_type}-{node.name}:"
f"{node.input}->{new_inputs}:{node.output}->{new_outputs}"
)
new_node = oh.make_node(
node.op_type,
new_inputs,
new_outputs,
domain=node.domain,
name=node.name,
)
added += 1
removed += 1
if node.op_type in {"Loop", "Scan", "If", "SequenceMap"}:
# Hidden inputs must be taken care of.
node_attributes = []
for att in node.attribute:
node_attributes.append(
att
if att.type != AttributeProto.GRAPH
else oh.make_attribute(
att.name,
self._rename_inputs_in_subgraph(att.g, replacements),
)
)
else:
node_attributes = node.attribute
new_node.attribute.extend(node_attributes)
self.nodes.append(new_node)
if all(self.is_constant(i) for i in new_node.input):
for o in new_node.output:
self.update_node_constant(o, new_node)
else:
self.nodes.append(node)
if self.verbose > 1:
print(
f"[GraphBuilder.remove_identity_nodes] ends with "
f"{len(self.nodes)} nodes in "
f"{time.perf_counter() - begin_} seconds"
)
# fourth pass: simplify the graph.
identity_outputs = {}
for node in self.nodes:
if node.op_type != "Identity" or node.domain != "":
continue
anc = node.input[0]
while anc in identity_outputs:
anc = identity_outputs[anc]
identity_outputs[node.output[0]] = anc
for node in self.nodes:
new_inputs = []
rename = False
for i in node.input:
if i in identity_outputs:
new_inputs.append(identity_outputs[i])
rename = True
else:
new_inputs.append(i)
if rename:
del node.input[:]
node.input.extend(new_inputs)
# results
return removed, added
def _position_msg(
self, nodes: List[NodeProto], around: Optional[Tuple[int, int]] = None
) -> str:
"Buids an error message."
pos = {}
posi = {}
for i, n in enumerate(self.nodes):
if n is None:
continue
pos[id(n)] = i
for o in n.output:
pos[o] = i
for o in n.input:
if o not in posi:
posi[o] = i
rows = []
for node in nodes:
if node is None:
continue
rows.append(
f"{node.op_type}({', '.join(node.input)}) -> "
f"[{', '.join(node.output)}] -- {node.name}"
)
for i in node.input:
rows.append(f" -> pos({i}) = {pos.get(i, ' -')} -> {posi.get(i, ' -')}")
for i in node.output:
rows.append(f" <- pos({i}) = {pos.get(i, ' -')} -> {posi.get(i, ' -')}")
if around is None:
return "\n".join(rows)
rows.append("---")
for i in range(max(0, around[0] - 3), min(len(self.nodes), around[1] + 3)):
n = self.nodes[i]
if n is None:
continue
rows.append(
f"P{i}: {n.op_type}({', '.join(n.input)}) -> "
f"[{', '.join(n.output)}] -- {n.name}"
)
return "\n".join(rows)
def _needed_at_first_at(self):
"Needed by insert_and_remove_nodes."
needed_at = {}
first_at = {}
for i, node in enumerate(self.nodes):
for name in node.input:
if name not in needed_at:
needed_at[name] = i
for name in node.output:
if name not in first_at:
first_at[name] = i
return needed_at, first_at
def _move_node_position(self, pos: int) -> int:
"""Tries to move a node at position pos closed to the beginning."""
the_node = self.nodes[pos]
first_at = {}
for i, node in enumerate(self.nodes):
if i > pos:
break
for name in node.output:
if name not in first_at:
first_at[name] = i
can_be = max(first_at.get(i, 0) for i in the_node.input) + 1
if can_be >= pos:
return None
self.nodes[can_be + 1 : pos + 1] = self.nodes[can_be:pos]
self.nodes[can_be] = the_node
return can_be
[docs]
def insert_and_remove_nodes(
self,
insert_at: Optional[int],
new_nodes: List[NodeProto],
removed: List[int],
opsets: Optional[Dict[str, int]] = None,
debug: Optional[Any] = None,
) -> List[NodeProto]:
"""
Inserts new nodes and removes others.
:param insert_at: insert the new nodes at this position,
if empty, the function guesses where to add them
:param new_nodes: list of nodes to insert
:param removed: list of nodes to removed (based on their positions)
:param opsets: opsets used
:param debug: anything added to exception messages
:return: list of removed nodes
"""
assert insert_at is None or not removed or min(removed) <= insert_at, (
f"The position {insert_at} must be higher than the position "
f"of the removed nodes {removed}"
)
memo = []
for i in removed:
assert i < len(
self.nodes
), f"Unable to remove node position {i}, there are {len(self.nodes)}"
n = self.nodes[i]
if not n:
# already marked as removed
continue
assert n and not self.do_not_remove(
n
), f"Node {n.name!r} marked as 'DONOTREMOVE' cannot be removed."
memo.append(n)
# We need to remove the constant.
for o in n.output:
if o in self.constants_:
del self.constants_[o]
self.nodes[i] = None
n_existing = []
for node in new_nodes:
for i in self._enumerate_inputs_with_subgraph(node):
assert self.has_name(i), (
f"Input {i!r} does not exist for node {node.op_type!r}, "
f"debug={debug}{self.get_debug_msg()}"
)
for o in node.output:
if self.has_name(o):
# connecting to existing input
n_existing.append(o)
else:
self.set_name(o, f"insert_and_remove_nodes_{node.op_type}_{o}")
node_domain = node.domain or ""
if node_domain in self.opsets:
if opsets and node_domain in opsets:
assert compatible_opsets(
node_domain,
node.op_type,
current=self.opsets[node_domain],
new_version=opsets[node_domain],
), (
f"Incompatible opset for node {node_domain!r} "
f"from domain {node_domain!r}, "
f"current is {self.opsets[node_domain]}, "
f"new is {opsets[node_domain]}"
)
else:
if opsets and node_domain in opsets:
self.opsets[node_domain] = opsets[node_domain]
else:
self.opsets[node_domain] = choose_consistent_domain_opset(
node_domain,
opsets=self.opsets,
)
assert n_existing, "Any output of the new node is connected to existing names."
if insert_at is not None:
for i, n in enumerate(new_nodes):
assert isinstance(n, NodeProto), f"Unexpected type {type(n)} for a node"
self.nodes.insert(insert_at + i, n)
self._make_node_set_type_shape_constant(n, True)
self._make_node_set_type_shape(n)
self.nodes = [n for n in self.nodes if n is not None]
return memo
self.nodes = [n for n in self.nodes if n is not None]
# Needs to insert the nodes at the right location.
# Let's find out where the best position is.
needed_at, first_at = self._needed_at_first_at()
# First loop to check positions are ok otherwise move a node or two.
N = len(self.nodes)
inode = 0
while inode < len(new_nodes):
node = new_nodes[inode]
if node.input:
min_position = max(first_at.get(i, -1) for i in node.input) + 1
else:
# a constant node
min_position = 0
max_position = min(needed_at.get(o, N) for o in node.output)
if min_position <= max_position:
inode += 1
continue
# trouble, let's assume one move is ok.
mini = max((first_at.get(i, -1), i) for i in node.input)
pos, name = mini
assert (
name in self.nodes[pos].output
), f"Name {name!r} should be at node position {pos}"
new_position = self._move_node_position(pos)
assert new_position, (
f"Node at position {pos} cannot be moved.\n----\n"
f"{self._position_msg([node])}"
f"\n-------\n{self._position_msg(new_nodes)}"
)
needed_at, first_at = self._needed_at_first_at()
# guess the position to insert the nodes at
# the order of the new nodes is consistent but it may have to be changed
# if it does not fit the existing order
insert_needed_at = {}
insert_first_at = {}
N = len(self.nodes)
inserted_at = []
new_nodes_p = []
for init, node in enumerate(new_nodes):
if node.input:
min_position = max(first_at.get(i, -1) for i in node.input) + 1
else:
# a constant node
min_position = 0
max_position = min(needed_at.get(o, N) for o in node.output)
assert min_position <= max_position, (
f"Unable to insert node {self.pretty_node(node)}, "
f"min_position={min_position}, max_position={max_position}, "
f"len(nodes)={len(self.nodes)}, previous insertions={inserted_at}, "
f"insert_needed_at={insert_needed_at}, insert_first_at={insert_first_at}, "
f"inserted_at={inserted_at}\n{self._position_msg([node])}"
f"\n-------\n{self._position_msg(new_nodes)}"
)
if node.input:
local_min_position = max(insert_first_at.get(i, -1) for i in node.input)
else:
# a constant node
local_min_position = 0
local_max_position = min(insert_needed_at.get(o, N) for o in node.output)
assert local_min_position <= local_max_position, (
f"Unable to insert node {self.pretty_node(node)}, "
f"local_min_position={local_min_position}, "
f"local_max_position={local_max_position}, "
f"len(nodes)={len(self.nodes)}, previous insertions={inserted_at}, "
f"insert_needed_at={insert_needed_at}, insert_first_at={insert_first_at}"
)
insert_position = max(min_position, local_min_position)
new_nodes_p.append((insert_position, init, node))
for i in node.input:
insert_needed_at[i] = min(
insert_position, insert_needed_at.get(i, insert_position)
)
for i in node.output:
insert_first_at[i] = min(
insert_position, insert_first_at.get(i, insert_position)
)
assert len(new_nodes) == len(new_nodes_p)
new_nodes_p.sort()
# do the addition
init_nams = {}
for p, _, n in reversed(new_nodes_p):
assert isinstance(n, NodeProto), f"Unexpected type {type(n)} for a node"
self.nodes.insert(p, n)
for _, _, n in new_nodes_p:
if n.output[0] in init_nams:
continue
self._make_node_set_type_shape_constant(n, True)
self._make_node_set_type_shape(n)
return memo
@classmethod
def _clean_shapes(cls, proto: Union[GraphProto, ModelProto]):
# cleaning unresolved shapes
if isinstance(proto, ModelProto):
cls._clean_shapes(proto.graph)
return
new_shapes = []
for sh in proto.value_info:
if sh.type.tensor_type.elem_type == 0:
continue
new_shapes.append(sh)
del proto.value_info[:]
proto.value_info.extend(new_shapes)
def _update_shape_types_with_proto_one_result(self, val):
itype = val.type.tensor_type.elem_type
if itype > 0:
self.set_type(val.name, itype)
shape = tuple(
d.dim_param if d.dim_param else d.dim_value for d in val.type.tensor_type.shape.dim
)
if all(isinstance(s, int) for s in shape) and -1 in shape:
# Some converters uses -1 to specify a dynamic dimension.
# We replace the value with a string
new_shape = []
for index, s in enumerate(shape):
if s >= 0:
new_shape.append(s)
continue
dyn_name = f"{val.name}_{index}"
new_shape.append(dyn_name)
shape = tuple(new_shape)
for i, sh in enumerate(shape):
if isinstance(sh, int):
continue
self.make_dynamic_object(sh, self.torch.SymInt(sh), input_name=val.name, axis=i)
self.set_shape(val.name, shape, exc=False)
def _update_shape_types_with_proto(
self, proto: ModelProto, infer_shapes: Union[bool, str] = False
):
"""
Updates the shapes and types for an existing model.
:param proto: model proto
:param infer_shapes: infer shapes to fill information about type and shapes
run shape inference, if the value is `'new'`,
existing shapes are ignored
"""
if self.verbose > 1:
begin_ = time.perf_counter()
print(
f"[GraphBuilder._update_shape_types_with_proto] starts with "
f"{len(self.nodes)} nodes and {len(getattr(proto.graph, 'value_info', 0))} "
f"shapes."
)
assert isinstance(proto, ModelProto), f"Unexpected type {type(proto)} for proto"
if infer_shapes:
if self.verbose > 1:
print("[GraphBuilder._update_shape_types_with_proto] infer shapes")
if infer_shapes == "new":
del proto.graph.value_info[:]
new_proto = onnx_infer_shapes(proto)
if self.verbose > 1:
print(
"[GraphBuilder._update_shape_types_with_proto] "
f"infer shapes done {time.perf_counter() - begin_} seconds"
)
self._clean_shapes(new_proto)
if self.verbose > 1:
print(
"[GraphBuilder._update_shape_types_with_proto] "
f"_clean_shapes after {time.perf_counter() - begin_} seconds"
)
else:
new_proto = proto
if not hasattr(new_proto.graph, "value_info"):
if self.verbose > 1:
print(
f"[GraphBuilder._update_shape_types_with_proto] ends in "
f"{time.perf_counter() - begin_} seconds."
)
return
if self.verbose > 1:
begin_ = time.perf_counter()
print(
f"[GraphBuilder._update_shape_types_with_proto] "
f"walk through {len(proto.graph.value_info)} shapes."
)
for val in new_proto.graph.value_info:
self._update_shape_types_with_proto_one_result(val)
if self.verbose > 1:
print(
f"[GraphBuilder._update_shape_types_with_proto] ends in "
f"{time.perf_counter() - begin_} seconds."
)
def _update_structures_with_proto(self, proto: ModelProto, bypass_shape: bool):
"""
Updates the shapes and types for an existing model.
"""
if self.verbose > 1:
begin_ = time.perf_counter()
print(
f"[GraphBuilder._update_structures_with_proto] starts with "
f"{len(proto.graph.node)} nodes"
)
self.opsets = {d.domain: d.version for d in proto.opset_import}
if self.ir_version is None:
self.ir_version = proto.ir_version
self.nodes = list(proto.graph.node)
for i in proto.graph.initializer:
self.add_initializer(i.name, i, allow_empty=True)
for i in proto.graph.sparse_initializer:
self.add_initializer(i.name, i, allow_empty=True)
self.functions = {}
self.functions_builder = {}
for f in proto.functions:
self.add_function(f)
self.value_info = list(proto.graph.value_info)
self.inputs = list(proto.graph.input)
self.outputs = list(proto.graph.output)
self.input_names = [i.name for i in proto.graph.input]
if hasattr(proto.graph, "value_info"):
available_shapes = {v.name: v for v in proto.graph.value_info}
else:
available_shapes = {}
for i in self.inputs + self.outputs:
self.set_name(i.name, f"_update_structures_with_proto_{i}")
self.set_type(i.name, i.type.tensor_type.elem_type)
if i.type.tensor_type.shape.dim:
shape = tuple(
d.dim_param if d.dim_param else d.dim_value
for d in i.type.tensor_type.shape.dim
)
if all(isinstance(s, int) for s in shape) and -1 in shape:
# Some converters uses -1 to specify a dynamic dimension.
# We replace the value with a string
new_shape = []
for index, s in enumerate(shape):
if s >= 0:
new_shape.append(s)
continue
dyn_name = f"{i.name}_{index}"
new_shape.append(dyn_name)
shape = tuple(new_shape)
for axis, sh in enumerate(shape):
if isinstance(sh, int):
continue
self.make_dynamic_object(
sh, self.torch.SymInt(sh), input_name=i.name, axis=axis
)
self.set_shape(i.name, shape)
if (
self.get_type(i.name) == TensorProto.INT64
and self.has_shape(i.name)
and self.get_shape(i.name) in ((1,), tuple)
and "dim" in i.name
):
self.set_value_shape(i.name, (i.name,))
need_identity_removal = False
new_nodes = []
for node in self.nodes:
self._unique_names |= set(node.output)
shape_set = self.simple_update_value_shape_with_node(node)
if node.name:
self._unique_node_names.add(node.name)
if shape_set:
for o in node.output:
if not self.has_name(o):
self.set_name(o, f"_update_structures_with_proto_n_{o}")
new_nodes.append(node)
continue
if node.op_type == "SequenceConstruct":
dtypes = [(self.get_type(n) if self.has_type(n) else 0) for n in node.input]
ranks = [(self.get_rank(n) if self.has_rank(n) else -1) for n in node.input]
unique_dtypes = set(dtypes)
if 0 in unique_dtypes:
unique_dtypes = set(u for u in unique_dtypes if u != 0)
if len(unique_dtypes) == 0:
if not self.has_name(node.output[0]):
self.set_name(
node.output[0],
f"_update_structures_with_proto_SC1_{node.output[0]}",
)
self.set_sequence(
node.output[0], 0, shapes=None, ranks=None, unknown=True
)
new_nodes.append(node)
continue
assert len(unique_dtypes) == 1, (
f"A sequence has distinct dtype: {dtypes}, "
f"(unique_dtypes={unique_dtypes}), node.name={node.name}, "
f"node.input={node.input}"
)
if not self.has_name(node.output[0]):
self.set_name(
node.output[0], f"_update_structures_with_proto_SC2_{node.output[0]}"
)
self.set_sequence(
node.output[0],
next(_ for _ in unique_dtypes),
shapes=None,
ranks=ranks,
)
new_nodes.append(node)
continue
if node.op_type == "SequenceAt":
position = self.get_constant(node.input[1], computed_value=True)
seq = self.get_sequence(node.input[0])
dtype = seq["dtype"]
if isinstance(dtype, tuple):
# More than one type is allowed in torch sequences.
dtype = dtype[int(position)]
if not self.has_name(node.output[0]):
self.set_name(
node.output[0], f"_update_structures_with_proto_SAt_{node.output[0]}"
)
self.set_type(node.output[0], dtype)
if seq["ranks"] is None:
new_nodes.append(node)
continue
rank = seq["ranks"][int(position)]
self.set_rank(node.output[0], rank)
new_nodes.append(node)
continue
assert node.op_type not in {
"SplitToSequence",
"SequenceErase",
"SequenceInsert",
"SequenceAt",
}, (
f"Sequence operators are not supported yet and op_type={node.op_type!r}"
f"(name={node.name!r})."
)
if node.op_type == "Constant":
exist = self.is_exact_same_constant(node)
if exist is not None:
node = oh.make_node(
"Identity",
[exist.output[0]],
[node.output[0]],
name="._update_structures_with_proto",
)
replaced = True
need_identity_removal = True
else:
self.add_constant_node(node)
replaced = False
self.update_node_constant(node.output[0], node)
if not self.has_name(node.output[0]):
self.set_name(
node.output[0], f"_update_structures_with_proto_Cst_{node.output[0]}"
)
if replaced:
self.set_type(node.output[0], self.get_type(node.input[0]))
self.set_shape(node.output[0], self.get_shape(node.input[0]))
else:
self.set_shape(node.output[0], self._get_tensor_shape(node))
self.set_type(node.output[0], self._get_tensor_type(node))
elif node.op_type == "ConstantOfShape" and self.is_constant(node.input[0]):
exist = self.is_exact_same_constant(node)
if exist is not None:
node = oh.make_node(
"Identity",
[exist.output[0]],
[node.output[0]],
name="._update_structures_with_proto",
)
replaced = True
need_identity_removal = True
else:
self.add_constant_node(node)
replaced = False
self.update_node_constant(node.output[0], node)
if not self.has_name(node.output[0]):
self.set_name(
node.output[0], f"_update_structures_with_proto_CoF_{node.output[0]}"
)
if replaced:
self.set_type(node.output[0], self.get_type(node.input[0]))
self.set_shape(node.output[0], self.get_shape(node.input[0]))
else:
value = self.get_constant(node.input[0], computed_value=True)
shape = tuple(int(i) for i in value)
self.set_shape(node.output[0], shape)
if len(node.attribute) == 0:
self.set_type(node.output[0], TensorProto.FLOAT)
else:
value = node.attribute[0].t
self.set_type(node.output[0], value.data_type)
else:
for o in node.output:
if o == "":
continue
if not self.has_name(o):
self.set_name(o, f"_update_structures_with_proto_l_{o}")
self._make_node_set_type_shape_constant(node, True)
self._make_node_set_type_shape(node)
new_nodes.append(node)
for o in node.output:
if o in available_shapes:
self._update_shape_types_with_proto_one_result(available_shapes[o])
if not bypass_shape:
if any(
(x not in available_shapes and not self.has_type(x)) for x in node.output
):
# second try
self._make_node_set_type_shape(node)
# This test should be enabled when shape inference is complete.
# assert all(
# map(
# lambda x: x in available_shapes or self.has_type(x), node.output
# )
# ), (
# f"One output of node {node.op_type!r} "
# f"(name={node.name!r}) has no type: "
# f"{', '.join(o + ((':' + str(self.get_type(o))) "
# f"if self.has_type(o) else ':0') for o in node.output)}"
# )
self.nodes = new_nodes
if need_identity_removal:
self.remove_identity_nodes()
if self.verbose > 1:
print(
f"[GraphBuilder._update_structures_with_proto] ends with "
f"{len(self.nodes)} nodes in "
f"{time.perf_counter() - begin_}"
)
[docs]
def parse_dimension_expression(self, expr: str, exc: bool = True) -> Expression:
"""
Parses an expression involving dimension.
:param expr: expr
:param exc: raises an exception if it fails
:return: an expression or None if exc is False and the parsing failed
"""
try:
return parse_expression(expr, exc=exc, context=self.dynamic_objects)
except AssertionError as e:
raise AssertionError(
f"Unable to parse an expression expr=[{expr!r}] "
f"due to {e}{self.get_debug_msg()}"
) from e
def _constant_key(self, node: NodeProto) -> Optional[bytes]:
"""
Builds a unique key for a constant.
Returns None if the constant if too big.
"""
if node.op_type == "ConstantOfShape":
# We assume initializer are fused.
name = node.input[0]
while name in self.constants_alias_:
name = self.constants_alias_[name]
key = [node.op_type.encode(), name.encode()]
for att in node.attribute:
key.append(att.SerializeToString())
return b"|".join(key)
if node.op_type == "Constant":
shape = self._get_tensor_shape(node)
size = np.prod(shape) if shape else 1
if size > self.optimization_options.constant_size:
# It would be too long.
return None
key = [node.op_type.encode()]
for att in node.attribute:
key.append(att.SerializeToString())
return b"|".join(key)
raise RuntimeError(f"Unexpected node type {node.op_type!r}{self.get_debug_msg()}")
[docs]
def add_constant_node(self, node: NodeProto) -> Optional[bytes]:
"""
Adds a constant node. Any constant equivalent to this one
will be fused.
`self.optimization_options.constant_fusing` must be True.
"""
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for a node"
if not self.optimization_options.constant_fusing:
return None
key = self._constant_key(node)
assert (
key not in self.constants_node_
), f"A constant with the same key {key!r} was already added{self.get_debug_msg()}"
self.constants_node_[key] = node
return key
[docs]
def is_exact_same_constant(self, node: NodeProto) -> Optional[NodeProto]:
"""
Adds a constant node. Any constant equivalent to this one
will be fused.
`self.optimization_options.constant_fusing` must be True.
"""
if not self.optimization_options.constant_fusing:
return None
key = self._constant_key(node)
if key in self.constants_node_:
return self.constants_node_[key]
return None
[docs]
def make_local_function(
self,
builder: "GraphBuilder",
function_options,
optimize: bool = False,
) -> Tuple[List[str], Tuple[str, str]]:
"""
Adds a local function to exiting graph.
:param builder: builder
:param function_options: to define how to handle weights
:param optimize: optimize the function
:return: the list of added initializers if
*move_initializer_to_constant* is True,
and the function name (domain, name),
it can be changed if one is already existing
Method :meth:`GraphBuilder.inline_functions`,
meth:`GraphBuilder.move_initializers_to_constant` are called on
the builder if *move_initializer_to_constant* is True.
It modifies the builder inplace.
"""
name = function_options.name
domain = function_options.domain
assert name, f"function_options is wrong {function_options!r}"
assert (
function_options.rename_allowed
or function_options.merge_allowed
or not self.has_local_function(name=name, domain=domain)
), f"Function {name!r}, domain={domain!r} already exists"
if function_options.move_initializer_to_constant:
if function_options.inline:
self._check_constants("before-inline_functions")
builder.inline_functions(verbose=max(0, self.verbose - 1))
self._check_constants("after-inline_functions")
assert function_options.return_initializer, (
f"incompatible options, return_initializer must be True "
f"but function_options={function_options!r}"
)
fct = builder.to_onnx(
function_options=function_options,
optimize=optimize,
)
assert isinstance(fct, dict), (
f"Unexpected type {type(fct)}, function_options={function_options}"
f"{self.get_debug_msg()}"
)
onx = fct["proto"]
doc_string = f"function_options={function_options!r}"
onx.doc_string += doc_string + (
f"\noptimized:{builder.optimization_options!r}" if optimize else "not-optimized"
)
to_rename = {}
keys = []
for key, f in builder.functions.items():
# What if on is already existing?
old_key = f.domain, f.name
new_key = self.add_function(
f,
rename_allowed=function_options.rename_allowed,
merge_allowed=function_options.merge_allowed,
builder=builder.functions_builder.get(key),
)
if new_key != old_key:
to_rename[old_key] = new_key
keys.append(new_key)
if to_rename:
# We rename the local functions.
onx = self.rename_in_local_functions(to_rename, keys, proto=onx)
# Let's rename the initializers.
if "initializers_dict" in fct:
repl = {}
inits = fct["initializers_dict"]
new_inits = []
for i in onx.input:
if i in inits:
new_inits.append(i)
for i in inits:
if i not in new_inits:
new_inits.append(i)
for k, v in inits.items():
new_name = self.add_initializer(self.unique_name(k), v)
repl[k] = new_name
new_inits = [repl.get(i, i) for i in new_inits]
else:
new_inits = []
assert isinstance(onx, FunctionProto), (
f"Unexpected type {type(onx)}, name={name!r}, domain={domain!r}, "
f"function_options={function_options}"
)
assert all(node.op_type != name or node.domain != domain for node in onx.node), (
f"Recursivity is not allowed in function {name!r}, domain={domain!r}, "
f"function_options={function_options}\n------ONNX----\n{pretty_onnx(onx)}"
f"{self.get_debug_msg()}"
)
assert (
not optimize
or not builder.optimization_options.remove_identity
or len([n for n in onx.node if n.op_type == "Identity"]) <= len(onx.output)
), (
f"The optimization was not applied. There are two many nodes identity"
f"\n{builder.pretty_text()}"
)
new_domain, new_name = self.add_function(
onx,
rename_allowed=function_options.rename_allowed,
merge_allowed=function_options.merge_allowed,
builder=builder,
)
if new_domain not in self.opsets:
self.opsets[new_domain] = 1
return new_inits, (new_domain, new_name)
[docs]
def rename_in_local_functions(
self,
replacements: Dict[Tuple[str, str], Tuple[str, str]],
list_keys: List[Tuple[str, str]],
proto: FunctionProto,
) -> FunctionProto:
"""
Renames local function in a given list of local functions.
:param replacements: replacements to make
:param list_keys: list of local function to modify
:param proto: one function to update as well
:return: the modified proto for proto
The function does not modify inplace the functions,
it creates a copy assuming this one is not too big.
"""
for key in list_keys:
assert (
key in self.functions
), f"Local function {key!r} is missing from {sorted(self.functions)}"
new_f = self._rename_op_type_in_local_functions(self.functions[key], replacements)
self.functions[key] = new_f
if proto is None:
return None
return self._rename_op_type_in_local_functions(proto, replacements)
@classmethod
def _detect_op_type_replacements(
cls,
proto: Union[FunctionProto, GraphProto],
replacements: Dict[Tuple[str, str], Tuple[str, str]],
) -> bool:
"""
Detects a replacements to make in a proto.
"""
for node in proto.node:
if (node.domain, node.op_type) in replacements:
return True
for att in node.attribute:
if att.type == AttributeProto.GRAPH:
if cls._detect_op_type_replacements(att.g, replacements):
return True
return False
def _rename_op_type_in_local_functions(
self,
proto: Union[FunctionProto, GraphProto],
replacements: Dict[Tuple[str, str], Tuple[str, str]],
) -> Union[FunctionProto, GraphProto]:
"""
Modifies the function to replace a call by another one.
:param proto: the proto to modify
:param replacements: the replacements to do
:return: the new proto, or the existing one if no replacements was found
"""
if not self._detect_op_type_replacements(proto, replacements):
return proto
if isinstance(proto, FunctionProto):
new_proto = FunctionProto()
elif isinstance(proto, GraphProto):
new_proto = GraphProto()
else:
raise AssertionError(f"Unexpected type {type(proto)}")
new_proto.ParseFromString(proto.SerializeToString())
self._check_constants("begin-renaming")
for node in new_proto.node:
key = node.domain, node.op_type
if key in replacements:
node.domain, node.op_type = replacements[key]
node_attributes = []
modified = False
for att in node.attribute:
if att.type != AttributeProto.GRAPH:
node_attributes.append(att)
continue
if not self._detect_op_type_replacements(att.g, replacements):
node_attributes.append(att)
continue
modified = True
node_attributes.append(
oh.make_attribute(
att.name,
self._rename_op_type_in_local_functions(att.g, replacements),
)
)
if modified:
del node.attribute[:]
node.attribute.extend(node_attributes)
self._check_constants("after-renaming", new_proto)
return new_proto
[docs]
def add_function(
self,
f: FunctionProto,
rename_allowed: bool = False,
merge_allowed: bool = False,
builder: Optional["GraphBuilder"] = None,
):
"""
Adds a new local function.
:param f: new function to register
:param rename_allowed: the function can be renamed if a function
with the same name already exists,
the proto is modified inplace
:param merge_allowed: the function is not added if another function
of the same name already exists and is the same
:param builder: GraphBuilder used to build the local function,
it contains shape information the function does not have
:return: function name
"""
key = f.domain, f.name
if merge_allowed and key in self.functions:
existing = self.functions[key]
if same_function_proto(existing, f):
# No need to add it again.
return f.domain, f.name
if rename_allowed and key in self.functions:
i = 2
new_name = f"{f.name}_2"
while (f.domain, new_name) in self.functions:
i += 1
new_name = f"{f.name}_{i}"
f.name = new_name
key = f.domain, f.name
assert key not in self.functions, (
f"Function {key} was already added, rename_allowed={rename_allowed}, "
f"merge_allowed={merge_allowed}, same: "
f"{same_function_proto(self.functions[key], f)}"
f"\n{self.pretty_text()}"
)
self.functions[key] = f
if builder:
self.functions_builder[key] = builder
return f.domain, f.name
[docs]
def has_local_function(self, name: str, domain: str = "") -> bool:
"""
Checks if a local function exists.
"""
return (domain, name) in self.functions
[docs]
def get_local_function(self, name: str, domain: str = "") -> FunctionProto:
"""
Returns a local function.
"""
return self.functions[domain, name]
[docs]
def get_local_function_outputs(self, name: str, domain: str = "") -> List[str]:
"""
Returns the outputs of a local function.
"""
return self.functions[domain, name].output
[docs]
def register_constraint_dimension(self, dim_name: str, value: Any):
"""
Registers a constraint on a dimension.
:param dim_name: dimension name
:param value: value to register
"""
if dim_name not in self.constraints_:
self.constraints_[dim_name] = set()
self.constraints_[dim_name].add(value)
def _to_torch_tensor(self, a: Any) -> "torch.Tensor": # noqa: F821
"""
Torch does not convert numpy dtype very well.
"""
if isinstance(a, self.torch.Tensor):
return a
if isinstance(a, np.ndarray):
if len(a.shape) == 0:
# Then torch may consider this as a the creation of empty array.
tt = self.torch.from_numpy(a.reshape((1,)).copy())
tt = tt[0]
else:
tt = self.torch.Tensor(a)
ttype = onnx_dtype_to_torch_dtype(dtype_to_tensor_dtype(a.dtype))
res = tt.to(ttype)
assert a.shape == tuple(res.shape), (
f"Unexpected shape {res.shape}, expecting shape={a.shape}, "
f"dtype={res.dtype}, expected dtype={a.dtype}"
)
return res
raise AssertionError(
f"Unsupported type {type(a)}, unable to convert to a torch.Tensor."
)
[docs]
def inline_functions(self, verbose: int = 0) -> int:
"""
Inlines local functions.
Returns the number of inlined nodes.
"""
if not self.functions:
# Nothing to do
return
begin0 = time.perf_counter()
# Checks opsets
for v in self.functions.values():
for op in v.opset_import:
if op.domain in self.opsets:
assert op.version == self.opsets[op.domain], (
f"Opset version mismatch for domain {op.domain!r}, "
f"existing version is {self.opsets[op.domain]}, "
f"version for function {v.name!r} is {op.version}"
)
else:
self.opsets[op.domain] = op.version
if verbose:
print(f"[inline_functions] begin graph {id(self)}")
stats = []
self._check(stats, step="before inline")
begin = time.perf_counter()
inlined = self._inline_functions_iteration(verbose=verbose)
stats.append(
dict(
pattern="inline",
time_inline=time.perf_counter() - begin,
iteration=0,
inlined=inlined,
)
)
self._check(stats, step="after inline iteration 0")
it = 0
while inlined:
it += 1
begin = time.perf_counter()
inlined = self._inline_functions_iteration(verbose=verbose)
stats.append(
dict(
pattern="inline",
time_inline=time.perf_counter() - begin,
iteration=0,
inlined=inlined,
)
)
self._check(stats, step=f"after inline iteration {it}")
# We can remove the local functions now.
self.functions = {}
if verbose:
print(f"[inline_functions] done graph {id(self)} in {time.perf_counter()-begin0}")
return stats
def local_functions_found(self, g: GraphProto) -> bool:
for node in g.node:
if node.domain == "":
continue
key = node.domain, node.op_type
if key in self.functions:
return True
for att in node.attribute:
assert att.type != AttributeProto.GRAPHS, (
f"node.op_type={node.op_type!r}, node.name={node.name!r}, "
f"att.name={att.name!r}, not implemented with att.type={att.type} "
f"(AttributeProto.GRAPHS)"
)
if att.type == AttributeProto.GRAPH:
if self.has_local_function(att.g):
return True
return False
def _inline_functions_iteration(self, verbose: int = 0) -> int:
"""
Inlines local functions. Returns the number of replacements.
"""
n_replacements = 0
replacements = []
for pos, node in enumerate(self.nodes):
for att in node.attribute:
assert att.type != AttributeProto.GRAPHS, (
f"node.op_type={node.op_type!r}, node.name={node.name!r}, "
f"att.name={att.name!r}, not implemented with att.type={att.type} "
f"(AttributeProto.GRAPHS)"
)
if self.local_functions_found(att.g):
# A function was detected in a subgraphs.
if verbose:
print(
f"[_inline_functions_iterations] replace local "
f"functions in node {node.op_type!r}, name={node.name!r}"
)
n_replacements += self._inline_functions_subgraph(att.g, verbose)
key = node.domain, node.op_type
if key not in self.functions:
continue
n_replacements += 1
if verbose:
print(
f"[_inline_functions_iterations] inline function "
f"{self.functions[key].name!r} domain {self.functions[key].domain!r}"
)
new_nodes = self._convert_function(node.input, node.output, self.functions[key])
if verbose:
print(
f"[_inline_functions_iterations] {len(new_nodes)} new nodes "
f"for {self.functions[key].name!r}, {self.functions[key].domain!r}"
)
replacements.append((pos, node, new_nodes))
if n_replacements == 0:
# No replacements to do.
return n_replacements
stat = []
for pos, node, new_nodes in reversed(replacements):
self.insert_and_remove_nodes(pos, new_nodes, removed=[pos])
self._check(stat, step=f"after inlining function {node.op_type!r}")
return n_replacements
def _inline_functions_subgraph(self, g: GraphProto, verbose: int = 0) -> int:
"""
Inlines local functions in subgraph (inplace).
Returns the number of replacements.
"""
stat = []
self._check(stat, step="before inline")
inlined = self._inline_functions_subgraph_iteration(g, verbose=verbose)
self._check(stat, step="after inline iteration 0")
total = inlined
it = 0
while inlined:
it += 1
inlined = self._inline_functions_subgraph_iteration(g, verbose=verbose)
total += inlined
it += 1
self._check(stat, step=f"after inline iteration {it}")
return total
def _inline_functions_subgraph_iteration(self, g: GraphProto, verbose: int = 0) -> int:
new_nodes = []
if verbose:
print(f"[_inline_functions_subgraph_iteration] begin with {id(g)}")
n_replacements = 0
for node in g.node:
for att in node.attribute:
assert att.type != AttributeProto.GRAPHS, (
f"node.op_type={node.op_type!r}, node.name={node.name!r}, "
f"att.name={att.name!r}, not implemented with att.type={att.type} "
f"(AttributeProto.GRAPHS)"
)
if self.local_functions_found(att.g):
# A function was detected in a subgraphs.
if verbose:
print(
f"[_inline_functions_subgraph_iteration] replace local "
f"functions in node {node.op_type!r}, name={node.name!r}"
)
n_replacements += self._inline_functions_subgraph(att.g, verbose)
key = node.domain, node.op_type
if key not in self.functions:
new_nodes.append(node)
continue
n_replacements += 1
if verbose:
print(
f"[_inline_functions_subgraph_iteration] inline function "
f"{self.functions[key].name!r} domain {self.functions[key].domain!r}"
)
functions_nodes = self._rename_results(
self._convert_function(node.input, node.output, self.functions[key]),
replacements={n: n for n in (set(node.input) | set(node.output))},
)
new_nodes.extend(functions_nodes)
if verbose:
print(
f"[_inline_functions_subgraph_iteration] {len(new_nodes)} new nodes "
f"for {self.functions[key].name!r}, {self.functions[key].domain!r}"
)
if n_replacements > 0:
del g.node[:]
g.node.extend(new_nodes)
if verbose:
print(
f"[_inline_functions_subgraph_iteration] done with "
f"{id(g)} and {n_replacements} replacements"
)
return n_replacements
def _rename_results(
self, nodes: List[NodeProto], replacements: Dict[str, str]
) -> List[NodeProto]:
new_nodes = []
for node in nodes:
assert all(i in replacements for i in node.input), (
f"An input from {node.input} is inkonwn input in node "
f"{node.op_type!r}, name={node.name!r}"
)
new_outputs = []
for o in node.output:
if o in replacements:
assert (
replacements[o] == o
), f"o={o!r} must be an output in {sorted(replacements)!r}"
new_outputs.append(o)
continue
new_name = self.unique_name(o)
replacements[o] = new_name
new_outputs.append(new_name)
new_node = oh.make_node(
node.op_type,
[replacements[i] for i in node.input],
new_outputs,
name=self.unique_node_name(node.name),
domain=node.domain,
)
new_atts = []
for att in node.attribute:
assert att.type != AttributeProto.GRAPHS, (
f"node.op_type={node.op_type!r}, node.name={node.name!r}, "
f"att.name={att.name!r}, not implemented with att.type={att.type} "
f"(AttributeProto.GRAPHS)"
)
new_atts.append(
oh.make_attribute(
att.name,
self._rename_results_in_subgraph(
att.g, replacements=replacements.copy()
),
)
if att.type == AttributeProto.GRAPH
else att
)
new_node.attribute.extend(new_atts)
new_nodes.append(new_node)
return new_nodes
def _rename_results_in_subgraph(
self, g: GraphProto, replacements: Dict[str, str]
) -> GraphProto:
set_rep = set(replacements)
new_nodes = []
do = False
for node in g:
diff = bool(set(node.input) & set_rep)
new_inputs = [replacements.get(i, i) for i in node.input]
new_atts = []
for att in node.attribute:
assert att.type != AttributeProto.GRAPHS, (
f"node.op_type={node.op_type!r}, node.name={node.name!r}, "
f"att.name={att.name!r}, not implemented with att.type={att.type} "
f"(AttributeProto.GRAPHS)"
)
new_g = self._rename_results_in_subgraph(
att.g, replacements=replacements.copy()
)
if id(new_g) != id(g):
diff = True
new_atts.append(
oh.make_attribute(att.name, new_g)
if att.type == AttributeProto.GRAPH
else att
)
else:
new_atts.append(att)
if diff:
do = True
new_node = oh.make_node(
node.op_type, new_inputs, node.output, domain=node.domain, name=node.name
)
new_node.attribute.extend(new_atts)
new_nodes.append(new_node)
else:
new_nodes.append(node)
if not do:
return g
g2 = oh.make_graph(
new_nodes, g.name, g.input, g.output, g.initializer, g.sparse_initializer
)
return g2
def _convert_function(
self, inputs: List[str], outputs: List[str], proto: FunctionProto
) -> List[NodeProto]:
"""
Converts a function into a list of nodes.
:param inputs: inputs in the calling nodes
:param outputs: outputs in the calling nodes
:param proto: function proto
:return: list of nodes
"""
renamed = dict(zip(proto.input, inputs))
renamed.update(dict(zip(proto.output, outputs)))
new_nodes = []
for node in proto.node:
new_inputs = []
for name in node.input:
assert name in renamed, f"Unable to find {name!r} in renamed={renamed}"
new_inputs.append(renamed[name])
new_outputs = []
for name in node.output:
if name in renamed:
new_outputs.append(renamed[name])
else:
new_name = self.unique_name(name)
renamed[name] = new_name
new_outputs.append(new_name)
new_node = oh.make_node(
node.op_type,
new_inputs,
new_outputs,
name=self.unique_node_name(node.name),
domain=node.domain,
)
new_node.attribute.extend(node.attribute)
new_nodes.append(new_node)
return new_nodes
[docs]
def move_initializers_to_constant(
self, full_parameter_name, threshold: Optional[int] = None, verbose: int = 0
) -> int:
"""
Moves initializers as constant nodes.
:param full_parameter_name: keeps the local name or the full name for the parameters
:param threshold: only move intializers to constant if their size is below this limit
:param verbose: verbosity
:return: number of moves iniatliazers
"""
if not self.initializers_dict:
return
initializers, _ = self._build_initializers(
switch_low_high=sys.byteorder != "big",
large_model=True,
external_threshold=threshold or 2**30,
full_parameter_name=full_parameter_name,
)
to_remove = []
cst_nodes = []
for proto in initializers:
if proto.external_data:
# external tensor
continue
to_remove.append(proto.name)
if self.verbose:
print(
f"[move_initializers_to_constant] convert "
f"{proto.name!r} into a node 'Constant'"
)
cst = oh.make_node(
"Constant",
[],
[proto.name],
value=proto,
name=self.unique_node_name("init2cst"),
)
cst_nodes.append(cst)
self.add_constant_node(cst)
self.update_node_constant(proto.name, cst)
assert self.has_type(proto.name), f"Type is missing for initializer {proto.name!r}"
assert self.has_shape(
proto.name
), f"Shape is missing for initializer {proto.name!r}"
if to_remove:
remove = set(to_remove)
self.initializers_dict = {
k: v for k, v in self.initializers_dict.items() if k not in remove
}
self.nodes = [*cst_nodes, *self.nodes]