from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
[docs]
@dataclass
class patched_AttentionMaskConverter:
"""
Patches
``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``.
"""
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""Patched method."""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),
mask,
],
dim=-1,
)
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window - 1
context_mask = torch.tril(
torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal
)
# In this case, the current implementation of torch fails (17/12/2024).
# Try model Phi-3.5-Mini-Instruct.
mask = mask.masked_fill(context_mask, torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
[docs]
class patched_DynamicCache:
"""
Removes the dependency on :class:`torch.nn.Module`
from :class:`transformers.cache_utils.DynamicCache`.
"""
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
self._seen_tokens = 0
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(
f"Cache only has {len(self)} layers, "
f"attempted to access layer with index {layer_idx}"
)
def __iter__(self):
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
def __len__(self):
return len(self.key_cache)
[docs]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
# elif (
# len(self.key_cache[layer_idx]) == 0
# ): # fills previously skipped layers; checking for tensor causes errors
# self.key_cache[layer_idx] = key_states
# self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states], dim=-2
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
[docs]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if not self.key_cache:
return 0
assert layer_idx < len(
self.key_cache
), f"Unexpected layer_idx={layer_idx}, len(key_cache)={len(self.key_cache)}"
return self.key_cache[layer_idx].shape[-2]
[docs]
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
return self.get_seq_length(layer_idx)
[docs]
def get_max_cache_shape(self) -> Optional[int]:
return None
[docs]
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache
[docs]
@classmethod
def from_legacy_cache(
cls,
past_key_values: Optional[Tuple[Tuple["torch.Tensor"]]] = None,
num_hidden_layers: Optional[int] = None,
) -> "transformers.cache_utils.DynamicCache": # noqa: F821
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
[docs]
def crop(self, max_length: int):
# In case it is negative
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= max_length:
return
self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
if self.key_cache[idx] != []:
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
[docs]
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: Optional[int] = None
) -> List["transformers.cache_utils.DynamicCache"]: # noqa: F821
out = []
for i in range(0, full_batch_size, split_size):
current_split = patched_DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [
tensor[i : i + split_size] for tensor in self.value_cache
]
out.append(current_split)
return out
[docs]
@classmethod
def from_batch_splits(
cls,
splits: List["transformers.cache_utils.DynamicCache"], # noqa: F821
num_hidden_layers: Optional[int] = None,
) -> "transformers.cache_utils.DynamicCache": # noqa: F821
cache = cls()
for idx in range(len(splits[0])):
key_cache = [
current.key_cache[idx] for current in splits if current.key_cache[idx] != []
]
value_cache = [
current.value_cache[idx]
for current in splits
if current.value_cache[idx] != []
]
if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0)
cache.update(layer_keys, layer_values, idx)
return cache
[docs]
def batch_repeat_interleave(self, repeats: int):
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(
repeats, dim=0
)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(
repeats, dim=0
)
[docs]
def batch_select_indices(self, indices: torch.Tensor):
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]