onnx_diagnostic.helpers.cache_helper

onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) Any[source]

Returns the object in a different structure similar to what the definition of the dynamic shapes should use.

Parameters:
  • obj – object from a custom class

  • use_dict – closer to the original result but torch.export.export() only considers the values, the context gives the dictionary keys but it is not expressed in the dynamic shapes, these specifications seems to be different for the strict and non strict mode.

Returns:

the serialized object

onnx_diagnostic.helpers.cache_helper.is_cache_dynamic_registered(fast: bool = False) bool[source]

Tells class transformers.cache_utils.DynamicCache can be serialized and deserialized. Only then, torch.export.export() can export a model.

Parameters:

fast – if True, do not check the serialization is ok as well

Returns:

result

onnx_diagnostic.helpers.cache_helper.make_dynamic_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) DynamicCache[source]

Creates an instance of transformers.cache_utils.DynamicCache. This version is valid for transformers >= 4.50.

Parameters:

key_value_pairs – list of pairs of (key, values)

Returns:

transformers.cache_utils.DynamicCache

Example:

<<<

import torch
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7

past_key_values = make_dynamic_cache(
    [
        (
            torch.randn(bsize, nheads, slen, dim),
            torch.randn(bsize, nheads, slen, dim),
        )
        for i in range(n_layers)
    ]
)
print(string_type(past_key_values, with_shape=True))

>>>

    DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
onnx_diagnostic.helpers.cache_helper.make_encoder_decoder_cache(self_attention_cache: DynamicCache, cross_attention_cache: DynamicCache) EncoderDecoderCache[source]

Creates an EncoderDecoderCache.

onnx_diagnostic.helpers.cache_helper.make_mamba_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) MambaCache[source]

Creates a transformers.cache_utils.MambaCache.

onnx_diagnostic.helpers.cache_helper.make_sliding_window_cache(key_value_pairs: List[Tuple[Tensor, Tensor]]) MambaCache[source]

Creates a transformers.cache_utils.SlidingWindowCache.