Note
Go to the end to download the full example code.
Export with DynamicCache and dynamic shapes¶
Every LLMs implemented in transformers use cache.
One of the most used is transformers.cache_utils.DynamicCache
.
The cache size is dynamic to cope with the growing context.
The example shows a tool which determines the dynamic shapes
for torch.export.export()
based on a set of valid inputs.
DynamicCache¶
torch.export.export()
serializes caches and any custom class
if these serialization functions are provided with is the case for
transformers.cache_utils.DynamicCache
and transformers>=4.50
.
The dynamic shapes must be provided following the serialized form.
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
flatten_unflatten_for_dynamic_shapes,
make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
class Model(torch.nn.Module):
def forward(self, cache, z):
return (
z
+ cache.key_cache[0]
+ cache.key_cache[1]
+ cache.value_cache[0]
+ cache.value_cache[1]
)
model = Model()
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
[
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
for i in range(n_layers)
]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z) # to check it works.
tensor([[[[ 2.0641, 1.8414, 2.1144, -2.1858, -1.4553, 2.1089, -3.9414],
[ 2.5480, 0.4189, 0.8150, -1.8488, -1.5655, 1.6026, -0.3057],
[ 3.8920, -1.8116, 0.7533, -2.1390, 0.9199, -0.8702, 2.0602]],
[[ 0.4179, -0.7373, 0.9346, 1.8390, -3.5611, 3.4647, 0.2530],
[-2.0087, 1.1262, -1.9001, 3.3054, -2.2922, 2.4464, 0.9105],
[ 1.6731, 0.8197, -0.3959, 4.5894, -1.9321, -2.5554, -0.1521]],
[[-0.4913, -1.6713, -0.8794, 2.4646, -1.6725, -3.4529, -2.7974],
[ 4.2529, -4.1975, 0.4804, -1.7241, -2.2581, -1.1349, -1.2331],
[ 3.0714, -0.8710, -4.8588, 3.2943, -2.8378, -0.1089, -1.8729]],
[[ 1.0005, 3.5533, -2.5484, -1.0387, -1.3163, 1.9498, -1.7814],
[-1.0856, -2.3275, 3.0730, -0.1047, -3.7895, 1.5157, 0.0685],
[ 0.7635, 1.6033, 1.8651, -0.7784, -1.7362, -1.0176, 0.9097]]],
[[[ 0.4167, -3.1517, 0.4268, -0.6477, -2.1251, -2.2218, -4.0205],
[-0.3417, -0.5707, -1.0588, -0.4271, -0.6683, -1.1274, -3.2400],
[ 6.3509, -0.7531, 1.9027, 1.8726, -4.0646, 0.4263, 2.2371]],
[[ 0.2658, 3.4582, -4.4293, 1.3780, -2.8750, -1.7799, 0.5450],
[ 0.7751, -1.5647, 0.0314, 3.6410, -2.6702, 0.5702, 2.2426],
[-0.6385, -0.9095, 1.5091, -1.2678, -3.0067, -0.2964, 1.6353]],
[[ 1.9896, 2.1162, 2.8666, 0.4280, -2.3102, -1.7958, 0.4322],
[ 3.9568, 0.4840, -2.0115, 0.9279, -2.8587, 0.3561, -0.0295],
[ 0.6160, 0.1291, 0.5133, -2.9567, -2.2645, 0.7163, 3.2457]],
[[ 0.8378, 0.2365, 0.0756, -0.3337, -2.9120, -1.5043, -0.7300],
[-0.2882, 1.1744, -3.2432, -3.2175, -0.3126, -0.3074, -1.8022],
[ 0.7445, -0.6921, -1.7870, 2.7280, -6.2481, 4.1910, 0.4509]]]])
The cache looks like this:
print(string_type(cache, with_shape=True))
DynamicCache[serialized](#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]])
cache2 = make_dynamic_cache(
[
(
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
)
for i in range(n_layers)
]
)
inputs = [
(cache, z),
(cache2, torch.randn((1, 1, 1, 8))),
]
And the second set of inputs looks like:
print(string_type(inputs[1], with_shape=True))
(DynamicCache[serialized](#2[#2[T1s3x4x4x8,T1s3x4x4x8],#2[T1s3x4x4x8,T1s3x4x4x8]]),T1s1x1x1x8)
Guess the dynamic shapes¶
The following tool can be used to guess the dynamic shapes
the way torch.export.export()
expects them.
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}],
[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)},
{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}]],
{3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=None,
max=None,
_factory=True)}),
{})
And finally the export.
The export is simple if transformers>=4.50
, otherwise,
transformers needs to be patched.
onnx_diagnostic.torch_export_patches.bypass_export_some_errors()
registers functions to serialize DynamicCache
. This one is modified to make
the shape inference implemented in torch happy.
if has_transformers("4.50"):
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
with bypass_export_some_errors(patch_transformers=True) as modificator:
ep = torch.export.export(
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
)
print(ep)
# Do we need to guess?
# ++++++++++++++++++++
#
# Function :func:`onnx_diagnostic.helpers.string_type` is using
# the serialization functions to print out the DynamicCache the was
# :func:`torch.export.export` expects them.
print(string_type(cache, with_shape=True))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, cache_key_cache_0: "f32[s26, 4, s28, s1]", cache_key_cache_1: "f32[s26, 4, s28, s1]", cache_value_cache_0: "f32[s26, 4, s28, s1]", cache_value_cache_1: "f32[s26, 4, s28, s1]", z: "f32[1, 1, 1, s1]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
add: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0); z = cache_key_cache_0 = None
add_1: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1); add = cache_key_cache_1 = None
add_2: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0); add_1 = cache_value_cache_0 = None
add_3: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1); add_2 = cache_value_cache_1 = None
return (add_3,)
Graph signature:
# inputs
cache_key_cache_0: USER_INPUT
cache_key_cache_1: USER_INPUT
cache_value_cache_0: USER_INPUT
cache_value_cache_1: USER_INPUT
z: USER_INPUT
# outputs
add_3: USER_OUTPUT
Range constraints: {s26: VR[2, int_oo], s28: VR[2, int_oo], s1: VR[2, int_oo]}
DynamicCache[serialized](#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]])
You can also use function
onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes()
to show a DynamicCache restructured the way torch.export.export()
expects
it to be without the custom class.
print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]
This code works for any custom class if it was registered
with torch.utils._pytree.register_pytree_node()
.
doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 0.197 seconds)
Related examples

Steel method forward to guess the dynamic shapes (with Tiny-LLM)