Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
    CacheKeyValue,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        cache = CacheKeyValue(cache)
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-1.2028, -2.0094,  3.7977,  0.4071,  4.1437, -1.0659, -3.5433],
          [ 0.7795,  2.5811,  1.6132, -2.6445,  3.6840, -0.9590, -0.5871],
          [-1.5822, -0.2605,  2.9487,  1.5768,  4.6462, -0.8236,  0.8009]],

         [[-4.3123, -0.9214,  1.7492, -0.5123,  0.9861,  0.6289,  3.4213],
          [ 1.4077,  2.3735,  0.3241, -4.4911,  4.6145, -1.2205,  0.1954],
          [-5.2932, -1.5919,  1.8236, -2.2762,  5.0160, -1.0473, -4.0363]],

         [[-0.9383,  1.9626,  2.5874, -1.1910,  4.5216,  1.7335, -2.7622],
          [-2.6795, -2.6711,  0.2925,  0.4326,  1.5071, -5.2493,  3.5553],
          [-1.5060,  0.8493, -2.0780,  0.0415,  0.6608,  1.3561, -4.1775]],

         [[-2.3508,  2.0689,  3.0377,  0.8420,  3.3470, -0.8217,  0.6004],
          [-3.0873,  0.8042,  0.5232, -0.2497,  3.0900,  0.7104,  0.1347],
          [-6.1725,  1.6906,  1.3505, -4.7555,  3.3785, -1.2545, -2.0412]]],


        [[[-3.3421,  1.8475,  1.8580, -1.6977, -1.1821,  0.9402, -2.1180],
          [-4.5139, -1.6307,  3.3223,  0.1565, -0.1814, -1.6915,  4.9428],
          [-5.9893,  3.8975,  2.8916,  0.1185,  3.7593, -0.0508,  2.9601]],

         [[ 0.4163,  1.5915, -1.6576,  0.9460,  0.2980, -2.8963,  0.9655],
          [-0.7329, -0.3600,  2.2106, -1.5638,  1.6809, -4.6662, -1.6885],
          [-1.1773,  2.6792, -1.6306, -0.7380,  1.2095,  1.0324, -1.0806]],

         [[-4.2685, -0.9429, -2.5708, -1.6719,  2.9741, -1.0510, -2.0483],
          [-0.1653, -1.1639,  0.0829,  2.3112,  2.8938, -0.0239, -2.9088],
          [-3.6575,  0.9655, -1.4765,  0.3372,  4.0444, -0.5254, -2.4673]],

         [[-1.4710, -2.1635, -0.0500,  0.6997,  2.0070, -1.4374,  1.1824],
          [-0.7175, -0.4376, -0.2689, -1.3328, -0.4840, -3.4535, -0.2466],
          [-2.0959,  3.3507,  1.4909, -0.7559,  1.0200, -3.6894,  3.9002]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

with torch_export_patches(patch_transformers=True):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s13, 4, s10, s38]", cache_key_cache_1: "f32[s13, 4, s10, s38]", cache_value_cache_0: "f32[s13, 4, s10, s38]", cache_value_cache_1: "f32[s13, 4, s10, s38]", z: "f32[1, 1, 1, s38]"):
             #
            sym_size_int_12: "Sym(s38)" = torch.ops.aten.sym_size.int(z, 3)

            # No stacktrace found for following nodes
            sym_size_int: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)
            empty: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_1: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty, cache_key_cache_0], -2);  empty = cache_key_cache_0 = None
            cat_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_1, cache_value_cache_0], -2);  empty_1 = cache_value_cache_0 = None
            empty_2: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_3: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = sym_size_int_12 = None
            cat_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_2, cache_key_cache_1], -2);  empty_2 = cache_key_cache_1 = None
            cat_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_3, cache_value_cache_1], -2);  empty_3 = cache_value_cache_1 = None

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: z
            add: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(z, cat);  z = cat = None
            add_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add, cat_2);  add = cat_2 = None
            add_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_1, cat_1);  add_1 = cat_1 = None
            add_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_2, cat_3);  add_2 = cat_3 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s13: VR[2, int_oo], s10: VR[2, int_oo], s38: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.777 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Gallery generated by Sphinx-Gallery