Source code for experimental_experiment.xbuilder._shape_helper

from typing import Any, Sequence, Tuple, Union
import numpy as np

STATIC_SHAPE = Tuple[int, ...]
DYNAMIC_SHAPE = Tuple[
    Union[int, "torch.SymInt", "torch.SymFloat", float, str], ...  # noqa: F821
]


def all_int(seq: Sequence[Any]) -> bool:
    return all(isinstance(i, int) for i in seq)


def all_float(seq: Sequence[Any]) -> bool:
    return all(isinstance(i, float) for i in seq)


def all_int_or_float(seq: Sequence[Any]) -> bool:
    return all(isinstance(i, (int, float)) for i in seq)


def all_int_or_str(seq: Sequence[Any]) -> bool:
    return all(isinstance(i, (int, str)) for i in seq)


def is_static_shape(shape: DYNAMIC_SHAPE) -> bool:
    if shape is None:
        return False
    return all(map(is_static_dimension, shape))


def is_static_dimension(d: int) -> bool:
    import torch

    if isinstance(d, torch.export.dynamic_shapes._Dim):
        return False
    if isinstance(d, int):
        assert not isinstance(
            d, (torch.SymInt, torch.SymFloat)
        ), f"Unexpected type {type(d)} for a dimension {d!r}"
        return True

    assert isinstance(
        d, (str, torch.SymInt, torch.SymFloat)
    ), f"Unexpected type {type(d)} for a dimension {d!r}"
    return False


[docs] def compatible_shapes(sh1: DYNAMIC_SHAPE, sh2: DYNAMIC_SHAPE) -> bool: """ Checks that two shapes are compatible. If both static, they must be equal. If dynamic, the variable part must be compatible meaning they could be equal. :param sh1: first shape :param sh2: seconde shape :return: compatibility .. runpython:: :showcode: from experimental_experiment.xbuilder._shape_helper import compatible_shapes print(compatible_shapes((1, 2), (1, 2))) # True print(compatible_shapes((1, 2), (1, "D2"))) # True print(compatible_shapes(("D2", 2), (1, "D2"))) # False print(compatible_shapes(("D2", 2), (2, "D2"))) # True """ assert isinstance(sh1, tuple), f"type(sh1)={type(sh1)} is unexpected" assert isinstance(sh2, tuple), f"type(sh2)={type(sh2)} is unexpected" if len(sh1) != len(sh2): # incompatible rank return False assert all_int_or_str( sh1 ), f"Shape sh1={sh1} has unexpected dimension type: {[type(i) for i in sh1]}" assert all_int_or_str( sh2 ), f"Shape sh2={sh2} has unexpected dimension type: {[type(i) for i in sh2]}" constraints = {} for a, b in zip(sh1, sh2): if a == b: continue if type(a) == type(b): # noqa: E721 # The same name should be used for the same value. if a != b: return False continue name, value = (b, a) if isinstance(a, int) else (a, b) if name not in constraints: constraints[name] = value continue if constraints[name] != value: return False return True
[docs] def compatible_dimensions(*dims: Sequence[Union[int, str]]) -> bool: """ Evaluates the fact all the dimensions can be equal or not. :param dims: dimensions :return: compatibility .. runpython:: :showcode: from experimental_experiment.xbuilder._shape_helper import compatible_dimensions print(compatible_dimensions(1, 1)) # True print(compatible_dimensions(1, 2)) # False print(compatible_dimensions(1, "D")) # True print(compatible_dimensions(1, "D", "DD")) # True """ assert all_int_or_str(dims), f"unexpected types in {dims} ({[type(i) for i in dims]})" unique = set(dims) ints = [i for i in unique if isinstance(i, int)] if len(ints) > 1: return False return True
def _reshape_shape(shape: Tuple[int, ...], new_shape: Tuple[int, ...]) -> Tuple[int, ...]: """ Computes the shape of the reshaped shape. """ assert all_int(new_shape), f"Unexpected new_shape={new_shape}" if -1 not in new_shape: return new_shape pos = [i for i in new_shape if i > 0] if not pos: assert new_shape == (-1,), f"Unexpected new_shape={new_shape}, input shape is {shape}" return (int(np.prod(shape)),) res = [] for i in new_shape: if i >= 0: res.append(i) else: p = np.prod(pos) s = np.prod(shape) assert s % p == 0, f"Incompatible shapes shape={shape}, reshaped into {new_shape}" res.append(int(s // p)) return tuple(res)