Note
Go to the end to download the full example code.
Export with DynamicCache and dynamic shapes¶
Every LLMs implemented in transformers use cache.
One of the most used is transformers.cache_utils.DynamicCache
.
The cache size is dynamic to cope with the growing context.
The example shows a tool which determines the dynamic shapes
for torch.export.export()
based on a set of valid inputs.
Simple Examples¶
We first look at examples playing positional and names parameters
to understand how torch.export.export()
works.
args¶
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.export import ModelInputs
We need addition import in case transformers<4.50
.
Exporting DynamicCache is not supported before that.
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y) # to check it works
ep = torch.export.export(model, (x, y))
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[5, 6]", y: "f32[1, 6]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: return x + y
add: "f32[5, 6]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {}
As expected there is no dynamic shapes.
We use onnx_diagnostic.export.ModelInputs
to define them from two set of valid inputs.
These inputs must have different value for the dynamic
dimensions.
inputs = [(x, y), (torch.randn((7, 8)), torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
{1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}),
{})
The function returns a tuple with two objects. The first one for the positional arguments, the other one for the named arguments. There is no named arguments. We we used the first result to export.
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[1, s1]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: return x + y
add: "f32[s0, s1]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
kwargs¶
We do the same with named arguments.
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y) # to check it works
tensor([[-0.6455, -2.0931, -0.3581, 1.0811, 0.8050, -0.0346],
[-1.5568, 1.0470, -1.3065, 2.6465, 1.5073, -0.0994],
[-1.8165, 0.2948, 1.0558, 3.5286, 1.2399, 1.6393],
[-2.2680, -0.4382, -1.1139, 2.1643, 1.0748, -0.6413],
[-2.1564, -1.6748, 1.2996, 3.3037, 0.5191, 0.5734]])
Two sets of valid inputs.
inputs = [dict(x=x, y=y), dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
((),
{'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
And we export.
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[1, s1]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:81 in forward, code: return x + y
add: "f32[s0, s1]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
args and kwargs¶
torch.export.export()
does not like having dynami shapes
for both args and kwargs. We need to define them using one mechanism.
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y
model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y) # to check it works
tensor([[ 1.1243, -0.0781, 0.4323, -0.1759, -0.8099, -1.3753],
[ 1.2757, -0.5779, -0.3899, 0.1688, -0.7676, 1.9258],
[-0.4012, -0.2270, 0.7940, -0.1863, -1.2159, -0.4396],
[-0.1810, -0.1303, -1.6347, -0.2527, 0.2530, -1.4620],
[ 0.6797, -0.4122, -1.1754, -1.0088, 0.0953, 1.2012]])
Two sets of valid inputs with positional and names arguments.
inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(({0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},),
{'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
This does not work with torch.export.export()
so
we use a method to move the positional dynamic shapes to
named one. The method relies on the signature of the
forward method.
new_args, new_kwargs, new_ds = mi.move_to_kwargs(*mi.inputs[0], ds)
pprint.pprint(new_ds)
((),
{'x': {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
'y': {1: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}})
And we export.
ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[1, s1]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:111 in forward, code: return x + y
add: "f32[s0, s1]" = torch.ops.aten.add.Tensor(x, y); x = y = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
DynamicCache¶
torch.export.export()
serializes caches and any custom class
if these serialization functions are provided with is the case for
transformers.cache_utils.DynamicCache
and transformers>=4.50
.
The dynamic shapes must be provided following the serialized form.
class Model(torch.nn.Module):
def forward(self, cache, z):
return (
z
+ cache.key_cache[0]
+ cache.key_cache[1]
+ cache.value_cache[0]
+ cache.value_cache[1]
)
model = Model()
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
[
(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
for i in range(n_layers)
]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z) # to check it works.
tensor([[[[-2.7320, 1.1865, -1.4100, -0.8340, 2.9263, 0.0171, -0.9859],
[ 1.3497, 1.0734, -0.1092, 0.3638, 2.3535, 1.7591, 3.7069],
[-0.8386, -1.6677, 0.8543, 2.4365, -1.1363, -4.4741, 2.6431]],
[[ 1.1639, -0.6753, -0.2704, -3.1417, -2.4369, -2.1266, -2.0043],
[ 2.4236, -0.5340, 1.5569, -2.9683, 1.7018, -3.1168, 0.3381],
[-1.4968, -0.7079, -2.5629, -0.9208, 2.3547, 1.4676, 1.5403]],
[[ 3.8567, 0.3247, -0.4750, -3.7025, 0.1587, -0.3677, 1.8134],
[-1.3193, -0.3737, 1.9140, -1.2858, -2.6532, -3.6816, -1.3043],
[ 4.2593, 1.6604, -1.5416, -2.3577, 2.0926, -1.2807, 1.3090]],
[[-0.7791, 3.6148, 1.6548, -2.7349, -1.2399, -2.2513, 0.0310],
[-4.4242, 2.5574, -2.5203, 0.7911, -1.3226, -0.2234, 0.7806],
[ 1.2048, 1.1645, -0.3480, 4.9251, 0.4152, -2.2854, -0.2640]]],
[[[ 1.4907, 3.9079, -1.7886, -4.3854, 1.8592, -2.0317, 0.5876],
[ 0.5658, 1.5225, -1.5805, 1.4893, 5.7968, -0.7570, 0.1889],
[-4.6421, 0.5355, -1.5047, 0.7374, 0.8151, 2.7162, -2.2045]],
[[-1.5634, 2.7409, -1.1134, 0.6146, 2.4468, -1.8587, 0.9671],
[ 1.6605, -1.0035, -1.8223, -1.2326, 3.0069, -1.0524, 0.6003],
[ 0.2546, -0.6998, 1.9875, -0.3369, 3.8767, 0.2644, 2.0182]],
[[-1.4111, 1.2572, 1.7940, 3.5517, -0.1263, -4.9677, -2.0635],
[-1.0432, 0.2125, 1.8386, 1.5169, -1.7548, 2.0075, 0.3302],
[-1.2682, 2.9526, -2.0307, -1.0864, 3.6513, 2.3558, 0.7038]],
[[ 5.0862, 1.9283, -6.2322, -1.3224, 2.7094, 0.0978, 0.7379],
[ 0.3515, 0.1038, -1.2784, -2.0600, 1.3356, 1.7110, 0.2679],
[-0.3298, 2.2258, -2.9578, -3.5191, 1.6338, 0.0698, 4.9869]]]])
The cache looks like this:
print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
[
(
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
)
for i in range(n_layers)
]
)
inputs = [
(cache, z),
(cache2, torch.randn((1, 1, 1, 8))),
]
And the first set of inputs looks like:
print(string_type(inputs[0], with_shape=True))
(DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7]),T1s1x1x1x7)
We can now compute the dynamic shapes.
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}],
[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>)},
{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
2: _DimHint(type=<_DimHintType.DYNAMIC: 3>),
3: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}]],
{3: _DimHint(type=<_DimHintType.DYNAMIC: 3>)}),
{})
And finally the export.
The export is simple if transformers>=4.50
, otherwise,
transformers needs to be patched.
onnx_diagnostic.torch_export_patches.bypass_export_some_errors()
registers functions to serialize DynamicCache
. This one is modified to make
the shape inference implemented in torch happy.
if has_transformers("4.50"):
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
with bypass_export_some_errors(patch_transformers=True) as modificator:
ep = torch.export.export(
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
)
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, cache_key_cache_0: "f32[s0, 4, s1, s11]", cache_key_cache_1: "f32[s0, 4, s1, s11]", cache_value_cache_0: "f32[s0, 4, s1, s11]", cache_value_cache_1: "f32[s0, 4, s1, s11]", z: "f32[1, 1, 1, s11]"):
# File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:155 in forward, code: z
add: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0); z = cache_key_cache_0 = None
add_1: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1); add = cache_key_cache_1 = None
add_2: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0); add_1 = cache_value_cache_0 = None
add_3: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1); add_2 = cache_value_cache_1 = None
return (add_3,)
Graph signature:
# inputs
cache_key_cache_0: USER_INPUT
cache_key_cache_1: USER_INPUT
cache_value_cache_0: USER_INPUT
cache_value_cache_1: USER_INPUT
z: USER_INPUT
# outputs
add_3: USER_OUTPUT
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s11: VR[2, int_oo]}
doc.plot_legend("dynamic shapes\nfor cache", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 5.465 seconds)
Related examples