Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
    CacheKeyValue,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        cache = CacheKeyValue(cache)
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-1.3243,  3.5940, -1.5383,  4.3467, -2.7638, -2.4726, -3.9603],
          [-2.1588,  1.6287,  0.5150,  2.5715,  2.5243,  0.4649, -2.6767],
          [-2.3344, -0.3294,  0.1188, -1.2026, -2.3646, -2.9379, -2.2982]],

         [[ 2.2873, -2.3053,  3.1252, -0.0739, -2.1389, -1.3499, -4.6829],
          [-1.0878,  1.5745,  1.5970,  0.1743, -1.1593, -1.3129,  1.6321],
          [-2.1987,  1.5210,  0.4771,  0.2057, -2.1477, -0.6378,  0.3050]],

         [[-2.2894,  3.6120,  1.5569,  2.7154, -0.4645,  2.6321, -1.1310],
          [-0.4714,  3.9739,  3.7457,  0.6174, -3.8009, -1.4083, -1.8550],
          [-0.4066,  1.3037,  0.5129,  2.7619, -3.0350, -3.6074,  0.2571]],

         [[-0.3311, -4.1867,  3.2018, -2.0846, -3.0482, -1.3925, -2.7517],
          [ 0.1169, -2.2490,  2.2928, -0.7347, -1.2158, -2.0195, -4.9589],
          [-1.8431, -0.5776, -1.7848,  0.7713, -3.1344,  0.3506,  1.0216]]],


        [[[ 3.1785,  0.6576,  2.8110,  4.7334,  0.8651,  1.1011, -5.7397],
          [-0.7448, -6.0601,  1.0587,  2.7385, -1.9832,  1.8390, -1.7844],
          [ 1.4589, -1.8981, -2.3090, -5.7446,  0.2227, -1.3392, -2.6937]],

         [[-3.0801,  3.3459,  0.0348,  2.9240,  0.7747, -2.7957, -0.1278],
          [ 2.8298,  0.5449,  1.4096, -0.4177, -1.6747, -0.4843, -3.1566],
          [ 0.6165,  1.6457, -1.2478, -1.1977, -2.4278,  0.5445, -0.4396]],

         [[ 0.0746,  1.1348,  1.3557,  0.9144, -4.0249,  1.3906,  3.0441],
          [-2.3378, -1.6887,  0.7975,  2.5484,  2.9201, -1.6022,  2.8166],
          [-0.0416, -1.6263,  2.6844,  1.5435, -1.6131, -2.1283, -3.7662]],

         [[ 0.9089,  5.2677,  3.5966,  0.4998, -2.4410, -2.3898, -1.2408],
          [-3.2475, -0.5465, -3.6560, -0.5062, -4.8458, -0.1494, -1.6229],
          [ 1.3604,  0.8401, -0.5811,  0.9929, -0.9201, -0.0325,  2.8906]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

with torch_export_patches(patch_transformers=True):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s13, 4, s10, s38]", cache_key_cache_1: "f32[s13, 4, s10, s38]", cache_value_cache_0: "f32[s13, 4, s10, s38]", cache_value_cache_1: "f32[s13, 4, s10, s38]", z: "f32[1, 1, 1, s38]"):
             #
            sym_size_int_12: "Sym(s38)" = torch.ops.aten.sym_size.int(z, 3)

            # No stacktrace found for following nodes
            sym_size_int: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)
            empty: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_1: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty, cache_key_cache_0], -2);  empty = cache_key_cache_0 = None
            cat_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_1, cache_value_cache_0], -2);  empty_1 = cache_value_cache_0 = None
            empty_2: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_3: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = sym_size_int_12 = None
            cat_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_2, cache_key_cache_1], -2);  empty_2 = cache_key_cache_1 = None
            cat_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_3, cache_value_cache_1], -2);  empty_3 = cache_value_cache_1 = None

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: z
            add: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(z, cat);  z = cat = None
            add_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add, cat_2);  add = cat_2 = None
            add_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_1, cat_1);  add_1 = cat_1 = None
            add_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_2, cat_3);  add_2 = cat_3 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s13: VR[2, int_oo], s10: VR[2, int_oo], s38: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.460 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Gallery generated by Sphinx-Gallery