onnx_diagnostic.helpers.cache_helper

class onnx_diagnostic.helpers.cache_helper.CacheKeyValue(cache=None)[source][source]

Starting transformers>=4.54, the cache API has deprecated cache.key_cache and cache.value_cache. This class wraps a cache independently from transformers version and enables attributes key_cache and value_cache.

capi = CacheKeyValue(cache)
capi.key_cache
capi.value_cache
make_dynamic_cache()[source][source]

Do the reverse operation.

onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False, change_function: Callable[[Tensor], Any] | None = None) Any[source][source]

Returns the object in a different structure similar to what the definition of the dynamic shapes should use.

Parameters:
  • obj – object from a custom class

  • use_dict – closer to the original result but torch.export.export() only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. It also preserves tuple.

  • change_function – to modifies the tensor in the structure itself, like replace them by a shape

Returns:

the serialized object

onnx_diagnostic.helpers.cache_helper.is_cache_dynamic_registered(fast: bool = False) bool[source][source]

Tells class transformers.cache_utils.DynamicCache can be serialized and deserialized. Only then, torch.export.export() can export a model.

Parameters:

fast – if True, do not check the serialization is ok as well

Returns:

result

onnx_diagnostic.helpers.cache_helper.make_dynamic_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) DynamicCache[source][source]

Creates an instance of transformers.cache_utils.DynamicCache. This version is valid for transformers >= 4.50.

Parameters:

key_value_pairs – list of pairs of (key, values)

Returns:

transformers.cache_utils.DynamicCache

Example:

<<<

import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_dynamic_cache(
    [
        (
            torch.randn(bsize, nheads, slen, dim),
            torch.randn(bsize, nheads, slen, dim),
        )
        for i in range(n_layers)
    ]
)
print(string_type(past_key_values, with_shape=True))

>>>

    DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
onnx_diagnostic.helpers.cache_helper.make_encoder_decoder_cache(self_attention_cache: DynamicCache, cross_attention_cache: DynamicCache) EncoderDecoderCache[source][source]

Creates an EncoderDecoderCache.

onnx_diagnostic.helpers.cache_helper.make_hybrid_cache(key_value_pairs: List[Tuple[Tensor, Tensor]], max_cache_len: int | None = None, max_batch_size: int | None = None, sliding_window: int | None = None) HybridCache[source][source]

Creates an instance of transformers.cache_utils.HybridCache. This version is valid for transformers < 4.50.

Parameters:

key_value_pairs – list of pairs of (key, values)

Returns:

transformers.cache_utils.HybridCache

Example:

<<<

import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_hybrid_cache(
    [
        (
            torch.randn(bsize, nheads, slen, dim),
            torch.randn(bsize, nheads, slen, dim),
        )
        for i in range(n_layers)
    ]
)
print(string_type(past_key_values, with_shape=True))

>>>

    HybridCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

This part defines how the shapes are working in one HybridCache.

self.max_cache_len = (
    max_cache_len if max_cache_len is not None else config.max_position_embeddings)

# Sliding layers can't be larger than the overall max cache len
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size

self.head_dim = (
    config.head_dim if hasattr(config, "head_dim")
    else config.hidden_size // config.num_attention_heads
)

self._dtype = dtype
self.num_key_value_heads = (
    config.num_attention_heads
    if getattr(config, "num_key_value_heads", None) is None
    else config.num_key_value_heads
)

# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
    self.is_sliding = [
        layer_type != "full_attention" for layer_type in config.layer_types]
else:
    self.is_sliding = [False] * config.num_hidden_layers

self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads,
                      self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads,
                       self.sliding_window_len, self.head_dim)
self.sliding_window = min(config.sliding_window, max_cache_len)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
    layer_device = layer_device_map[i] if layer_device_map is not None else device
    cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
    new_layer_key_cache = torch.zeros(
        cache_shape, dtype=self._dtype, device=layer_device)
    new_layer_value_cache = torch.zeros(
        cache_shape, dtype=self._dtype, device=layer_device)
    torch._dynamo.mark_static_address(new_layer_key_cache)
    torch._dynamo.mark_static_address(new_layer_value_cache)
    self.key_cache.append(new_layer_key_cache)
    self.value_cache.append(new_layer_value_cache)
onnx_diagnostic.helpers.cache_helper.make_mamba_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) MambaCache[source][source]

Creates a MambaCache.

onnx_diagnostic.helpers.cache_helper.make_sliding_window_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) SlidingWindowCache[source][source]

Creates a transformers.cache_utils.SlidingWindowCache.

onnx_diagnostic.helpers.cache_helper.make_static_cache(key_value_pairs: List[Tuple[Tensor, Tensor]], max_cache_len: int | None = None) DynamicCache[source][source]

Creates an instance of transformers.cache_utils.StaticCache. :param key_value_pairs: list of pairs of (key, values) :param max_cache_len: max_cache_length or something inferred from the vector :return: transformers.cache_utils.StaticCache

Example:

<<<

import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_static_cache

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_static_cache(
    [
        (
            torch.randn(bsize, nheads, slen, dim),
            torch.randn(bsize, nheads, slen, dim),
        )
        for i in range(n_layers)
    ],
    max_cache_len=10,
)
print(string_type(past_key_values, with_shape=True))

>>>

    StaticCache(key_cache=#2[T1s2x4x10x7,T1s2x4x10x7], value_cache=#2[T1s2x4x10x7,T1s2x4x10x7])