Export with DynamicCache and dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-0.6074,  0.4564, -2.5022, -3.0228, -0.7389,  0.8869, -0.2922],
          [ 2.8649,  1.5606, -1.6871, -2.7368,  0.3228, -3.1701,  1.2549],
          [ 2.5762,  0.9670, -4.0031, -3.1900, -0.0721,  2.7567, -0.9342]],

         [[ 1.2530, -1.2665, -4.8364,  0.4504, -4.6639,  2.3003, -3.8798],
          [ 1.3169,  0.8476, -1.9826,  0.7616,  1.8153,  0.1259,  1.2098],
          [ 0.8704, -1.3857, -3.4236, -1.5669,  0.2834,  0.8064,  0.7629]],

         [[ 0.5600, -1.2178, -2.3031, -1.4461,  0.4515, -4.2972,  2.2226],
          [ 1.7402, -0.1522, -2.2510,  0.0965, -0.0474, -0.6318, -0.0480],
          [-0.9039,  1.2713, -1.1893, -0.4839, -0.1363,  3.8304,  2.6745]],

         [[-0.0085,  3.1588, -2.5299, -2.3592, -0.1064,  2.1106,  1.6984],
          [ 2.3459, -3.0749, -0.2180, -0.9902, -1.8782,  0.0546,  3.0867],
          [ 2.4693, -0.3514, -2.8640, -1.7124,  1.3470,  0.0556,  4.4353]]],


        [[[ 0.7560,  1.5433, -1.9611, -2.8960,  0.3127,  0.8502,  3.7807],
          [ 0.4000,  1.4497, -3.1303, -5.8418,  0.3858,  2.6408, -1.0751],
          [ 3.0079, -0.3757, -0.3050, -2.1079,  0.8492, -0.6222,  1.2014]],

         [[ 2.1729,  3.5238, -2.8492, -4.3999,  0.2071,  3.3403,  2.5480],
          [ 0.8098,  0.0974, -1.1923, -0.4168, -0.8248, -0.6386,  3.3686],
          [ 0.4880, -0.7822,  3.6328, -0.9024, -2.1802, -0.7438,  3.2056]],

         [[ 3.9566, -2.9754,  1.6396, -1.5234, -2.6085,  3.2171,  0.6216],
          [ 5.3821,  1.7626, -0.1185, -3.9268,  1.5964, -1.6121, -1.1457],
          [ 0.4399, -2.2553,  2.6655, -6.1374, -0.6601, -3.7089,  1.9773]],

         [[-1.5957, -5.7879,  2.7818, -0.3937, -1.1999, -1.6934,  2.2525],
          [ 0.8051,  1.0250,  1.5346, -1.4982, -2.2119,  0.2248, -1.8336],
          [-0.1183, -2.4124,  0.3726,  0.3715,  0.4399,  1.1047,  5.3878]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache[serialized](#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache[serialized](#2[#2[T1s3x4x4x8,T1s3x4x4x8],#2[T1s3x4x4x8,T1s3x4x4x8]]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.bypass_export_some_errors() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

if has_transformers("4.50"):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
    with bypass_export_some_errors(patch_transformers=True) as modificator:
        ep = torch.export.export(
            model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
        )
print(ep)

# Do we need to guess?
# ++++++++++++++++++++
#
# Function :func:`onnx_diagnostic.helpers.string_type` is using
# the serialization functions to print out the DynamicCache the was
# :func:`torch.export.export` expects them.

print(string_type(cache, with_shape=True))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s26, 4, s28, s1]", cache_key_cache_1: "f32[s26, 4, s28, s1]", cache_value_cache_0: "f32[s26, 4, s28, s1]", cache_value_cache_1: "f32[s26, 4, s28, s1]", z: "f32[1, 1, 1, s1]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None
            add_1: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s26: VR[2, int_oo], s28: VR[2, int_oo], s1: VR[2, int_oo]}

DynamicCache[serialized](#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.322 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Steel method forward to guess the dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery