Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
    CacheKeyValue,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        cache = CacheKeyValue(cache)
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[ 1.2065e+00,  4.2592e+00, -5.0596e-01,  3.7851e+00,  1.6995e+00,
           -1.2802e+00, -1.5096e+00],
          [-8.6155e-01, -1.9605e+00,  1.9830e+00,  2.5508e+00,  2.0251e+00,
            2.2704e+00,  4.0910e-02],
          [ 1.5919e-01, -9.7137e-01,  9.0396e-01,  3.2832e+00,  3.9668e+00,
            1.5895e+00, -4.3324e+00]],

         [[ 3.3668e+00, -1.4471e+00, -2.2153e-01,  1.8183e+00,  4.5027e-01,
            1.2779e+00,  3.7411e-01],
          [-1.9475e+00, -4.5119e-01,  1.2478e+00,  2.5660e+00,  2.7818e+00,
           -1.8669e+00,  4.1276e-01],
          [ 1.8003e+00,  5.3775e-01,  1.5675e+00, -4.7837e-01,  5.3404e-01,
            4.6559e-01,  8.6670e-01]],

         [[ 2.4356e+00,  4.0516e-01,  1.2190e+00,  2.9211e+00,  4.7782e-01,
           -1.4041e+00, -8.1459e-01],
          [ 2.0599e+00, -1.1782e+00,  1.6312e+00,  1.5066e+00,  1.9275e+00,
            3.4996e+00,  7.6782e-01],
          [ 3.4057e+00,  3.9317e-01,  2.6442e+00, -3.0205e-01,  1.1599e+00,
            2.8338e+00, -1.2043e+00]],

         [[-2.3608e+00,  3.0060e+00, -1.7603e-01,  1.0714e+00, -8.0234e-01,
           -1.2117e-01,  6.9854e-01],
          [-2.1983e-01,  1.6605e+00,  2.8058e+00, -1.4066e-02, -6.4126e-01,
           -2.0381e+00,  1.0112e+00],
          [ 2.1301e-01, -1.8910e+00,  1.2207e+00,  1.6962e+00,  3.8539e+00,
            2.5187e+00,  4.4274e-01]]],


        [[[-1.0759e+00,  8.8161e-01, -1.2863e+00, -6.8479e-01,  3.4326e+00,
           -9.7247e-01,  1.9581e+00],
          [-4.9716e+00, -7.2541e-01,  2.1058e+00, -1.4056e+00, -1.7904e+00,
            3.2410e+00,  1.1777e-01],
          [-1.3734e+00, -1.0456e+00,  4.2197e-01,  2.0193e+00,  7.3696e-01,
           -1.0122e+00,  1.0726e+00]],

         [[-7.7517e-01, -2.5451e+00,  3.4221e+00,  1.3313e+00,  2.8932e+00,
           -8.5541e-01,  2.1524e+00],
          [-1.6473e+00, -2.5183e+00, -5.5573e-01, -4.3489e-03,  1.2963e+00,
            1.5930e+00, -1.7775e+00],
          [-2.1222e+00,  1.3105e+00,  1.7864e+00, -1.0606e+00,  2.2835e+00,
            2.3810e+00, -4.6853e-01]],

         [[ 5.2642e-01,  3.6147e-01, -1.1354e+00,  2.2141e+00,  2.3492e+00,
            3.7347e-01, -9.1122e-01],
          [-9.5342e-01,  6.7934e-01,  1.3236e+00,  2.9452e+00,  3.2378e+00,
            3.8519e+00,  2.2207e+00],
          [ 1.1340e+00, -5.2686e-01,  3.0743e+00,  5.1192e-01,  2.7586e+00,
           -2.3406e-01, -4.7315e+00]],

         [[-1.7984e-01, -5.5359e+00,  4.4568e+00, -2.4558e+00,  2.8290e+00,
           -1.3754e+00, -6.1635e-01],
          [ 4.3120e-01, -9.4078e-01,  4.4439e+00,  3.2635e+00,  1.1788e+00,
            2.2735e+00,  5.2567e-01],
          [ 1.7179e+00,  7.3110e-01,  3.8122e+00,  8.0327e-01,  6.0578e+00,
           -9.5705e-01,  2.2877e-01]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

with torch_export_patches(patch_transformers=True):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s13, 4, s10, s38]", cache_key_cache_1: "f32[s13, 4, s10, s38]", cache_value_cache_0: "f32[s13, 4, s10, s38]", cache_value_cache_1: "f32[s13, 4, s10, s38]", z: "f32[1, 1, 1, s38]"):
             #
            sym_size_int_6: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 3)
            sym_stride_int: "Sym(s38)" = torch.ops.aten.sym_stride.int(cache_value_cache_0, 2)
            sym_size_int_12: "Sym(s38)" = torch.ops.aten.sym_size.int(z, 3)

            # No stacktrace found for following nodes
            sym_size_int: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)

             #
            eq: "Sym(True)" = sym_stride_int == sym_size_int_6;  sym_stride_int = sym_size_int_6 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s53, s92) on node 'eq'");  eq = _assert_scalar_default = None

            # No stacktrace found for following nodes
            empty: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_1: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            cat: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty, cache_key_cache_0], -2);  empty = cache_key_cache_0 = None
            cat_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_1, cache_value_cache_0], -2);  empty_1 = cache_value_cache_0 = None
            empty_2: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            empty_3: "f32[s13, 4, 0, s38]" = torch.ops.aten.empty.memory_format([sym_size_int, 4, 0, sym_size_int_12], dtype = torch.float32, device = device(type='cpu'), pin_memory = False);  sym_size_int = sym_size_int_12 = None
            cat_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_2, cache_key_cache_1], -2);  empty_2 = cache_key_cache_1 = None
            cat_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.cat.default([empty_3, cache_value_cache_1], -2);  empty_3 = cache_value_cache_1 = None

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: z
            add: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(z, cat);  z = cat = None
            add_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add, cat_2);  add = cat_2 = None
            add_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_1, cat_1);  add_1 = cat_1 = None
            add_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_2, cat_3);  add_2 = cat_3 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s13: VR[2, int_oo], s10: VR[2, int_oo], s38: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 5.635 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Gallery generated by Sphinx-Gallery