onnx_diagnostic.helpers.cache_helper

class onnx_diagnostic.helpers.cache_helper.CacheKeyValue(cache=None, cls_layers=None)[source][source]

Starting transformers>=4.54, the cache API has deprecated cache.key_cache and cache.value_cache. This class wraps a cache independently from transformers version and enables attributes key_cache and value_cache.

capi = CacheKeyValue(cache)
capi.key_cache
capi.value_cache
aslist() List[Tensor][source][source]

Returns tensors in a list.

make_dynamic_cache()[source][source]

Does the reverse operation.

property n_layers: int

Returns the number of layers.

onnx_diagnostic.helpers.cache_helper.finalize_cache(cache: Cache) Cache[source][source]

Ensures the created cache is consistent. Returns the cache modified inplace.

onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False, change_function: Callable[[Tensor], Any] | None = None) Any[source][source]

Returns the object in a different structure similar to what the definition of the dynamic shapes should use.

Parameters:
  • obj – object from a custom class

  • use_dict – closer to the original result but torch.export.export() only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode. It also preserves tuple.

  • change_function – to modifies the tensor in the structure itself, like replace them by a shape

Returns:

the serialized object

onnx_diagnostic.helpers.cache_helper.is_cache_dynamic_registered(fast: bool = False) bool[source][source]

Tells if class transformers.cache_utils.DynamicCache can be serialized and deserialized. Only then, torch.export.export() can export a model.

Parameters:

fast – if True, do not check the serialization is ok as well

Returns:

result

onnx_diagnostic.helpers.cache_helper.make_dynamic_cache(key_value_pairs: List[Tensor] | List[Tuple[Tensor, Tensor]], cls_layers: str | List[type] | None = None) DynamicCache[source][source]

Creates an instance of transformers.cache_utils.DynamicCache. This version is valid for transformers >= 4.50.

Parameters:
  • key_value_pairs – list of pairs of (key, values)

  • cls_layers – to select the appropriate class to use on each layer, if specified, sliding_window is ignored, it can be a string if all layers are expected to follow the same class

Returns:

transformers.cache_utils.DynamicCache

Example:

<<<

import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_dynamic_cache(
    [
        (
            torch.randn(bsize, nheads, slen, dim),
            torch.randn(bsize, nheads, slen, dim),
        )
        for i in range(n_layers)
    ]
)
print(string_type(past_key_values, with_shape=True))

>>>

    DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

The function is fully able to handle FakeTensor with dynamic dimensions if transformers>=4.56. Before that version, only FakeTensor with static dimensions are supported.

onnx_diagnostic.helpers.cache_helper.make_dynamic_shapes_kv_cache(cache: Cache, shape_of_one: Dict[int, Any]) List[Dict[int, Any]][source][source]

Returns the dynamic shapes for key-value cache

Parameters:
  • cache – a cache

  • shape_of_one – shape of one element

Returns:

dynamic shapes

onnx_diagnostic.helpers.cache_helper.make_encoder_decoder_cache(self_attention_cache: DynamicCache, cross_attention_cache: DynamicCache) EncoderDecoderCache[source][source]

Creates an EncoderDecoderCache.

onnx_diagnostic.helpers.cache_helper.make_mamba_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) MambaCache[source][source]

Creates a MambaCache.

onnx_diagnostic.helpers.cache_helper.make_static_cache(key_value_pairs: List[Tensor] | List[Tuple[Tensor, Tensor]], max_cache_len: int | None = None) DynamicCache[source][source]

Creates an instance of transformers.cache_utils.StaticCache. :param key_value_pairs: list of pairs of (key, values) :param max_cache_len: max_cache_length or something inferred from the vector :return: transformers.cache_utils.StaticCache

Example:

<<<

import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_static_cache

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_static_cache(
    [
        (
            torch.randn(bsize, nheads, slen, dim),
            torch.randn(bsize, nheads, slen, dim),
        )
        for i in range(n_layers)
    ],
    max_cache_len=10,
)
print(string_type(past_key_values, with_shape=True))

>>>

    StaticCache(key_cache=#2[T1s2x4x10x7,T1s2x4x10x7], value_cache=#2[T1s2x4x10x7,T1s2x4x10x7])