Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
    CacheKeyValue,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        cache = CacheKeyValue(cache)
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-3.0480, -1.8073,  0.6495, -0.7381, -0.8798,  0.5898, -2.8004],
          [ 2.1842, -1.8548, -3.0003, -1.6655,  2.3879,  1.8654, -1.1088],
          [-1.7419, -2.0483, -2.1441,  1.7263,  4.8609, -0.9125, -1.6990]],

         [[-0.4475,  1.7122, -3.2443, -0.4818,  0.9170,  0.1613, -0.2010],
          [ 0.2781, -0.0750, -0.6615, -0.1611,  0.5996, -0.1890,  3.3484],
          [-0.7312, -1.6378,  0.4120,  1.5301, -2.1836, -0.8944, -3.0169]],

         [[ 3.9477, -0.5752, -2.4382,  2.6695, -2.8392,  0.0642, -0.9968],
          [-1.1000,  0.0556, -1.9869, -2.0721,  2.3018, -5.9606,  4.7270],
          [ 1.2055,  1.9346, -3.4212,  1.4522,  1.0991, -5.5324,  0.9484]],

         [[ 2.4892,  0.8174, -4.6810,  2.3765,  0.5757,  0.2363,  2.1117],
          [ 0.7241, -4.0972,  0.0493,  0.6650, -2.3097, -2.5479, -0.7381],
          [-0.1486, -0.8967, -0.6537,  1.8371,  1.6975, -3.0013,  2.1476]]],


        [[[-0.4016, -1.4695, -4.8510,  0.7859,  2.9136,  0.7801,  1.2864],
          [-1.4630, -1.6314,  0.2547,  5.5747,  2.5220, -3.9564,  1.0161],
          [ 0.9736, -0.3289, -1.5126,  0.4241,  1.9342,  2.9816,  0.6241]],

         [[-1.2355, -0.6681,  1.0512, -0.2258,  4.1040, -1.7015,  3.7405],
          [ 4.3105,  0.4773, -2.0226,  0.2613,  3.1922,  1.3685,  3.4226],
          [-3.1643, -2.3023, -1.7919,  2.6631, -1.6354, -4.4506, -0.5819]],

         [[-0.3613,  3.2567,  0.2353, -4.7029, -0.6312, -0.1357,  0.9231],
          [-0.2292, -1.0579, -1.3913, -0.5051, -0.3587, -0.4330, -0.1460],
          [ 0.4194, -4.6027, -3.1670,  0.8345,  0.5804, -1.5147, -1.8612]],

         [[-2.3072,  0.2388, -6.0574, -0.2454,  2.4291, -3.4830,  2.0750],
          [ 1.1975, -0.3019, -2.6884,  2.4735,  0.0962, -1.7582, -0.2080],
          [ 3.1556, -7.0135,  1.6012,  3.5653,  0.8449, -1.1759,  4.7246]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)},
   {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)},
   {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)},
   {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)}],
  {3: DimHint(DYNAMIC)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

with torch_export_patches(patch_transformers=True):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_0: "f32[s50, 4, s15, s19]", cache_value_0: "f32[s50, 4, s15, s19]", cache_key_1: "f32[s50, 4, s15, s19]", cache_value_1: "f32[s50, 4, s15, s19]", z: "f32[1, 1, 1, s19]"):
            # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: z
            add: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(z, cache_key_0);  z = cache_key_0 = None
            add_1: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(add, cache_key_1);  add = cache_key_1 = None
            add_2: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(add_1, cache_value_0);  add_1 = cache_value_0 = None
            add_3: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(add_2, cache_value_1);  add_2 = cache_value_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_0: USER_INPUT
    cache_value_0: USER_INPUT
    cache_key_1: USER_INPUT
    cache_value_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s50: VR[2, int_oo], s15: VR[2, int_oo], s19: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([{0: 'dim_0I_0o0', 2: 'dim_0I_0o2', 3: 'dim_0I_0o3'},
   {0: 'dim_0I_1o0', 2: 'dim_0I_1o2', 3: 'dim_0I_1o3'},
   {0: 'dim_0I_2o0', 2: 'dim_0I_2o2', 3: 'dim_0I_2o3'},
   {0: 'dim_0I_3o0', 2: 'dim_0I_3o2', 3: 'dim_0I_3o3'}],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#4[T1s2x4x3x7,T1s2x4x3x7,T1s2x4x3x7,T1s2x4x3x7]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.204 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Gallery generated by Sphinx-Gallery