Source code for onnx_diagnostic.helpers.cache_helper

from typing import Any, Callable, List, Optional, Tuple
import packaging.version as pv
import torch
import transformers
import transformers.cache_utils

try:
    from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
    from transformers.cache_utils import MambaCache


[docs] class CacheKeyValue: """ Starting transformers>=4.54, the cache API has deprecated ``cache.key_cache`` and ``cache.value_cache``. This class wraps a cache independently from transformers version and enables attributes ``key_cache`` and ``value_cache``. .. code-block:: python capi = CacheKeyValue(cache) capi.key_cache capi.value_cache """ def __init__(self, cache=None): if hasattr(cache, "layers"): layers = [ layer for layer in cache.layers if layer is not None and layer.keys is not None and layer.values is not None ] self.key_cache = [layer.keys for layer in layers] self.value_cache = [layer.values for layer in layers] if None in self.key_cache or None in self.value_cache: from .helper import string_type raise AssertionError( f"issue with key_cache={string_type(self.key_cache)}, " f"or value_cache={string_type(self.value_cache)}, " f"cache.layers={string_type(cache.layers)}" ) elif cache is not None: self.key_cache = cache.key_cache self.value_cache = cache.value_cache
[docs] def make_dynamic_cache(self): """Do the reverse operation.""" return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
[docs] def flatten_unflatten_for_dynamic_shapes( obj: Any, use_dict: bool = False, change_function: Optional[Callable[[torch.Tensor], Any]] = None, ) -> Any: """ Returns the object in a different structure similar to what the definition of the dynamic shapes should use. :param obj: object from a custom class :param use_dict: closer to the original result but :func:`torch.export.export` only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. It also preserves tuple. :param change_function: to modifies the tensor in the structure itself, like replace them by a shape :return: the serialized object """ if isinstance(obj, torch.Tensor): return change_function(obj) if change_function else obj flat, spec = torch.utils._pytree.tree_flatten(obj) start = 0 end = 0 subtrees = [] for subspec in spec.children_specs: end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_for_dynamic_shapes( value, use_dict=use_dict, change_function=change_function ) subtrees.append(value) start = end if use_dict: if spec.type is dict: # This a dictionary. return dict(zip(spec.context, subtrees)) if spec.type is tuple: return tuple(subtrees) if spec.type is list: return list(subtrees) if spec.context: # This is a custom class with attributes. # It is returned as a list. return list(subtrees) raise ValueError( f"Unable to interpret spec type {spec.type} " f"(type is {type(spec.type)}, context is {spec.context})." ) # This is a list. return subtrees
[docs] def is_cache_dynamic_registered(fast: bool = False) -> bool: """ Tells class :class:`transformers.cache_utils.DynamicCache` can be serialized and deserialized. Only then, :func:`torch.export.export` can export a model. :param fast: if True, do not check the serialization is ok as well :return: result """ if fast: return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES bsize, nheads, slen, dim = 2, 4, 3, 7 cache = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(2) ] ) values, spec = torch.utils._pytree.tree_flatten(cache) cache2 = torch.utils._pytree.tree_unflatten(values, spec) return len(cache2.key_cache) == len(cache.value_cache)
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): def make_dynamic_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers >= 4.50``. :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.DynamicCache` Example: .. runpython:: :showcode: import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True)) """ cache = transformers.cache_utils.DynamicCache(key_value_pairs) if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is # initialized. We need to remove them. # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) # self.append_new_layers(self.num_hidden_layers - 1) cache.layers[:] = cache.layers[-len(key_value_pairs) :] assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), ( f"Unexpected number of layers in the cache ({len(cache.layers)}), " f"{len(key_value_pairs)} expected." ) return cache else:
[docs] def make_dynamic_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers < 4.50``. :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.DynamicCache` Example: .. runpython:: :showcode: import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True)) """ cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore for i, (key, value) in enumerate(key_value_pairs): cache.update(key, value, i) return cache
[docs] def make_static_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], max_cache_len: Optional[int] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.StaticCache`. :param key_value_pairs: list of pairs of (key, values) :param max_cache_len: max_cache_length or something inferred from the vector :return: :class:`transformers.cache_utils.StaticCache` Example: .. runpython:: :showcode: import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_static_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_static_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ], max_cache_len=10, ) print(string_type(past_key_values, with_shape=True)) """ class _config: def __init__(self): self.head_dim = key_value_pairs[0][0].shape[-1] self.num_attention_heads = key_value_pairs[0][0].shape[1] self.num_hidden_layers = len(key_value_pairs) assert max_cache_len is not None, ( f"max_cache_len={max_cache_len} cannot be setup " f"automatically yet from shape {key_value_pairs[0][0].shape}" ) torch._check( max_cache_len >= key_value_pairs[0][0].shape[2], ( f"max_cache_len={max_cache_len} cannot be smaller " f"shape[2]={key_value_pairs[0][0].shape[2]} in shape " f"{key_value_pairs[0][0].shape}" ), ) cache = transformers.cache_utils.StaticCache( config=_config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, dtype=key_value_pairs[0][0].dtype, max_cache_len=max_cache_len, ) ca = CacheKeyValue(cache) for i in range(len(key_value_pairs)): assert ( key_value_pairs[i][0].shape == key_value_pairs[i][1].shape ), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}" d = key_value_pairs[i][1].shape[2] ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0] ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1] if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is # initialized. We need to remove them. # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) # self.append_new_layers(self.num_hidden_layers - 1) cache.layers[:] = cache.layers[-len(key_value_pairs) :] assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), ( f"Unexpected number of layers in the cache ({len(cache.layers)}), " f"{len(key_value_pairs)} expected." ) return cache
[docs] def make_encoder_decoder_cache( self_attention_cache: transformers.cache_utils.DynamicCache, cross_attention_cache: transformers.cache_utils.DynamicCache, ) -> transformers.cache_utils.EncoderDecoderCache: """Creates an EncoderDecoderCache.""" return transformers.cache_utils.EncoderDecoderCache( self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache )
[docs] def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache: "Creates a ``MambaCache``." dtype = key_value_pairs[0][0].dtype class _config: def __init__(self): self.intermediate_size = key_value_pairs[0][0].shape[1] self.conv_kernel = key_value_pairs[0][0].shape[-1] self.state_size = key_value_pairs[0][1].shape[-1] self.num_hidden_layers = len(key_value_pairs) self.dtype = dtype cache = MambaCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, dtype=dtype, ) for i in range(len(key_value_pairs)): assert cache.conv_states[i].dtype == dtype, ( f"Type mismatch for cache.conv_states[{i}].dtype=" f"{cache.conv_states[i].dtype} != {dtype}" ) assert cache.ssm_states[i].dtype == dtype, ( f"Type mismatch for cache.ssm_states[{i}].dtype=" f"{cache.ssm_states[i].dtype} != {dtype}" ) assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, ( f"Shape mismatch, expected {cache.conv_states[i].shape}, " f"got {key_value_pairs[i][0].shape}" ) cache.conv_states[i][:, :, :] = key_value_pairs[i][0] assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, ( f"Shape mismatch, expected {cache.ssm_states[i].shape}, " f"got {key_value_pairs[i][1].shape}" ) cache.ssm_states[i][:, :, :] = key_value_pairs[i][1] return cache
[docs] def make_sliding_window_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.SlidingWindowCache: "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." class _config: def __init__(self): self.head_dim = key_value_pairs[0][0].shape[-1] self.num_attention_heads = key_value_pairs[0][0].shape[1] self.num_hidden_layers = len(key_value_pairs) self.sliding_window = key_value_pairs[0][0].shape[2] cache = transformers.cache_utils.SlidingWindowCache( config=_config(), max_batch_size=key_value_pairs[0][0].shape[0], max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window device=key_value_pairs[0][0].device, dtype=key_value_pairs[0][0].dtype, ) ca = CacheKeyValue(cache) for i in range(len(key_value_pairs)): assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, ( f"Shape mismatch, expected {cache.key_cache[i].shape}, " f"got {key_value_pairs[i][0].shape}" ) ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0] assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, ( f"Shape mismatch, expected {cache.value_cache[i].shape}, " f"got {key_value_pairs[i][1].shape}" ) ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1] if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is # initialized. We need to remove them. # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) # self.append_new_layers(self.num_hidden_layers - 1) cache.layers[:] = cache.layers[-len(key_value_pairs) :] assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), ( f"Unexpected number of layers in the cache ({len(cache.layers)}), " f"{len(key_value_pairs)} expected." ) return cache
[docs] def make_hybrid_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], max_cache_len: Optional[int] = None, max_batch_size: Optional[int] = None, sliding_window: Optional[int] = None, ) -> transformers.cache_utils.HybridCache: """ Creates an instance of :class:`transformers.cache_utils.HybridCache`. This version is valid for ``transformers < 4.50``. :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.HybridCache` Example: .. runpython:: :showcode: import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_hybrid_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True)) This part defines how the shapes are working in one HybridCache. .. code-block:: python self.max_cache_len = ( max_cache_len if max_cache_len is not None else config.max_position_embeddings) # Sliding layers can't be larger than the overall max cache len self.sliding_window_len = min(config.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) # If the attribute does not exist in the config, fallback to a simple StaticCache if hasattr(config, "layer_types"): self.is_sliding = [ layer_type != "full_attention" for layer_type in config.layer_types] else: self.is_sliding = [False] * config.num_hidden_layers self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) self.sliding_window = min(config.sliding_window, max_cache_len) device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): layer_device = layer_device_map[i] if layer_device_map is not None else device cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape new_layer_key_cache = torch.zeros( cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros( cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) """ layer_types = None if key_value_pairs: assert ( not max_batch_size and not max_cache_len ), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size" max_batch_size = key_value_pairs[0][0].shape[0] sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs) if len(sets_of_dim) == 1: max_cache_len = sets_of_dim.pop() sliding_window = max_cache_len else: assert ( len(sets_of_dim) == 2 ), f"Not implemented for more than 2 dimensions {sets_of_dim}" max_cache_len = max(sets_of_dim) sliding_window = min(sets_of_dim) layer_types = [ "full_attention" if i == max_cache_len else "sliding_attention" for i in [kv[0].shape[2] for kv in key_value_pairs] ] else: assert ( max_batch_size and max_cache_len ), "key_value_pairs is empty, max_batch_size and max_cache_len are required" if sliding_window is None: sliding_window = max_cache_len _max_cache_len = max_cache_len _sliding_window = sliding_window class _config: max_cache_len = _max_cache_len batch_size = max_batch_size num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None num_hidden_layers = len(key_value_pairs) sliding_window = _sliding_window num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3 if layer_types: _config.layer_types = layer_types # type: ignore[attr-defined] cache = transformers.cache_utils.HybridCache( config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size ) for i, (key, value) in enumerate(key_value_pairs): cache.update( key, value, i, cache_kwargs={ "cache_position": torch.arange(0, key.shape[2], dtype=torch.int64).to( key.device ) }, ) if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is # initialized. We need to remove them. # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) # self.append_new_layers(self.num_hidden_layers - 1) cache.layers[:] = cache.layers[-len(key_value_pairs) :] assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), ( f"Unexpected number of layers in the cache ({len(cache.layers)}), " f"{len(key_value_pairs)} expected." ) return cache