Note
Go to the end to download the full example code.
Do no use Module as inputs!¶
This continues example Export a model using a custom type as input.
Custom classes are working fine¶
DynamicCache
is replica of transformers.cache_utils.DynamicCache
but it does not inherits from transformers.cache_utils.Cache
.
from typing import Any, Dict, List, Optional, Tuple
import torch
class DynamicCache:
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
self._seen_tokens = (
0 # Used in `generate` to keep tally of how many tokens the cache has seen
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
len(self.key_cache[layer_idx]) == 0
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states], dim=-2
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache)
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length
A model uses the class we introduced.
class ModelTakingDynamicCacheAsInput(torch.nn.Module):
def forward(self, x, dc):
kc = torch.cat(dc.key_cache, axis=1)
vc = torch.cat(dc.value_cache, axis=1)
length = dc.get_seq_length() if dc is not None else 0
ones = torch.zeros(
(
dc.key_cache[0].shape[0],
dc.key_cache[0].shape[1],
length,
dc.key_cache[0].shape[-1],
)
)
w = vc + kc + ones
y = w.sum(axis=2, keepdim=True)
return x + y
Let’s check the model runs.
x = torch.randn(3, 8, 7, 1)
cache = DynamicCache(1)
cache.update(torch.ones((3, 8, 5, 6)), (torch.ones((3, 8, 5, 6)) * 2), 0)
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([3, 8, 7, 6])
Let’s check it works with others shapes.
x = torch.randn(4, 8, 7, 1)
cache = DynamicCache(1)
cache.update(torch.ones((4, 8, 11, 6)), (torch.ones((4, 8, 11, 6)) * 2), 0)
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([4, 8, 7, 6])
Let’s export after serialization functions were registered as shown in Export a model using a custom type as input
def flatten_dynamic_cache(
dynamic_cache: DynamicCache,
) -> Tuple[List[Any], torch.utils._pytree.Context]:
flat = [
(k, getattr(dynamic_cache, k))
for k in ["key_cache", "value_cache"]
if hasattr(dynamic_cache, k)
]
return [f[1] for f in flat], [f[0] for f in flat]
def unflatten_dynamic_cache(
values: List[Any],
context: torch.utils._pytree.Context,
output_type=None,
) -> DynamicCache:
cache = DynamicCache()
values = dict(zip(context, values))
for k, v in values.items():
setattr(cache, k, v)
return cache
def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
torch.utils._pytree.Context,
]:
values, context = flatten_dynamic_cache(d)
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
torch.utils._pytree.register_pytree_node(
DynamicCache,
flatten_dynamic_cache,
unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
)
torch.fx._pytree.register_pytree_flatten_spec(
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
)
Let’s export with dynamic shapes.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, 8, 7, 1]", dc_key_cache_0: "f32[s0, 8, s2, 6]", dc_value_cache_0: "f32[s0, 8, s2, 6]"):
#
sym_size_int_2: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_3: "Sym(s2)" = torch.ops.aten.sym_size.int(dc_key_cache_0, 2)
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:80 in forward, code: kc = torch.cat(dc.key_cache, axis=1)
cat: "f32[s0, 8, s2, 6]" = torch.ops.aten.cat.default([dc_key_cache_0], 1); dc_key_cache_0 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:81 in forward, code: vc = torch.cat(dc.value_cache, axis=1)
cat_1: "f32[s0, 8, s2, 6]" = torch.ops.aten.cat.default([dc_value_cache_0], 1); dc_value_cache_0 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:83 in forward, code: ones = torch.zeros(
zeros: "f32[s0, 8, s2, 6]" = torch.ops.aten.zeros.default([sym_size_int_2, 8, sym_size_int_3, 6], device = device(type='cpu'), pin_memory = False); sym_size_int_2 = sym_size_int_3 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:91 in forward, code: w = vc + kc + ones
add: "f32[s0, 8, s2, 6]" = torch.ops.aten.add.Tensor(cat_1, cat); cat_1 = cat = None
add_1: "f32[s0, 8, s2, 6]" = torch.ops.aten.add.Tensor(add, zeros); add = zeros = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:92 in forward, code: y = w.sum(axis=2, keepdim=True)
sum_1: "f32[s0, 8, 1, 6]" = torch.ops.aten.sum.dim_IntList(add_1, [2], True); add_1 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:93 in forward, code: return x + y
add_2: "f32[s0, 8, 7, 6]" = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
return (add_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='dc_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='dc_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {s0: VR[1, 1024], s2: VR[1, 1024]}
We remove the changes for pytorch.
torch.utils._pytree.SUPPORTED_NODES.pop(DynamicCache)
torch.fx._pytree.SUPPORTED_NODES.pop(DynamicCache)
torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH.pop(DynamicCache)
Everything looks fine but now…
DynamicCache(torch.nn.Module)¶
That’s the only change we make. Everything else is the same.
class DynamicCache(torch.nn.Module):
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
self._seen_tokens = (
0 # Used in `generate` to keep tally of how many tokens the cache has seen
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
len(self.key_cache[layer_idx]) == 0
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], value_states], dim=-2
)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache)
<= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length
A model uses the class we introduced.
class ModelTakingDynamicCacheAsInput(torch.nn.Module):
def forward(self, x, dc):
kc = torch.cat(dc.key_cache, axis=1)
vc = torch.cat(dc.value_cache, axis=1)
length = dc.get_seq_length() if dc is not None else 0
ones = torch.zeros(
(
dc.key_cache[0].shape[0],
dc.key_cache[0].shape[1],
length,
dc.key_cache[0].shape[-1],
)
)
w = vc + kc + ones
y = w.sum(axis=2, keepdim=True)
return x + y
Let’s check the model runs.
x = torch.randn(3, 8, 7, 1)
cache = DynamicCache(1)
cache.update(torch.ones((3, 8, 5, 6)), (torch.ones((3, 8, 5, 6)) * 2), 0)
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([3, 8, 7, 6])
Let’s check it works with others shapes.
x = torch.randn(4, 8, 7, 1)
cache = DynamicCache(1)
cache.update(torch.ones((4, 8, 11, 6)), (torch.ones((4, 8, 11, 6)) * 2), 0)
model = ModelTakingDynamicCacheAsInput()
expected = model(x, cache)
print(expected.shape)
torch.Size([4, 8, 7, 6])
Let’s export after serialization functions were registered as shown in Export a model using a custom type as input
def flatten_dynamic_cache(
dynamic_cache: DynamicCache,
) -> Tuple[List[Any], torch.utils._pytree.Context]:
flat = [
(k, getattr(dynamic_cache, k))
for k in ["key_cache", "value_cache"]
if hasattr(dynamic_cache, k)
]
return [f[1] for f in flat], [f[0] for f in flat]
def unflatten_dynamic_cache(
values: List[Any],
context: torch.utils._pytree.Context,
output_type=None,
) -> DynamicCache:
cache = DynamicCache()
values = dict(zip(context, values))
for k, v in values.items():
setattr(cache, k, v)
return cache
def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[
List[Tuple[torch.utils._pytree.KeyEntry, Any]],
torch.utils._pytree.Context,
]:
values, context = flatten_dynamic_cache(d)
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
torch.utils._pytree.register_pytree_node(
DynamicCache,
flatten_dynamic_cache,
unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
)
torch.fx._pytree.register_pytree_flatten_spec(
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
)
Let’s export with dynamic shapes.
batch = torch.export.Dim("batch", min=1, max=1024)
clength = torch.export.Dim("clength", min=1, max=1024)
try:
ep = torch.export.export(
model,
(x, cache),
dynamic_shapes=({0: batch}, [[{0: batch, 2: clength}], [{0: batch, 2: clength}]]),
)
print(ep)
except Exception as e:
print(f"It did not work: {e}")
It did not work: Constraints violated (batch)! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of batch = L['x'].size()[0] in the specified range batch <= 1024 are valid because batch was inferred to be a constant (4).
Suggested fixes:
batch = 4
There exists a little trick to bypass that issue: we changed the base class.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, 8, 7, 1]", dc_key_cache_0: "f32[s0, 8, s2, 6]", dc_value_cache_0: "f32[s0, 8, s2, 6]"):
#
sym_size_int_2: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_3: "Sym(s2)" = torch.ops.aten.sym_size.int(dc_key_cache_0, 2)
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:259 in forward, code: kc = torch.cat(dc.key_cache, axis=1)
cat: "f32[s0, 8, s2, 6]" = torch.ops.aten.cat.default([dc_key_cache_0], 1); dc_key_cache_0 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:260 in forward, code: vc = torch.cat(dc.value_cache, axis=1)
cat_1: "f32[s0, 8, s2, 6]" = torch.ops.aten.cat.default([dc_value_cache_0], 1); dc_value_cache_0 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:262 in forward, code: ones = torch.zeros(
zeros: "f32[s0, 8, s2, 6]" = torch.ops.aten.zeros.default([sym_size_int_2, 8, sym_size_int_3, 6], device = device(type='cpu'), pin_memory = False); sym_size_int_2 = sym_size_int_3 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:270 in forward, code: w = vc + kc + ones
add: "f32[s0, 8, s2, 6]" = torch.ops.aten.add.Tensor(cat_1, cat); cat_1 = cat = None
add_1: "f32[s0, 8, s2, 6]" = torch.ops.aten.add.Tensor(add, zeros); add = zeros = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:271 in forward, code: y = w.sum(axis=2, keepdim=True)
sum_1: "f32[s0, 8, 1, 6]" = torch.ops.aten.sum.dim_IntList(add_1, [2], True); add_1 = None
# File: /home/xadupre/github/experimental-experiment/_doc/recipes/plot_exporter_exporter_inputs.py:272 in forward, code: return x + y
add_2: "f32[s0, 8, 7, 6]" = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
return (add_2,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='dc_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='dc_value_cache_0'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {s0: VR[1, 1024], s2: VR[1, 1024]}
We remove the changes for pytorch.
torch.utils._pytree.SUPPORTED_NODES.pop(DynamicCache)
torch.fx._pytree.SUPPORTED_NODES.pop(DynamicCache)
torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH.pop(DynamicCache)
Total running time of the script: (0 minutes 0.825 seconds)
Related examples
Export a model using a custom type as input
Measures the exporter success on many test cases
to_onnx and submodules from LLMs