Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[ 1.6211e+00, -4.4963e-01,  3.7138e-01, -1.1672e+00,  3.0473e+00,
           -3.8904e+00, -3.7442e+00],
          [ 5.1387e-01,  5.5463e-01,  3.2164e+00, -2.0300e+00, -4.9825e-01,
           -2.8044e+00, -3.4624e+00],
          [-3.1017e+00, -3.0667e+00, -8.5357e-01, -2.0056e+00,  1.5722e+00,
            1.9887e+00, -2.8111e+00]],

         [[-1.5656e+00,  2.6021e+00, -6.9981e-03, -6.1359e-01, -1.8002e+00,
           -1.8340e-02,  4.5578e-01],
          [ 3.2180e-01,  3.9261e-02,  2.3844e-01, -6.3152e-01, -2.4216e+00,
           -2.3339e+00,  1.7688e+00],
          [-5.2917e-01, -2.4031e+00,  7.8195e-01, -1.4755e+00,  2.1120e+00,
            3.2735e-01, -1.2401e+00]],

         [[-6.8606e-01, -1.3246e+00, -2.0586e+00,  4.8146e-02,  1.4194e+00,
            1.2322e+00, -7.9979e-01],
          [ 1.6524e+00, -9.8557e-01, -7.8124e-01, -6.3959e-01,  7.0277e-02,
            6.7879e-02, -1.3848e+00],
          [-2.8186e+00, -2.7293e+00, -3.2263e-01, -2.3429e-01, -1.7123e+00,
           -5.7148e-01, -9.9731e-02]],

         [[-4.7298e-01,  1.0095e+00,  6.8650e-01,  1.1478e-01, -2.0815e+00,
           -2.7628e+00, -1.5509e+00],
          [ 2.3727e+00, -4.8947e+00,  5.2629e-01, -1.2931e+00,  1.0776e+00,
           -2.6690e+00, -1.0616e+00],
          [-3.7991e-01, -2.4463e-01, -2.0551e+00,  2.7257e-01,  1.5132e+00,
           -1.2175e+00, -3.2151e-01]]],


        [[[ 3.7307e+00, -2.8901e-01, -1.8701e+00,  1.7091e+00,  2.9891e-01,
           -3.6944e+00, -8.1635e-01],
          [-2.5388e+00,  1.0909e+00, -6.7015e-01, -3.6232e+00, -2.1953e-01,
           -3.1885e+00,  2.2721e+00],
          [ 1.8095e-02, -2.6668e+00, -2.6200e+00,  6.7534e-01,  2.6317e+00,
           -1.5612e+00, -9.8545e-01]],

         [[ 1.6619e+00, -7.3568e-02,  1.8104e+00, -2.6276e+00, -1.5602e+00,
            1.8792e+00, -2.8923e+00],
          [ 3.1849e+00,  1.1824e+00, -1.1947e+00,  3.7644e-01, -2.5369e+00,
            4.8330e-03,  9.9245e-01],
          [ 1.0146e+00,  3.8534e-01,  4.0157e-01,  7.2531e-01,  1.9562e-01,
            7.5785e-01,  2.7832e-01]],

         [[-2.3417e+00, -2.9269e+00,  2.3163e+00, -8.0094e-02,  3.8486e+00,
           -4.0846e-01, -2.5925e+00],
          [ 9.1589e-01, -8.1134e-01,  3.5342e+00, -3.5822e+00,  2.8905e+00,
            1.7780e+00, -1.5687e+00],
          [ 3.0770e+00, -9.1214e-01,  4.6789e-01, -6.9587e-01, -5.3018e-01,
           -1.2687e-01, -1.2247e+00]],

         [[ 5.3347e+00,  1.7573e+00,  7.6813e-01, -1.4423e+00,  3.8983e+00,
            7.3556e-02,  1.2060e+00],
          [-1.6905e+00,  4.6349e-01, -2.4926e+00,  8.4244e-01, -2.1735e+00,
           -3.1046e-01, -1.7711e+00],
          [ 1.6472e+00,  1.2324e+00, -8.4641e-01, -7.6808e-01, -3.5611e+00,
           -2.4918e+00,  3.9598e-01]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

if has_transformers("4.50"):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
    with torch_export_patches(patch_transformers=True) as modificator:
        ep = torch.export.export(
            model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
        )
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s26, 4, s28, s1]", cache_key_cache_1: "f32[s26, 4, s28, s1]", cache_value_cache_0: "f32[s26, 4, s28, s1]", cache_value_cache_1: "f32[s26, 4, s28, s1]", z: "f32[1, 1, 1, s1]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None
            add_1: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s26: VR[2, int_oo], s28: VR[2, int_oo], s1: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 11.039 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery