Source code for experimental_experiment.xshape._shape_runtime

import numpy as np
import onnx
from ._shape_helper import all_int


[docs] class _ShapeRuntime: """Runs through a few nodes often used to deal with shapes."""
[docs] def simple_update_value_shape_with_node(self, node: onnx.NodeProto) -> bool: """Updates ``_known`_value_shape`` for a particular node.""" if node.domain != "" or node.op_type not in { "Abs", "Add", "Concat", "Div", "Gather", "Identity", "Mod", "Mul", "Range", "Scatter", "Shape", "Slice", "Squeeze", "Sub", "Unsqueeze", }: return False # Constant can be considered as possible shape. for i in node.input: known = self.value_as_shape(i) if known is not None: continue if not self.is_constant(i): continue if not self.has_type(i) or self.get_type(i) != onnx.TensorProto.INT64: # No chance for this to be used a shape computation. continue cst = self.get_constant(i, exc=False, computed_value=True) if cst is None or len(cst.shape) > 1: continue with self.maybe_disable_fake_tensor_mode(): tu = tuple(map(int, cst)) if len(cst.shape) > 0 else int(cst) self.set_value_shape(i, tu) if node.op_type in {"Identity", "Abs"}: value = self.value_as_shape(node.input[0]) if value is not None: node.doc_string += "#SV-Id1" self.set_value_shape( node.output[0], ( np.abs(value) if node.op_type == "Abs" and all(isinstance(s, (int, float)) for s in value) else value ), equal_to=(node.input[0], node.output[0]), ) return True node.doc_string += "#SV-Id/2" return False if node.op_type == "Shape": if len(node.attribute) == 0: if self.has_shape(node.input[0]): node.doc_string += "#SV-Sh1" shape = self.get_shape(node.input[0]) self.set_value_shape(node.output[0], shape) if all_int(shape): self.update_node_constant(node.output[0], node) self.set_shape(node.output[0], (len(shape),)) return True node.doc_string += "#SV-Sh/1" return False start = self.get_attribute(node, "start", exc=False) end = self.get_attribute(node, "end", exc=False) assert end is None or start is None or end.i < 0 or start.i < end.i, ( f"Shape(..., end < start) is not implemented, node={self.pretty_node(node)}, " f"start={start}, end={end}{self.get_debug_msg()}" ) if end is None: if self.has_rank(node.input[0]): end = self.get_rank(node.input[0]) if self.has_shape(node.input[0]): shape = self.get_shape(node.input[0]) assert start is None or start.i < len(shape), ( f"Shape mismatch, start={0 if start is None else start.i}, " f"shape of {node.input[0]!r} " f"is {shape}{self.get_debug_msg()}" ) if end is None: n_shape = shape[0 if start is None else start.i :] self.set_value_shape(node.output[0], n_shape) if all_int(shape): self.update_node_constant(node.output[0], node) self.set_shape(node.output[0], (len(n_shape),)) node.doc_string += "#SV-Sh4" return True assert getattr(end, "i", end) <= len(shape), ( f"Shape mismatch, end={getattr(end, 'i', end)}, " f"shape of {node.input[0]!r} " f"is {shape}{self.get_debug_msg()}" ) n_shape = shape[0 if start is None else start.i : getattr(end, "i", end)] if all_int(shape): self.update_node_constant(node.output[0], node) self.set_value_shape(node.output[0], n_shape) self.set_shape(node.output[0], (len(n_shape),)) node.doc_string += "#SV-Sh6" return True if end is None: self.set_value_shape(node.output[0], f"{node.input[0]}[{start.i}:]") node.doc_string += "#SV-Sh/6" return False start = start.i end = getattr(end, "i", end) if isinstance(start, int) and isinstance(end, int): self.set_value_shape( node.output[0], tuple(f"{node.input[0]}[{i}]" for i in range(start, end)) ) node.doc_string += "#SV-Sh7" else: self.set_value_shape(node.output[0], f"{node.input[0]}[{start}:{end}]") node.doc_string += "#SV-Sh7" return True if node.op_type == "Gather": if self.has_type(node.input[0]): self.set_type(node.output[0], self.get_type(node.input[0])) if self.is_constant(node.input[1]): y = self.value_as_shape(node.input[0]) if y is None: node.doc_string += "#SV-Ga/2" return False i = self.get_constant(node.input[1], computed_value=True, exc=True) if i is None: node.doc_string += "#SV-Ga/3" return False if isinstance(y, str) and isinstance(i, int): self.set_value_shape(node.output[0], f"{y}[{i}]") node.doc_string += "#SV-Ga3" self.set_shape(node.output[0], tuple()) return True if ( isinstance(y, str) and isinstance(i, np.ndarray) and i.dtype == np.int64 and i.shape in ((1,), tuple()) ): ii = int(i[0]) if i.shape == (1,) else int(i) self.set_value_shape(node.output[0], f"{y}[{ii}]") node.doc_string += "#SV-Ga4" self.set_shape(node.output[0], (1,) if i.shape == (1,) else tuple()) return True if isinstance(y, tuple) and isinstance(i, int): self.set_value_shape(node.output[0], y[i]) node.doc_string += "#SV-Ga5" self.set_shape(node.output[0], tuple()) return True if isinstance(y, tuple) and isinstance(i, tuple) and all_int(i): self.set_value_shape(node.output[0], tuple(y[_] for _ in i)) self.set_shape(node.output[0], (len(i),)) node.doc_string += "#SV-Ga6" return True if ( isinstance(y, tuple) and isinstance(i, (self.torch.Tensor, np.ndarray)) and i.dtype in (np.int64, self.torch.int64) and tuple(i.shape) in ((1,), tuple()) ): ishape = tuple(i.shape) ii = int(i[0]) if ishape == (1,) else int(i) if self._debug_quiet and ii >= len(y): node.doc_string += "#SV-Ga/77" return False assert ii < len(y), ( f"Unexpected value for y={y!r}, i={i!r} in node Gather " f"with inputs={node.input}{self.get_debug_msg()}" ) self.set_value_shape(node.output[0], (y[ii],) if i.shape == (1,) else y[ii]) self.set_shape(node.output[0], (1,) if i.shape == (1,) else tuple()) node.doc_string += "#SV-Ga7" return True raise RuntimeError( f"Not implemented when node Gather(x,i) with inputs={node.input}, " f"shape(x)={y!r}, i={i!r}, i.dtype={i.dtype if i is not None else '?'}" f"{self.get_debug_msg()}" ) node.doc_string += "#SV-Ga/7" return False if node.op_type == "Squeeze": if self.is_constant_or_attribute(node, 1, "axes"): y = self.value_as_shape(node.input[0]) if y is None: node.doc_string += "#SV-Sq/3" return False i = self.get_constant_or_attribute(node, 1, "axes") if isinstance(i, int): ii = i elif ( isinstance(i, np.ndarray) and i.dtype == np.int64 and i.shape in ((1,), tuple()) ): ii = int(i[0]) if i.shape == (1,) else int(i) elif i is None and isinstance(y, tuple) and len(y) == 1: # A dimension a tensor of 1 element turned into a scalar node.doc_string += "#SV-SqDim" self.set_value_shape(node.output[0], y[0]) return True else: raise RuntimeError( f"Not implemented when node Squeeze with inputs={node.input}, " f"y={y!r}, i={i!r}{self.get_debug_msg()}" ) assert ( ii == 0 ), f"A shape should only have one axis i={i}, y={y}{self.get_debug_msg()}" if isinstance(y, str): node.doc_string += "#SV-Sq1" self.set_value_shape(node.output[0], f"squeeze({y})") return True if isinstance(y, int): node.doc_string += "#SV-Sq2" self.set_value_shape(node.output[0], y) return True assert isinstance( y, tuple ), f"Unexpected type {type(y)} for y={y} and i={i}{self.get_debug_msg()}" node.doc_string += "#SV-Sq3" self.set_value_shape(node.output[0], y[0]) return True node.doc_string += "#SV-Sq/2" return False if node.op_type == "Unsqueeze": values_0 = self.value_as_shape(node.input[0]) if isinstance(values_0, tuple) and len(values_0) > 1: # This cannot be a shape anymore after this operation node.doc_string += "#SV-Unsq/1" return False if self.has_rank(node.input[0]) and self.get_rank(node.input[0]) > 0: # This cannot be a shape anymore. node.doc_string += "#SV-Unsq/2" return False if not self.has_rank(node.input[0]) and values_0 is None: node.doc_string += "#SV-Unsq/3" return False assert self.has_rank(node.input[0]), ( f"Rank of {node.input[0]!r} is unknown but " f"its value is {values_0!r}{self.get_debug_msg()}" ) if len(node.input) > 1: cst = self.get_constant(node.input[1], exc=False, computed_value=True) cst = tuple() if not cst.shape else tuple(cst) else: cst = tuple(self.get_attribute(node, "axes").ints) assert cst, f"Value={cst!r} is wrong for {node.input[0]}{self.get_debug_msg()}" if cst is not None and len(cst) == 1 and self.get_rank(node.input[0]) == 0: node.doc_string += "#SV-Unsq4" self.set_value_shape( node.output[0], (node.input[0],) if values_0 is None else (values_0,) ) return True # after this point, it is all about operators between shapes. values = [self.value_as_shape(x) for x in node.input] if any(x is None for x in values): # it is not a shape node.doc_string += "#SV-All/0" return False if node.op_type == "Concat": node.doc_string += "#SV-Co1" concatenated = [] for v in values: concatenated.extend(v if isinstance(v, tuple) else (v,)) self.set_value_shape(node.output[0], tuple(concatenated)) return True if node.op_type == "Range": if len(values) == 3: args = [] for v in values: if isinstance(v, int): args.append(v) elif len(v) == 1: # Should not happen. args.append(v[0]) else: node.doc_string += "#SV-Ra/1" return False if not all_int(args): node.doc_string += "#SV-Ra/2" return False node.doc_string += "#SV-Ra" self.set_value_shape(node.output[0], tuple(range(*args))) return True if node.op_type in {"Mul", "Add", "Div", "Sub", "Mod"}: fct, symbol = { "Add": ((lambda x, y: x + y), "+"), "Div": ((lambda x, y: x // y), "/"), "Mul": ((lambda x, y: x * y), "*"), "Sub": ((lambda x, y: x - y), "-"), "Mod": ((lambda x, y: x % y), "%"), }[node.op_type] m1 = values[0] m2 = values[1] if isinstance(m1, int) and isinstance(m2, int): node.doc_string += f"#SV-{node.op_type}1" self.set_value_shape(node.output[0], fct(m1, m2)) return True if isinstance(m1, (int, str)) and isinstance(m2, (int, str)): node.doc_string += f"#SV-{node.op_type}2" self.set_value_shape(node.output[0], f"{m1}{symbol}{m2}") return True # One of them is a tuple. if not isinstance(m1, tuple): m1 = (m1,) if not isinstance(m2, tuple): m2 = (m2,) if len(m1) == len(m2): res = [] for s1, s2 in zip(m1, m2): res.append( fct(s1, s2) if isinstance(s1, int) and isinstance(s2, int) else f"{s1}{symbol}{s2}" ) self.set_value_shape(node.output[0], tuple(res)) node.doc_string += f"#SV-{node.op_type}3" return True if len(m1) == 1: res = [] for s2 in m2: res.append( fct(m1[0], s2) if isinstance(m1[0], int) and isinstance(s2, int) else f"{m1[0]}{symbol}{s2}" ) self.set_value_shape(node.output[0], tuple(res)) node.doc_string += f"#SV-{node.op_type}4" return True if len(m2) == 1: res = [] for s1 in m1: res.append( fct(s1, m2[0]) if isinstance(s1, int) and isinstance(m2[0], int) else f"{s1}{symbol}{m2[0]}" ) self.set_value_shape(node.output[0], tuple(res)) node.doc_string += f"#SV-{node.op_type}4" return True # This cannot be a shape anymore. node.doc_string += f"#SV-{node.op_type}/0" return False if node.op_type == "Gather": if isinstance(values[1], tuple) and all_int(values[1]): shape = (values[0],) if not isinstance(values[0], tuple) else values[0] node.doc_string += "#SV-Ga1" assert max(values[1]) < len(shape), ( f"Unable to compute new value shape when values={values}" f"{self.get_debug_msg()}" ) self.set_value_shape(node.output[0], tuple(shape[i] for i in values[1])) return True if node.op_type == "Slice": if len(values) >= 3 and values[1] == (0,) and values[2] == (9223372036854775807,): node.doc_string += "#SV-Sl1" self.set_value_shape(node.output[0], values[0]) return True if len(values) < 4 or values[3] != (0,): # Not a shape. node.doc_string += "#SV-Sl/2" return False if len(values) == 4 and all_int(values[1]) and all_int(values[2]): assert len(values[1]) == len(values[2]) == 1, ( f"Unexpected values {values} to compute a shape from node " f"{self.pretty_node(node)}{self.get_debug_msg()}" ) node.doc_string += "#SV-Sl3" self.set_value_shape(node.output[0], values[0][values[1][0] : values[2][0]]) return True if ( len(values) == 4 and values[1] == (0,) and isinstance(values[2][0], str) and isinstance(values[3][0], int) ): # Maybe a shape but probably not. node.doc_string += "#SV-Sl/3" return False raise RuntimeError( f"Unable to compute a shape for node {self.pretty_node(node)} " f"with values={values}{self.get_debug_msg()}" )