Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-1.1534, -0.4810, -2.7559,  3.5524, -2.6642, -6.1048,  1.8736],
          [-3.9018,  0.7266, -0.9381,  2.3329, -0.8829, -3.8779,  5.9116],
          [-1.8228,  0.0299, -0.7934,  4.2816, -6.3535, -4.2360,  0.2349]],

         [[-4.2640, -0.9247, -0.1411, -0.1737, -2.9008,  0.1997, -2.6438],
          [ 0.0840,  0.3889, -1.0523, -2.6057,  1.6853, -2.1530,  0.4002],
          [-1.9219,  3.2423, -1.8033, -0.0108,  2.5272, -3.9575, -0.7794]],

         [[ 0.4002,  0.3289, -2.8407, -0.4206, -1.8948, -3.8877,  0.5374],
          [-2.2765, -0.7927, -4.1071, -0.1069,  0.7096, -4.7109, -1.4282],
          [-0.4841, -1.5750, -5.2166,  2.3293,  0.1353, -3.6288,  2.4231]],

         [[ 1.7191,  4.5104, -3.7494,  1.9053,  0.3526, -2.1625, -2.1564],
          [-0.8744,  0.0809, -1.1439,  0.9984, -1.3885, -1.2748, -0.3671],
          [-0.1743, -2.0525, -1.0917,  0.8684,  0.4196,  0.3833,  1.2645]]],


        [[[-0.3088,  0.2936,  0.0545,  2.4301, -2.2080, -1.8228,  1.0070],
          [ 1.7969,  1.2497, -3.1076,  1.4958, -0.7197, -4.4798,  1.3573],
          [-0.7751,  1.8067, -0.6388, -4.9977,  1.1114, -0.6417,  1.9996]],

         [[ 0.3497, -0.9197, -5.6521,  1.2137,  1.2347, -2.8634,  1.7814],
          [ 0.4986, -2.1289, -0.8862,  2.4905, -3.3386,  1.3207,  1.8182],
          [-0.2019, -6.7506, -0.7204,  2.6980, -1.3016, -3.3119,  1.6042]],

         [[-0.1474, -0.1328, -1.6456,  1.4732,  0.2217, -4.2840,  0.6970],
          [ 2.9944,  3.0970, -0.5592,  2.8136,  0.2917,  0.7529,  0.5986],
          [-0.7545, -0.6447, -2.5607,  2.9409, -0.1490, -0.8170, -3.2199]],

         [[ 1.5678, -1.4138, -0.4724, -0.3759,  3.2607, -1.3677,  0.0093],
          [ 0.5380,  0.7210,  0.5958, -0.3307, -3.4815, -0.7025,  0.2686],
          [-1.3187, -0.6886, -2.0378,  3.1369,  1.6712,  0.9054,  0.0266]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

if has_transformers("4.50"):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
    with torch_export_patches(patch_transformers=True) as modificator:
        ep = torch.export.export(
            model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
        )
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s13, 4, s10, s38]", cache_key_cache_1: "f32[s13, 4, s10, s38]", cache_value_cache_0: "f32[s13, 4, s10, s38]", cache_value_cache_1: "f32[s13, 4, s10, s38]", z: "f32[1, 1, 1, s38]"):
             #
            sym_size_int: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)
            sym_size_int_1: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 2)
            sym_size_int_2: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 3)
            sym_size_int_3: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 0)
            sym_size_int_4: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 2)
            sym_size_int_5: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 3)
            sym_size_int_6: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 0)
            sym_size_int_7: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 2)
            sym_size_int_8: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 3)
            sym_size_int_9: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 0)
            sym_size_int_10: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 2)
            sym_size_int_11: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 3)
            sym_size_int_12: "Sym(s38)" = torch.ops.aten.sym_size.int(z, 3)

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None

             #
            eq: "Sym(True)" = sym_size_int_12 == sym_size_int_2;  sym_size_int_2 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s38, s92) on node 'eq'");  eq = _assert_scalar_default = None
            eq_1: "Sym(True)" = sym_size_int_12 == sym_size_int_5;  sym_size_int_5 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s38, s57) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
            eq_2: "Sym(True)" = sym_size_int_1 == sym_size_int_4;  sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s10, s29) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None
            eq_3: "Sym(True)" = sym_size_int == sym_size_int_3;  sym_size_int_3 = None
            _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(s33, s93) on node 'eq_3'");  eq_3 = _assert_scalar_default_3 = None
            eq_4: "Sym(True)" = sym_size_int_12 == sym_size_int_8;  sym_size_int_8 = None
            _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(eq_4, "Runtime assertion failed for expression Eq(s38, s53) on node 'eq_4'");  eq_4 = _assert_scalar_default_4 = None
            eq_5: "Sym(True)" = sym_size_int_1 == sym_size_int_7;  sym_size_int_7 = None
            _assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(eq_5, "Runtime assertion failed for expression Eq(s10, s19) on node 'eq_5'");  eq_5 = _assert_scalar_default_5 = None
            eq_6: "Sym(True)" = sym_size_int == sym_size_int_6;  sym_size_int = None
            _assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(eq_6, "Runtime assertion failed for expression Eq(s33, s13) on node 'eq_6'");  eq_6 = _assert_scalar_default_6 = None
            eq_7: "Sym(True)" = sym_size_int_12 == sym_size_int_11;  sym_size_int_12 = sym_size_int_11 = None
            _assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(eq_7, "Runtime assertion failed for expression Eq(s38, s97) on node 'eq_7'");  eq_7 = _assert_scalar_default_7 = None
            eq_8: "Sym(True)" = sym_size_int_1 == sym_size_int_10;  sym_size_int_1 = sym_size_int_10 = None
            _assert_scalar_default_8 = torch.ops.aten._assert_scalar.default(eq_8, "Runtime assertion failed for expression Eq(s10, s42) on node 'eq_8'");  eq_8 = _assert_scalar_default_8 = None
            eq_9: "Sym(True)" = sym_size_int_6 == sym_size_int_9;  sym_size_int_6 = sym_size_int_9 = None
            _assert_scalar_default_9 = torch.ops.aten._assert_scalar.default(eq_9, "Runtime assertion failed for expression Eq(s13, s25) on node 'eq_9'");  eq_9 = _assert_scalar_default_9 = None

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s13: VR[2, int_oo], s10: VR[2, int_oo], s38: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.362 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery