from typing import Any, List, Tuple
import packaging.version as pv
import torch
import transformers
import transformers.cache_utils
[docs]
def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any:
"""
Returns the object in a different structure similar to what
the definition of the dynamic shapes should use.
:param obj: object from a custom class
:param use_dict: closer to the original result but
:func:`torch.export.export` only considers the values,
the context gives the dictionary keys but it is not expressed
in the dynamic shapes, these specifications seems to be different
for the strict and non strict mode.
:return: the serialized object
"""
if isinstance(obj, torch.Tensor):
return obj
flat, spec = torch.utils._pytree.tree_flatten(obj)
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict)
subtrees.append(value)
start = end
if use_dict and (spec.type is dict or spec.context):
# This a dictionary.
return dict(zip(spec.context, subtrees))
# This is a list.
return subtrees
[docs]
def is_cache_dynamic_registered(fast: bool = False) -> bool:
"""
Tells class :class:`transformers.cache_utils.DynamicCache` can be
serialized and deserialized. Only then, :func:`torch.export.export`
can export a model.
:param fast: if True, do not check the serialization is ok as well
:return: result
"""
if fast:
return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(2)
]
)
values, spec = torch.utils._pytree.tree_flatten(cache)
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
return len(cache2.key_cache) == len(cache.value_cache)
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
def make_dynamic_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
This version is valid for ``transformers >= 4.50``.
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
past_key_values = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
"""
return transformers.cache_utils.DynamicCache(key_value_pairs)
else:
[docs]
def make_dynamic_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
This version is valid for ``transformers < 4.50``.
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.DynamicCache`
Example:
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
past_key_values = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
)
for i in range(n_layers)
]
)
print(string_type(past_key_values, with_shape=True))
"""
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
for i, (key, value) in enumerate(key_value_pairs):
cache.update(key, value, i)
return cache
[docs]
def make_encoder_decoder_cache(
self_attention_cache: transformers.cache_utils.DynamicCache,
cross_attention_cache: transformers.cache_utils.DynamicCache,
) -> transformers.cache_utils.EncoderDecoderCache:
"""Creates an EncoderDecoderCache."""
return transformers.cache_utils.EncoderDecoderCache(
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
)
[docs]
def make_mamba_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.MambaCache:
"Creates a :class:`transformers.cache_utils.MambaCache`."
class _config:
def __init__(self):
self.intermediate_size = key_value_pairs[0][0].shape[1]
self.conv_kernel = key_value_pairs[0][0].shape[-1]
self.state_size = key_value_pairs[0][1].shape[-1]
self.num_hidden_layers = len(key_value_pairs)
self.dtype = key_value_pairs[0][0].dtype
cache = transformers.cache_utils.MambaCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
)
for i in range(len(key_value_pairs)):
assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
f"Shape mismatch, expected {cache.conv_states[i].shape}, "
f"got {key_value_pairs[i][0].shape}"
)
cache.conv_states[i][:, :, :] = key_value_pairs[i][0]
assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, (
f"Shape mismatch, expected {cache.ssm_states[i].shape}, "
f"got {key_value_pairs[i][1].shape}"
)
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
return cache
[docs]
def make_sliding_window_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.MambaCache:
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
class _config:
def __init__(self):
self.head_dim = key_value_pairs[0][0].shape[-1]
self.num_attention_heads = key_value_pairs[0][0].shape[1]
self.num_hidden_layers = len(key_value_pairs)
self.sliding_window = key_value_pairs[0][0].shape[2]
cache = transformers.cache_utils.SlidingWindowCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
device=key_value_pairs[0][0].device,
dtype=key_value_pairs[0][0].dtype,
)
for i in range(len(key_value_pairs)):
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
f"got {key_value_pairs[i][0].shape}"
)
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
f"got {key_value_pairs[i][1].shape}"
)
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
return cache