Source code for experimental_experiment.xshape._inference_runtime

import time
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import onnx
from onnx_diagnostic.helpers import string_type
from ..helpers import tensor_dtype_to_np_dtype, onnx_dtype_to_torch_dtype
from ..reference import ExtendedReferenceEvaluator
from ._shape_helper import (
    all_int,
    _reshape_shape,
    is_static_shape,
    reshape_implementation_with_zero,
)
from .shape_type_compute import set_shape_type_op_any, set_shape_type_custom


[docs] class _InferenceRuntime: """Sets shape and type."""
[docs] def make_dimension_name_if_necessary( self, a: Union[int, str], b: Union[int, str], op: str ) -> str: """Creates a new dimension.""" if op == "^": # very simple trick for the time being if a == b: return a if isinstance(a, str) and a.endswith(f"^{b}"): return a if isinstance(b, str) and b.startswith(f"{a}^"): return b if isinstance(a, str) and set(a) & set("+/*-^"): a = f"({a})" if isinstance(b, str) and set(b) & set("+/*-^"): b = f"({b})" return f"{a}{op}{b}"
[docs] def _make_node_set_type_shape(self, node: onnx.NodeProto, exc: bool = False): """Updates shapes for a node.""" update = self._make_node_set_type_shape_constant(node, {}) if update is None: if node.domain == "": node.doc_string += "#Io1" update = set_shape_type_op_any(self, node, exc=exc) else: # Missing type means it is probably coming from an inlined function. node.doc_string += ( "#Io3" if node.input and not self.has_type(node.input[0]) else "#Io2" ) update = set_shape_type_custom(self, node, exc=exc) if update: self._calls.append( (node.name, node.domain, node.op_type, node.input, node.output, update) ) assert update is not None or not self._debug_shape_missing, ( f"Shape missing for node type {node.op_type!r}, inputs={node.input}, " f"outputs={node.output}\n----\n{node}\n{self.get_debug_msg()}" ) return update
[docs] def update_node_constant(self, name: str, node: onnx.NodeProto) -> bool: """Updates a constant NodeProto.""" assert isinstance(name, str), f"Unexpected type {type(name)} for name" assert node is None or isinstance( node, onnx.NodeProto ), f"Unexpected type {type(node)} for name={name!r}" if node is not None and node.op_type.startswith("Random"): return False if self.verbose > 2: print( f"[GraphBuilder-{self._hash()}.update_node_constant] new constant " f"{name!r}, node={None if node is None else node.op_type}" ) assert ( node is None or node.op_type == "Shape" or all(self.is_constant(i) for i in node.input if i not in {"", None, "None"}) ), ( f"Output {name!r} is constant (node={self.pretty_node(node)}) " f"only if every input from {node.input} is constant " f"but constants={[self.is_constant(i) for i in node.input]}{self.get_debug_msg()}" ) self.constants_[name] = node return True
def _make_node_set_type_shape_constant( self, node: onnx.NodeProto, sts: Optional[Dict[str, Any]] ): if node.domain != "": return if all(self.is_constant(i) for i in node.input): for o in node.output: self.update_node_constant(o, node) if node.op_type == "Constant": assert ( len(node.attribute) == 0 or node.attribute[0].name != "value" or node.attribute[0].type != onnx.AttributeProto.GRAPH ), f"{node}" if len(node.attribute) == 1 and node.attribute[0].name == "value": size = np.prod(node.attribute[0].t.dims) else: size = len(node.SerializeToString()) assert size < self.optimization_options.constant_size, ( f"A node Constant holds a tensor bigger than " f"the constant: {size} >= {self.optimization_options.constant_size}." ) k = node.output[0] self.update_node_constant(k, node) node.doc_string += ":constant-3:" shape = self._get_tensor_shape(node) dtype = self._get_tensor_type(node) self.set_shape(k, shape) self.set_type(k, dtype) if self.verbose > 2 or np.prod(shape) > 100: print(f"[GraphBuilder-{self._hash()}.make_node] {k}[{dtype}: {shape}]") return shape elif node.op_type == "ConstantOfShape": if len(node.attribute) == 1 and node.attribute[0].name == "value": itype = node.attribute[0].t.data_type else: itype = onnx.TensorProto.FLOAT self.set_type(node.output[0], itype) if self.is_constant(node.input[0]): value = self.get_constant( node.input[0], computed_value=True, as_shape=True, exc=False ) if value is not None: # This is needed when concatenating caches. self.set_shape(node.output[0], value, allow_zero=True) node.doc_string += ":constant-9:" return value vs = self.value_as_shape(node.input[0]) if vs is not None: self.set_shape(node.output[0], vs, allow_zero=True) return vs if self.has_shape(node.input[0]): shape = self.get_shape(node.input[0]) if is_static_shape(shape): self.set_rank(node.output[0], shape[0]) return True elif node.op_type == "Identity": shape = None if self.has_shape(node.input[0]): # allow_zero is True but if it fails here, it means it did not fail # before when it should be. shape = self.get_shape(node.input[0]) self.set_shape(node.output[0], shape, allow_zero=True) elif self.has_rank(node.input[0]): self.set_rank(node.output[0], self.get_rank(node.input[0])) if self.has_type(node.input[0]): self.set_type(node.output[0], self.get_type(node.input[0])) if self.is_constant(node.input[0]): self.update_node_constant(node.output[0], node) node.doc_string += ":constant-4:" return shape elif node.op_type == "Expand": if self.has_type(node.input[0]): self.set_type(node.output[0], self.get_type(node.input[0])) if ( self.has_shape(node.input[0]) and is_static_shape(self.get_shape(node.input[0])) and self.is_constant(node.input[1]) ): cst, _ = self.compute_constant(node.input[1], exc=False, only_array=True) if cst is not None: assert not isinstance(cst, self.torch._subclasses.fake_tensor.FakeTensor), ( f"self.compute_constant returns a FakeTensor for {node.input[1]!r}" f"\n{self.pretty_text()}" ) assert ( not self.has_rank(node.input[1]) or self.get_rank(node.input[1]) == 1 ), ( f"Unexpected rank {self.get_rank(node.input[1])} for {node.input[1]!r}" f"{self.get_debug_msg()}" ) with self.maybe_disable_fake_tensor_mode(): assert len(cst.shape) == 1 and cst[-1] > 0, ( f"Unexpected shape {cst.shape} " f"for computed constant {node.input[1]!r}, " f"input={node.input}, cst={cst}{self.get_debug_msg()}" ) shape = self.get_shape(node.input[0]) new_shape = tuple(int(i) for i in cst) if len(shape) < len(new_shape): shape = (1,) * (len(new_shape) - len(shape)) + shape new_shape = tuple(max(i, j) for i, j in zip(shape, new_shape)) self.set_shape(node.output[0], new_shape, allow_zero=0 in shape) return new_shape elif node.op_type == "Reshape": if self.has_type(node.input[0]): self.set_type(node.output[0], self.get_type(node.input[0])) if self.is_constant(node.input[1]): cst, _ = self.compute_constant( node.input[1], exc=False, only_array=True, allow_empty=True ) if cst is not None: shape_cst = tuple(int(i) for i in cst) if 0 in shape_cst: if self.has_shape(node.input[0]): sh = self.get_shape(node.input[0]) shape_cst_last_zero = shape_cst[ : len(shape_cst) - 1 - shape_cst[::-1].index(0) + 1 ] assert len(sh) >= len(shape_cst_last_zero), ( f"Shape discrepancies for name={node.input[0]!r} " f"node.name={node.name!r} " f"between sh={sh} and shape_cst={shape_cst}, " f"shape_cst_last_zero={shape_cst_last_zero}" f"\ncst={cst}{self.get_debug_msg()}" ) shape_cst = tuple( [ shape_cst[i] if shape_cst[i] != 0 else sh[i] for i in range(len(shape_cst)) ] ) else: shape_cst = None if shape_cst is not None: if -1 in shape_cst: if self.has_shape(node.input[0]): sh = self.get_shape(node.input[0]) if is_static_shape(sh): new_shape = _reshape_shape(sh, shape_cst) self.set_shape(node.output[0], new_shape, allow_zero=0 in sh) node.doc_string += ":constant-7a:" return new_shape else: self.set_shape(node.output[0], shape_cst) node.doc_string += ":constant-7b:" return shape_cst elif node.op_type == "Shape": ret_shape = None self.set_type(node.output[0], onnx.TensorProto.INT64) if self.has_rank(node.input[0]): rk = self.get_rank(node.input[0]) if len(node.attribute) == 0: self.set_shape(node.output[0], (rk,)) else: start = self.get_attribute_with_default(node, "start", 0) if start < 0: start += rk end = self.get_attribute_with_default(node, "end", rk) if end < 0: end += rk self.set_shape(node.output[0], (end - start,)) ret_shape = (end - start,) elif node.attribute: start = self.get_attribute_with_default(node, "start", 0) end = self.get_attribute_with_default(node, "end", None) if end is not None and end - start > 0: self.set_shape(node.output[0], (end - start,)) else: self.set_rank(node.output[0], 1) assert not self._debug_shape_missing, ( f"Unable to compute the shape of this shape: " f"{self.pretty_node(node, shape=True)}{self.get_debug_msg()}" ) else: self.set_rank(node.output[0], 1) assert not self._debug_shape_missing, ( f"Unable to compute the shape of this shape: " f"{self.pretty_node(node, shape=True)}{self.get_debug_msg()}" ) if self.is_constant(node.input[0]) or ( self.has_shape(node.input[0]) and all_int(self.get_shape(node.input[0])) ): self.update_node_constant(node.output[0], node) node.doc_string += ":constant-2:" return ret_shape elif node.op_type == "Size": self.set_type(node.output[0], onnx.TensorProto.INT64) self.set_shape(node.output[0], tuple()) if self.is_constant(node.input[0]): self.update_node_constant(node.output[0], node) node.doc_string += ":constant-2s:" return tuple() elif not sts: if node.op_type == "GatherElements": if self.has_type(node.input[0]): self.set_type(node.output[0], self.get_type(node.input[0])) if self.has_shape(node.input[1]): self.set_shape(node.output[0], self.get_shape(node.input[1])) return self.get_shape(node.input[1]) elif self.has_rank(node.input[1]): self.set_rank(node.output[0], self.get_rank(node.input[1]))
[docs] def compute_constant( self, name: str, exc: bool = True, only_array: bool = False, allow_empty: bool = False ) -> Tuple[np.ndarray, Optional[Dict[str, np.ndarray]]]: """ Computes a constant. :param name: constant name :param exc: raises an exception if any failure :param only_array: do not return TensorProto :param allow_empty: allow empty result :return: constant If returns None if the constant is a FakeTensor. """ assert self.is_constant(name), f"Name {name!r} is not a constant" v = self.constants_[name] # It should not be None but a node as it is not an initializer. if isinstance(v, onnx.TensorProto): return self.get_constant(name, computed_value=True, exc=exc), None assert isinstance( v, onnx.NodeProto ), f"Unexpected type {type(v)} for constant name={name!r}" if self._debug_get_constant: print(f"[GraphBuilder-{self._hash()}.compute_constant] {self.pretty_node(v)}") if v.op_type == "Shape": if not self.has_shape(v.input[0]): # We stop. assert not self._debug_constant_folding, ( f"Unable to compute constant because {v.input[0]!r} has no shape" f"in node {self.pretty_node(v)}{self.get_debug_msg()}" ) return None, None shape = self.get_shape(v.input[0]) if is_static_shape(shape): if v.attribute: start = 0 end = None for att in v.attribute: if att.name == "start": start = att.i elif att.name == "end": end = att.i shape = shape[start:] if end is None else shape[start:end] if self._debug_get_constant: print( f"[GraphBuilder-{self._hash()}.compute_constant] - SHAPE " f"{name}: {shape}? start={start}, end={end}" ) elif self._debug_get_constant: print( f"[GraphBuilder-{self._hash()}.compute_constant] " f" - SHAPE {name}: {shape}?" ) return np.array(shape, dtype=np.int64), { v.input[0]: self.ShapeConstant(v.input[0], shape, v) } if not self.is_constant(v.input[0]): # One exception here as the input maybe not # be constant but the shape may be known. assert all_int(shape), ( f"Shape must be static ({shape}) if shape is constant in {v} in " f"{self.pretty_node(v)}{self.get_debug_msg()}" ) with self.maybe_disable_fake_tensor_mode(): output = self._apply_shape_on_shape(v, shape) if isinstance(output[0], self.torch.Tensor): # We convert the tensor into numpy array, # it is a small shape anyway so the FakeMode # does not come up as an issue. output = [output[0].detach().cpu().numpy()] if self._debug_get_constant: print( f"[GraphBuilder-{self._hash()}.compute_constant] - A " f"{name}: {self.pretty_tensor(output[0])}" ) return output[0], {v.input[0]: self.ShapeConstant(v.input[0], shape, v)} assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}" f"{self.get_debug_msg()}" ) return None, None feeds = {i: self.get_constant(i, exc=exc, computed_value=True) for i in v.input} for kval, val in feeds.items(): if not exc and "FakeTensor" in str(type(val)): assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}" f"because a FakeTensor appeared{self.get_debug_msg()}" ) return None, None assert "FakeTensor" not in str(type(val)), ( f"FakeTensor {kval!r} cannot be an initializer {type(val)}, " f"v.op_type={v.op_type!r}" f"{self.get_debug_msg()}" ) if val is None: assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}" f"because val=None{self.get_debug_msg()}" ) return None, None with self.maybe_disable_fake_tensor_mode(): if v.op_type == "Identity": # much faster this way output = [feeds[v.input[0]]] elif v.op_type == "Reshape": # much faster this way output = [ reshape_implementation_with_zero(feeds[v.input[0]], tuple(feeds[v.input[1]])) ] elif v.op_type in { "Add", "Div", "Mul", "Sub", }: # bypassing onnx.numpy_helper.from_array, too slow output = self._apply_binary_op(v, feeds) elif ( v.op_type == "Pow" and self.has_type(v.input[0]) and self.has_type(v.input[1]) and self.get_type(v.input[0]) == self.get_type(v.input[1]) ): output = self._apply_binary_op(v, feeds) elif v.op_type in {"Exp", "Log", "Reciprocal", "Sqrt"}: # bypassing onnx.numpy_helper.from_array, too slow output = self._apply_unary_function(v, feeds) elif hasattr(self, f"_apply_{v.op_type.lower()}"): output = getattr(self, f"_apply_{v.op_type.lower()}")(v, feeds) elif all(isinstance(v, np.ndarray) for v in feeds.values()): if v.op_type not in {"Constant", "ConstantOfShape"} and self.main_opset < 18: # This functionality is not enabled before that opset. if self._debug_get_constant: print( f"[GraphBuilder-{self._hash()}.compute_constant] fails " f"because opset={self.main_opset} for name={name!r}, " f"node={self.pretty_node(v)}" ) assert not self._debug_constant_folding, ( f"Unable to compute constant opset={self.main_opset}<18" f"for name={name!r}{self.get_debug_msg()}" ) return None, None # Let's avoid big computation on CPU. max_dim = 0 for _v in feeds.values(): max_dim = max(max_dim, np.prod(_v.shape)) if max_dim >= 2**22: if self.verbose > 1: print( f"[GraphBuilder-{self._hash()}.compute_constant] stop computing a " f"constant as it may be too big, shapes are " f"{[_.shape for _ in feeds.values()]}" ) assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}" f"because max_dim={max_dim} (shape={_v.shape}){self.get_debug_msg()}" ) return None, None begin = time.perf_counter() ref = ExtendedReferenceEvaluator(v) try: output = ref.run(None, feeds) except (ValueError, TypeError) as e: sf = ", ".join(f"{k}:{v.dtype}:{v.shape}" for k, v in feeds.items()) if "warnings" not in self._debug_msg: self._debug_msg["warnings"] = [] sv = str(v).replace("\n", " ") self._debug_msg["warnings"].append(f"Issue with v={sv}, feeds={sf}, e={e}") self.time_evaluation_constants_ += time.perf_counter() - begin assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}" f"due to {e}{self.get_debug_msg()}" ) return None, None self.time_evaluation_constants_ += time.perf_counter() - begin else: assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}, " f"feeds={string_type(feeds, with_shape=True, with_min_max=True, limit=20)}" f"{self.get_debug_msg()}" ) return None, None cst = None for n, val in zip(v.output, output): assert not isinstance(val, tuple), f"Unexpected type {type(val)} for n={n!r}" assert "FakeTensor" not in str(type(val)), ( f"FakeTensor detected {type(val)} in constant {name!r}, " f"v.op_type={v.op_type!r}{self.get_debug_msg()}" ) if self.has_type(n): # numpy changes the expected type sometimes # (like transpose(x: float36) --> float32) itype = self.get_type(n) if hasattr(val, "detach"): val = val.to(onnx_dtype_to_torch_dtype(itype)) else: val = val.astype(tensor_dtype_to_np_dtype(itype)) self.constants_computed_[n] = val if name == n: cst = val assert ( len(cst.shape) == 0 or min(cst.shape) > 0 or (v.op_type in {"ConstantOfShape", "Cast", "Identity", "Constant"}) ), ( f"Output has empty shape {cst.shape}, name={name!r} " f"v.op_type={v.op_type!r}, v.name={v.name!r}{self.get_debug_msg()}" ) assert cst is not None, f"Constant {name!r} was not found in {v.output}" if hasattr(self, "torch") and isinstance( cst, self.torch._subclasses.fake_tensor.FakeTensor ): assert not self._debug_constant_folding, ( f"Unable to compute constant for node {self.pretty_node(v)}" f"because a FakeTensor appeared{self.get_debug_msg()}" ) return None, None if self._debug_get_constant: print( f"[GraphBuilder-{self._hash()}.compute_constant] " f" - A {name}: {self.pretty_tensor(cst)}" ) assert ( not self._debug_constant_folding or cst is not None ), f"Unable to compute constant for node {self.pretty_node(v)}{self.get_debug_msg()}" return cst, feeds