onnx_diagnostic.helpers.cache_helper¶
- class onnx_diagnostic.helpers.cache_helper.CacheKeyValue(cache=None)[source][source]¶
Starting transformers>=4.54, the cache API has deprecated
cache.key_cache
andcache.value_cache
. This class wraps a cache independently from transformers version and enables attributeskey_cache
andvalue_cache
.capi = CacheKeyValue(cache) capi.key_cache capi.value_cache
- onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False, change_function: Callable[[Tensor], Any] | None = None) Any [source][source]¶
Returns the object in a different structure similar to what the definition of the dynamic shapes should use.
- Parameters:
obj – object from a custom class
use_dict – closer to the original result but
torch.export.export()
only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. It also preserves tuple.change_function – to modifies the tensor in the structure itself, like replace them by a shape
- Returns:
the serialized object
- onnx_diagnostic.helpers.cache_helper.is_cache_dynamic_registered(fast: bool = False) bool [source][source]¶
Tells class
transformers.cache_utils.DynamicCache
can be serialized and deserialized. Only then,torch.export.export()
can export a model.- Parameters:
fast – if True, do not check the serialization is ok as well
- Returns:
result
- onnx_diagnostic.helpers.cache_helper.make_dynamic_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) DynamicCache [source][source]¶
Creates an instance of
transformers.cache_utils.DynamicCache
. This version is valid fortransformers >= 4.50
.- Parameters:
key_value_pairs – list of pairs of (key, values)
- Returns:
transformers.cache_utils.DynamicCache
Example:
<<<
import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True))
>>>
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
- onnx_diagnostic.helpers.cache_helper.make_encoder_decoder_cache(self_attention_cache: DynamicCache, cross_attention_cache: DynamicCache) EncoderDecoderCache [source][source]¶
Creates an EncoderDecoderCache.
- onnx_diagnostic.helpers.cache_helper.make_hybrid_cache(key_value_pairs: List[Tuple[Tensor, Tensor]], max_cache_len: int | None = None, max_batch_size: int | None = None, sliding_window: int | None = None) HybridCache [source][source]¶
Creates an instance of
transformers.cache_utils.HybridCache
. This version is valid fortransformers < 4.50
.- Parameters:
key_value_pairs – list of pairs of (key, values)
- Returns:
transformers.cache_utils.HybridCache
Example:
<<<
import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_hybrid_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ] ) print(string_type(past_key_values, with_shape=True))
>>>
HybridCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
This part defines how the shapes are working in one HybridCache.
self.max_cache_len = ( max_cache_len if max_cache_len is not None else config.max_position_embeddings) # Sliding layers can't be larger than the overall max cache len self.sliding_window_len = min(config.sliding_window, self.max_cache_len) self.max_batch_size = max_batch_size self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) self._dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) # If the attribute does not exist in the config, fallback to a simple StaticCache if hasattr(config, "layer_types"): self.is_sliding = [ layer_type != "full_attention" for layer_type in config.layer_types] else: self.is_sliding = [False] * config.num_hidden_layers self.key_cache: list[torch.Tensor] = [] self.value_cache: list[torch.Tensor] = [] global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim) self.sliding_window = min(config.sliding_window, max_cache_len) device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): layer_device = layer_device_map[i] if layer_device_map is not None else device cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape new_layer_key_cache = torch.zeros( cache_shape, dtype=self._dtype, device=layer_device) new_layer_value_cache = torch.zeros( cache_shape, dtype=self._dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache)
- onnx_diagnostic.helpers.cache_helper.make_mamba_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) MambaCache [source][source]¶
Creates a
MambaCache
.
- onnx_diagnostic.helpers.cache_helper.make_sliding_window_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) SlidingWindowCache [source][source]¶
Creates a
transformers.cache_utils.SlidingWindowCache
.
- onnx_diagnostic.helpers.cache_helper.make_static_cache(key_value_pairs: List[Tuple[Tensor, Tensor]], max_cache_len: int | None = None) DynamicCache [source][source]¶
Creates an instance of
transformers.cache_utils.StaticCache
. :param key_value_pairs: list of pairs of (key, values) :param max_cache_len: max_cache_length or something inferred from the vector :return:transformers.cache_utils.StaticCache
Example:
<<<
import torch from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_static_cache n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 past_key_values = make_static_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ) for i in range(n_layers) ], max_cache_len=10, ) print(string_type(past_key_values, with_shape=True))
>>>
StaticCache(key_cache=#2[T1s2x4x10x7,T1s2x4x10x7], value_cache=#2[T1s2x4x10x7,T1s2x4x10x7])