Source code for experimental_experiment.xbuilder._shape_helper

from itertools import zip_longest
from typing import Any, Sequence, Tuple, Union
import numpy as np

STATIC_SHAPE = Tuple[int, ...]
DYNAMIC_SHAPE = Tuple[Union[int, "torch.SymInt", "torch.SymFloat", float, str], ...]  # noqa: F821


[docs] def reshape_implementation_with_zero(data: Any, shape: Sequence[int], allowzero: int = 0) -> Any: """ Reshapes an array or a tensor even if the new shape has 0 (following onnx specifications). """ assert hasattr(data, "reshape"), f"Missing 'reshape' for type {type(data)}" if len(shape) == 0: return data.squeeze() new_shape = ( ( tuple((o if s == 0 else s) for o, s in zip(data.shape, shape)) if len(shape) <= len(data.shape) else tuple((o if s == 0 else s) for o, s in zip_longest(data.shape, shape)) ) if allowzero == 0 else tuple(shape) ) return data.reshape(new_shape)
def all_int(seq: Sequence[Any]) -> bool: return all(isinstance(i, int) for i in seq) def all_float(seq: Sequence[Any]) -> bool: return all(isinstance(i, float) for i in seq) def all_int_or_float(seq: Sequence[Any]) -> bool: return all(isinstance(i, (int, float)) for i in seq) def all_int_or_str(seq: Sequence[Any]) -> bool: return all(isinstance(i, (int, str)) for i in seq) def is_static_shape(shape: DYNAMIC_SHAPE) -> bool: if shape is None: return False return all(map(is_static_dimension, shape)) def is_static_dimension(d: int) -> bool: import torch if isinstance(d, torch.export.dynamic_shapes._Dim): return False if isinstance(d, int): assert not isinstance( d, (torch.SymInt, torch.SymFloat) ), f"Unexpected type {type(d)} for a dimension {d!r}" return True assert isinstance( d, (str, torch.SymInt, torch.SymFloat) ), f"Unexpected type {type(d)} for a dimension {d!r}" return False
[docs] def compatible_shapes(sh1: DYNAMIC_SHAPE, sh2: DYNAMIC_SHAPE) -> bool: """ Checks that two shapes are compatible. If both static, they must be equal. If dynamic, the variable part must be compatible meaning they could be equal. :param sh1: first shape :param sh2: seconde shape :return: compatibility .. runpython:: :showcode: from experimental_experiment.xbuilder._shape_helper import compatible_shapes print(compatible_shapes((1, 2), (1, 2))) # True print(compatible_shapes((1, 2), (1, "D2"))) # True print(compatible_shapes(("D2", 2), (1, "D2"))) # False print(compatible_shapes(("D2", 2), (2, "D2"))) # True """ assert isinstance(sh1, tuple), f"type(sh1)={type(sh1)} is unexpected" assert isinstance(sh2, tuple), f"type(sh2)={type(sh2)} is unexpected" if len(sh1) != len(sh2): # incompatible rank return False assert all_int_or_str( sh1 ), f"Shape sh1={sh1} has unexpected dimension type: {[type(i) for i in sh1]}" assert all_int_or_str( sh2 ), f"Shape sh2={sh2} has unexpected dimension type: {[type(i) for i in sh2]}" constraints = {} for a, b in zip(sh1, sh2): if a == b: continue if type(a) == type(b): # noqa: E721 # The same name should be used for the same value. if a != b: return False continue name, value = (b, a) if isinstance(a, int) else (a, b) if name not in constraints: constraints[name] = value continue if constraints[name] != value: return False return True
[docs] def compatible_dimensions(*dims: Sequence[Union[int, str]]) -> bool: """ Evaluates the fact all the dimensions can be equal or not. :param dims: dimensions :return: compatibility .. runpython:: :showcode: from experimental_experiment.xbuilder._shape_helper import compatible_dimensions print(compatible_dimensions(1, 1)) # True print(compatible_dimensions(1, 2)) # False print(compatible_dimensions(1, "D")) # True print(compatible_dimensions(1, "D", "DD")) # True """ assert all_int_or_str(dims), f"unexpected types in {dims} ({[type(i) for i in dims]})" unique = set(dims) ints = [i for i in unique if isinstance(i, int)] if len(ints) > 1: return False return True
def _reshape_shape(shape: Tuple[int, ...], new_shape: Tuple[int, ...]) -> Tuple[int, ...]: """ Computes the shape of the reshaped shape. """ assert all_int(new_shape), f"Unexpected new_shape={new_shape}" if -1 not in new_shape: return new_shape pos = [i for i in new_shape if i > 0] if not pos: assert new_shape == (-1,), f"Unexpected new_shape={new_shape}, input shape is {shape}" return (int(np.prod(shape)),) res = [] for i in new_shape: if i >= 0: res.append(i) else: p = np.prod(pos) s = np.prod(shape) assert s % p == 0, f"Incompatible shapes shape={shape}, reshaped into {new_shape}" res.append(int(s // p)) return tuple(res)