from typing import Any, Dict, Optional, Set, Tuple
[docs]
class FakeTensorContext:
"""Stores information used to reused same dimension for the same dimension names."""
def __init__(self, fake_mode: Optional["FakeTensorMode"] = None): # noqa: F821
if fake_mode is None:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses.fake_tensor import FakeTensorMode
shape_env = ShapeEnv()
self.fake_mode = FakeTensorMode(shape_env=shape_env)
else:
self.fake_mode = fake_mode
self._candidates = self._first_primes()
self._unique_: Set[str] = set()
self._mapping_int: Dict[int, str] = {}
self._mapping_str: Dict[str, int] = {}
@classmethod
def _first_primes(cls, n=1000):
sieve = [True] * (n + 1)
sieve[0:2] = [False, False]
for i in range(2, int(n**0.5) + 1):
if sieve[i]:
# Élimine les multiples de i
sieve[i * i : n + 1 : i] = [False] * len(range(i * i, n + 1, i))
return [i for i, prime in enumerate(sieve) if prime and i >= 13]
def _unique(self) -> int:
i = 0
c = self._candidates[i]
while c in self._unique_ or c in self._mapping_int:
i += 1
assert i < len(
self._candidates
), f"Two many unique dimensions to generate, requested: {len(self._unique_)}"
c = self._candidates[i]
self._unique_.add(c)
return c
[docs]
def from_tensor(self, x, static_shapes=False) -> "FakeTensor": # noqa: F821
"""
Returns a fake tensor.
``pytorch`` returns the same name for the same dimension.
"""
fake = self.fake_mode.from_tensor(x, static_shapes=static_shapes)
for i, s in zip(x.shape, fake.shape):
assert i not in self._mapping_int or self._mapping_int[i] == s, (
f"Inconsistency between {x.shape} and {fake.shape}, "
f"mapping has {self._mapping_int[i]} and s={s}"
)
self._mapping_int[i] = s
return fake
[docs]
def fake_reshape(
self,
true_tensor: "torch.Tensor", # noqa: F821
sh: Dict[int, Any], # noqa: F821
fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
) -> "FakeTensor": # noqa: F821
"""
Changes the shape of a true tensor to make it dynamic.
:param true_tensor: true tensor
:param sh: dynamic shape
:param fake_tensor: fake tensor, if None, make a fake one
:return: fake tensor
"""
import torch
# deal with 0/1
for i in sh:
if true_tensor.shape[i] <= 1:
expanded_shape = list(true_tensor.shape)
expanded_shape[i] = self._unique()
true_tensor = torch.empty(
tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
)
# deal with equivalent dimension
new_shape = list(true_tensor.shape)
mapping = {}
for i, s in sh.items():
d = true_tensor.shape[i]
if d not in mapping:
mapping[d] = s
elif mapping[d] != s:
d = self._unique()
mapping[d] = s
new_shape[i] = d
true_tensor = torch.empty(
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
)
# now switch to FakeTensor
fake_tensor = self.from_tensor(true_tensor, static_shapes=False)
new_shape = list(true_tensor.shape)
for i in sh:
new_shape[i] = fake_tensor.shape[i]
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
axis=tuple(sorted(sh)), keepdim=True
)
return reduced_tensor.expand(*new_shape)
[docs]
def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
"""See :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`."""
if x is None:
return None
if isinstance(x, (list, tuple)):
return x.__class__([self.make_fake(i) for i in x])
if isinstance(x, dict):
return {k: self.make_fake(v) for k, v in x.items()}
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
assert hasattr(x, "layers"), (
f"Une more recent version of transformers (>=4.55), "
f"'layers' not found in class {type(x)}"
)
for layer in x.layers:
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
f"Une more recent version of transformers (>=4.55), 'layers' "
f"not found in class {type(layer)} ({dir(layer)})"
)
layer.keys = self.make_fake(layer.keys)
layer.values = self.make_fake(layer.values)
return x
if x.__class__.__name__ == "EncoderDecoderCache":
self.make_fake(x.self_attention_cache)
self.make_fake(x.cross_attention_cache)
return x
if hasattr(x, "shape"):
return self.from_tensor(x, static_shapes=False)
from . import string_type
raise TypeError(
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
)
[docs]
def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
"""
See
:func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
"""
if x is None:
return None, None
if isinstance(x, (list, tuple)):
return x.__class__(
[
self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
for i, ds in zip(x, dynamic_shapes)
]
)
if isinstance(x, dict):
return {
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
for k, v in x.items()
}
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
assert hasattr(x, "layers"), (
f"Une more recent version of transformers (>=4.55), "
f"'layers' not found in class {type(x)}"
)
assert isinstance(dynamic_shapes, list) and (
not dynamic_shapes or not isinstance(dynamic_shapes[0], list)
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
for il, layer in enumerate(x.layers):
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
f"Une more recent version of transformers (>=4.55), 'layers' "
f"not found in class {type(layer)} ({dir(layer)})"
)
layer.keys = self.make_fake_with_dynamic_dimensions(
layer.keys, dynamic_shapes=dynamic_shapes[il * 2]
)
layer.values = self.make_fake_with_dynamic_dimensions(
layer.values, dynamic_shapes=dynamic_shapes[il * 2 + 1]
)
return x
if x.__class__.__name__ == "EncoderDecoderCache":
self.make_fake_with_dynamic_dimensions(
x.self_attention_cache, dynamic_shapes=dynamic_shapes[0]
)
self.make_fake_with_dynamic_dimensions(
x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
)
return x
if hasattr(x, "shape"):
assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
f"dynamic_shapes must be a dictionary at this stage but "
f"dynamic_shapes={dynamic_shapes}"
)
# We need to overwrite the values.
new_shape = []
for idim, dim in enumerate(x.shape):
if dynamic_shapes is not None and idim in dynamic_shapes:
s = dynamic_shapes[idim]
assert isinstance(s, str), (
f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
f"at index {idim}"
)
if s in self._mapping_str:
dim = self._mapping_str[s]
else:
i = self._unique()
self._mapping_str[s] = i
dim = i
assert isinstance(dim, int), (
f"Unexpected type {type(dim)}, dynamic_shapes={dynamic_shapes} "
f"at index {idim}, dim={dim}"
)
new_shape.append(dim)
if tuple(new_shape) != x.shape:
import torch
x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
return t
from ..helpers import string_type
raise TypeError(
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
)
[docs]
def make_fake(
x: Any, context: Optional[FakeTensorContext] = None
) -> Tuple[Optional["FakeTensor"], Optional[FakeTensorContext]]: # noqa: F821
"""
Replaces all tensors by fake tensors.
This modification happens inplace for caches.
This function is only implemented for cache with
``transformers>=4.55``.
.. runpython::
:showcode:
import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
inputs, _ = make_fake(
dict(
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
past_key_values=make_dynamic_cache(
[
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
(
torch.rand((2, 32, 30, 96), dtype=torch.float16),
torch.rand((2, 32, 30, 96), dtype=torch.float16),
),
]
),
)
)
pprint.pprint(inputs)
"""
if x is None:
return None, None
if context is None:
context = FakeTensorContext()
return context.make_fake(x), context