Source code for onnx_diagnostic.helpers.cache_helper

from typing import Any, List, Tuple
import packaging.version as pv
import torch
import transformers
import transformers.cache_utils


[docs] def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> Any: """ Returns the object in a different structure similar to what the definition of the dynamic shapes should use. :param obj: object from a custom class :param use_dict: closer to the original result but :func:`torch.export.export` only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. :return: the serialized object """ if isinstance(obj, torch.Tensor): return obj flat, spec = torch.utils._pytree.tree_flatten(obj) start = 0 end = 0 subtrees = [] for subspec in spec.children_specs: end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict) subtrees.append(value) start = end if use_dict and (spec.type is dict or spec.context): # This a dictionary. return dict(zip(spec.context, subtrees)) # This is a list. return subtrees
[docs] def is_cache_dynamic_registered(fast: bool = False) -> bool: """ Tells class :class:`transformers.cache_utils.DynamicCache` can be serialized and deserialized. Only then, :func:`torch.export.export` can export a model. :param fast: if True, do not check the serialization is ok as well :return: result """ if fast: return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES bsize, nheads, slen, dim = 2, 4, 3, 7 cache = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(2) ] ) values, spec = torch.utils._pytree.tree_flatten(cache) cache2 = torch.utils._pytree.tree_unflatten(values, spec) return len(cache2.key_cache) == len(cache.value_cache)
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): def make_dynamic_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers >= 4.50``. :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.DynamicCache` Example: .. runpython:: :showcode: import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True)) """ return transformers.cache_utils.DynamicCache(key_value_pairs) else:
[docs] def make_dynamic_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers < 4.50``. :param key_value_pairs: list of pairs of (key, values) :return: :class:`transformers.cache_utils.DynamicCache` Example: .. runpython:: :showcode: import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True)) """ cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore for i, (key, value) in enumerate(key_value_pairs): cache.update(key, value, i) return cache
[docs] def make_encoder_decoder_cache( self_attention_cache: transformers.cache_utils.DynamicCache, cross_attention_cache: transformers.cache_utils.DynamicCache, ) -> transformers.cache_utils.EncoderDecoderCache: """Creates an EncoderDecoderCache.""" return transformers.cache_utils.EncoderDecoderCache( self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache )
[docs] def make_mamba_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.MambaCache: "Creates a :class:`transformers.cache_utils.MambaCache`." class _config: def __init__(self): self.intermediate_size = key_value_pairs[0][0].shape[1] self.conv_kernel = key_value_pairs[0][0].shape[-1] self.state_size = key_value_pairs[0][1].shape[-1] self.num_hidden_layers = len(key_value_pairs) self.dtype = key_value_pairs[0][0].dtype cache = transformers.cache_utils.MambaCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, ) for i in range(len(key_value_pairs)): assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, ( f"Shape mismatch, expected {cache.conv_states[i].shape}, " f"got {key_value_pairs[i][0].shape}" ) cache.conv_states[i][:, :, :] = key_value_pairs[i][0] assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, ( f"Shape mismatch, expected {cache.ssm_states[i].shape}, " f"got {key_value_pairs[i][1].shape}" ) cache.ssm_states[i][:, :, :] = key_value_pairs[i][1] return cache
[docs] def make_sliding_window_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.MambaCache: "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." class _config: def __init__(self): self.head_dim = key_value_pairs[0][0].shape[-1] self.num_attention_heads = key_value_pairs[0][0].shape[1] self.num_hidden_layers = len(key_value_pairs) self.sliding_window = key_value_pairs[0][0].shape[2] cache = transformers.cache_utils.SlidingWindowCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window device=key_value_pairs[0][0].device, dtype=key_value_pairs[0][0].dtype, ) for i in range(len(key_value_pairs)): assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, ( f"Shape mismatch, expected {cache.key_cache[i].shape}, " f"got {key_value_pairs[i][0].shape}" ) cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, ( f"Shape mismatch, expected {cache.value_cache[i].shape}, " f"got {key_value_pairs[i][1].shape}" ) cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] return cache