from typing import Any, Callable, List, Optional, Tuple
import packaging.version as pv
import torch
import transformers
import transformers.cache_utils
try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache
[docs]
class CacheKeyValue:
"""
Starting transformers>=4.54, the cache API has deprecated
``cache.key_cache`` and ``cache.value_cache``.
This class wraps a cache independently from transformers version and enables
attributes ``key_cache`` and ``value_cache``.
.. code-block:: python
capi = CacheKeyValue(cache)
capi.key_cache
capi.value_cache
"""
def __init__(self, cache=None):
if hasattr(cache, "layers"):
layers = [
layer
for layer in cache.layers
if layer is not None and layer.keys is not None and layer.values is not None
]
self.key_cache = [layer.keys for layer in layers]
self.value_cache = [layer.values for layer in layers]
if None in self.key_cache or None in self.value_cache:
from .helper import string_type
raise AssertionError(
f"issue with key_cache={string_type(self.key_cache)}, "
f"or value_cache={string_type(self.value_cache)}, "
f"cache.layers={string_type(cache.layers)}"
)
elif cache is not None:
self.key_cache = cache.key_cache
self.value_cache = cache.value_cache
[docs]
def make_dynamic_cache(self):
"""Do the reverse operation."""
return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
[docs]
def flatten_unflatten_for_dynamic_shapes(
obj: Any,
use_dict: bool = False,
change_function: Optional[Callable[[torch.Tensor], Any]] = None,
) -> Any:
"""
Returns the object in a different structure similar to what
the definition of the dynamic shapes should use.
:param obj: object from a custom class
:param use_dict: closer to the original result but
:func:`torch.export.export` only considers the values,
the context gives the dictionary keys but it is not expressed
in the dynamic shapes, these specifications seems to be different
for the strict and non strict mode. It also preserves tuple.
:param change_function: to modifies the tensor in the structure itself,
like replace them by a shape
:return: the serialized object
"""
if isinstance(obj, torch.Tensor):
return change_function(obj) if change_function else obj
flat, spec = torch.utils._pytree.tree_flatten(obj)
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(
value, use_dict=use_dict, change_function=change_function
)
subtrees.append(value)
start = end
if use_dict:
if spec.type is dict:
# This a dictionary.
return dict(zip(spec.context, subtrees))
if spec.type is tuple:
return tuple(subtrees)
if spec.type is list:
return list(subtrees)
if spec.context:
# This is a custom class with attributes.
# It is returned as a list.
return list(subtrees)
raise ValueError(
f"Unable to interpret spec type {spec.type} "
f"(type is {type(spec.type)}, context is {spec.context})."
)
# This is a list.
return subtrees
[docs]
def is_cache_dynamic_registered(fast: bool = False) -> bool:
"""
Tells class :class:`transformers.cache_utils.DynamicCache` can be
serialized and deserialized. Only then, :func:`torch.export.export`
can export a model.
:param fast: if True, do not check the serialization is ok as well
:return: result
"""
if fast:
return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(2)
]
)
values, spec = torch.utils._pytree.tree_flatten(cache)
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
return len(cache2.key_cache) == len(cache.value_cache)
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
def make_dynamic_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
This version is valid for ``transformers >= 4.50``.
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
past_key_values = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
"""
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
# The cache constructor contains the two following lines
# (in cache_utils.py) which append empty layers when the cache is
# initialized. We need to remove them.
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
# self.append_new_layers(self.num_hidden_layers - 1)
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
f"{len(key_value_pairs)} expected."
)
return cache
else:
[docs]
def make_dynamic_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
This version is valid for ``transformers < 4.50``.
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
past_key_values = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
"""
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache
[docs]
def make_static_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
max_cache_len: Optional[int] = None,
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
:param key_value_pairs: list of pairs of (key, values)
:param max_cache_len: max_cache_length or something inferred from the vector
:return: :class:`transformers.cache_utils.StaticCache`
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_static_cache
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
past_key_values = make_static_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
],
max_cache_len=10,
)
print(string_type(past_key_values, with_shape=True))
"""
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
assert max_cache_len is not None, (
f"max_cache_len={max_cache_len} cannot be setup "
f"automatically yet from shape {key_value_pairs[0][0].shape}"
)
torch._check(
max_cache_len >= key_value_pairs[0][0].shape[2],
(
f"max_cache_len={max_cache_len} cannot be smaller "
f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
f"{key_value_pairs[0][0].shape}"
),
)
cache = transformers.cache_utils.StaticCache(
config=_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
max_cache_len=max_cache_len,
)
ca = CacheKeyValue(cache)
for i in range(len(key_value_pairs)):
assert (
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
d = key_value_pairs[i][1].shape[2]
ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
# The cache constructor contains the two following lines
# (in cache_utils.py) which append empty layers when the cache is
# initialized. We need to remove them.
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
# self.append_new_layers(self.num_hidden_layers - 1)
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
f"{len(key_value_pairs)} expected."
)
return cache
[docs]
def make_encoder_decoder_cache(
self_attention_cache: transformers.cache_utils.DynamicCache,
cross_attention_cache: transformers.cache_utils.DynamicCache,
) -> transformers.cache_utils.EncoderDecoderCache:
"""Creates an EncoderDecoderCache."""
return transformers.cache_utils.EncoderDecoderCache(
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
)
[docs]
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
"Creates a ``MambaCache``."
dtype = key_value_pairs[0][0].dtype
class _config:
def __init__(self):
self.intermediate_size = key_value_pairs[0][0].shape[1]
self.conv_kernel = key_value_pairs[0][0].shape[-1]
self.state_size = key_value_pairs[0][1].shape[-1]
self.num_hidden_layers = len(key_value_pairs)
self.dtype = dtype
cache = MambaCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
dtype=dtype,
)
for i in range(len(key_value_pairs)):
assert cache.conv_states[i].dtype == dtype, (
f"Type mismatch for cache.conv_states[{i}].dtype="
f"{cache.conv_states[i].dtype} != {dtype}"
)
assert cache.ssm_states[i].dtype == dtype, (
f"Type mismatch for cache.ssm_states[{i}].dtype="
f"{cache.ssm_states[i].dtype} != {dtype}"
)
assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
f"Shape mismatch, expected {cache.conv_states[i].shape}, "
f"got {key_value_pairs[i][0].shape}"
)
cache.conv_states[i][:, :, :] = key_value_pairs[i][0]
assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, (
f"Shape mismatch, expected {cache.ssm_states[i].shape}, "
f"got {key_value_pairs[i][1].shape}"
)
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
return cache
[docs]
def make_sliding_window_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.SlidingWindowCache:
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
self.sliding_window = key_value_pairs[0][0].shape[2]
cache = transformers.cache_utils.SlidingWindowCache(
config=_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
)
ca = CacheKeyValue(cache)
for i in range(len(key_value_pairs)):
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
f"got {key_value_pairs[i][0].shape}"
)
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
f"got {key_value_pairs[i][1].shape}"
)
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
# The cache constructor contains the two following lines
# (in cache_utils.py) which append empty layers when the cache is
# initialized. We need to remove them.
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
# self.append_new_layers(self.num_hidden_layers - 1)
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
f"{len(key_value_pairs)} expected."
)
return cache
[docs]
def make_hybrid_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
max_cache_len: Optional[int] = None,
max_batch_size: Optional[int] = None,
sliding_window: Optional[int] = None,
) -> transformers.cache_utils.HybridCache:
"""
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
This version is valid for ``transformers < 4.50``.
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.HybridCache`
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
past_key_values = make_hybrid_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
This part defines how the shapes are working in one HybridCache.
.. code-block:: python
self.max_cache_len = (
max_cache_len if max_cache_len is not None else config.max_position_embeddings)
# Sliding layers can't be larger than the overall max cache len
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size
self.head_dim = (
config.head_dim if hasattr(config, "head_dim")
else config.hidden_size // config.num_attention_heads
)
self._dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
self.is_sliding = [
layer_type != "full_attention" for layer_type in config.layer_types]
else:
self.is_sliding = [False] * config.num_hidden_layers
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
self.sliding_window_len, self.head_dim)
self.sliding_window = min(config.sliding_window, max_cache_len)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
layer_device = layer_device_map[i] if layer_device_map is not None else device
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
new_layer_key_cache = torch.zeros(
cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(
cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
"""
layer_types = None
if key_value_pairs:
assert (
not max_batch_size and not max_cache_len
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
max_batch_size = key_value_pairs[0][0].shape[0]
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
if len(sets_of_dim) == 1:
max_cache_len = sets_of_dim.pop()
sliding_window = max_cache_len
else:
assert (
len(sets_of_dim) == 2
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
max_cache_len = max(sets_of_dim)
sliding_window = min(sets_of_dim)
layer_types = [
"full_attention" if i == max_cache_len else "sliding_attention"
for i in [kv[0].shape[2] for kv in key_value_pairs]
]
else:
assert (
max_batch_size and max_cache_len
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
if sliding_window is None:
sliding_window = max_cache_len
_max_cache_len = max_cache_len
_sliding_window = sliding_window
class _config:
max_cache_len = _max_cache_len
batch_size = max_batch_size
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
num_hidden_layers = len(key_value_pairs)
sliding_window = _sliding_window
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
if layer_types:
_config.layer_types = layer_types # type: ignore[attr-defined]
cache = transformers.cache_utils.HybridCache(
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
)
for i, (key, value) in enumerate(key_value_pairs):
cache.update(
key,
value,
i,
cache_kwargs={
"cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to(
key.device
)
},
)
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
# The cache constructor contains the two following lines
# (in cache_utils.py) which append empty layers when the cache is
# initialized. We need to remove them.
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
# self.append_new_layers(self.num_hidden_layers - 1)
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
f"{len(key_value_pairs)} expected."
)
return cache