Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-6.7487,  2.8074, -2.1156, -1.7443,  2.5823, -0.0527,  4.6709],
          [-0.3532,  0.6203,  0.0793,  1.1731,  1.5046, -0.1676, -3.6105],
          [-0.1634, -4.4359, -0.8863,  4.9206,  0.6399,  1.3288, -1.2070]],

         [[-1.4193, -1.2084,  0.0660,  1.7674, -4.1382,  0.5093, -1.1658],
          [-1.2666, -0.2242, -3.2339, -1.0332, -3.1076,  2.0630, -0.0520],
          [ 1.3251, -0.4194,  1.6129,  1.7321,  1.2646,  1.0794, -0.4160]],

         [[ 1.0982,  0.7019,  2.6615, -0.3968,  5.2689, -4.9040,  4.6290],
          [ 0.9750,  0.3986, -1.2881,  0.0122, -0.0859, -2.9513,  2.3343],
          [-0.4828, -2.9750,  0.3233,  1.2443, -4.0761,  0.8178, -1.2120]],

         [[ 2.7832,  4.4936,  0.5110, -1.5235, -0.2975, -0.7391,  2.4967],
          [ 0.5324, -1.6645, -1.3523,  1.3111, -0.9262, -2.2256, -1.4111],
          [ 4.8621, -2.1650,  1.5928,  3.3269,  2.0797, -2.3475, -1.9845]]],


        [[[-3.0808, -0.6199,  3.7311, -2.5959, -0.6766,  3.2606,  1.3339],
          [ 3.5423,  1.3965, -0.6126,  4.1131,  3.7087,  0.8477,  0.5217],
          [ 1.1949,  1.7462,  0.5437,  1.3466,  2.0851, -2.4452, -2.7668]],

         [[ 2.5374, -4.0315,  1.2335, -0.0569, -1.6998,  2.0486, -0.1644],
          [ 1.1178, -0.7027,  4.8050,  0.1444,  3.4636,  0.5256,  2.6754],
          [-1.1968, -1.1635,  0.9116,  1.1848,  0.2559,  0.6731, -3.4396]],

         [[ 0.6223,  0.7736, -0.4264,  2.0707,  0.4150, -0.9640,  3.6546],
          [ 3.1889,  0.1215, -1.6996,  4.0112,  1.0167, -3.5966,  0.0551],
          [ 0.1605,  0.1657,  3.9746,  1.4932,  0.6577,  0.9781, -0.1961]],

         [[ 3.8955, -3.4226,  0.4445,  1.6890, -0.8967,  3.0276,  0.4698],
          [ 0.1494, -1.3543, -1.0624,  1.0602,  0.6952,  2.2687,  1.7411],
          [ 0.0138, -3.2099, -4.7813, -0.3297, -2.9317,  0.1004,  0.4046]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

if has_transformers("4.50"):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
    with torch_export_patches(patch_transformers=True) as modificator:
        ep = torch.export.export(
            model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
        )
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s13, 4, s10, s38]", cache_key_cache_1: "f32[s13, 4, s10, s38]", cache_value_cache_0: "f32[s13, 4, s10, s38]", cache_value_cache_1: "f32[s13, 4, s10, s38]", z: "f32[1, 1, 1, s38]"):
             #
            sym_size_int: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)
            sym_size_int_1: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 2)
            sym_size_int_2: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 3)
            sym_size_int_3: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 0)
            sym_size_int_4: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 2)
            sym_size_int_5: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 3)
            sym_size_int_6: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 0)
            sym_size_int_7: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 2)
            sym_size_int_8: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 3)
            sym_size_int_9: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 0)
            sym_size_int_10: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 2)
            sym_size_int_11: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 3)
            sym_size_int_12: "Sym(s38)" = torch.ops.aten.sym_size.int(z, 3)

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None

             #
            eq: "Sym(True)" = sym_size_int_12 == sym_size_int_2;  sym_size_int_2 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s38, s92) on node 'eq'");  eq = _assert_scalar_default = None
            eq_1: "Sym(True)" = sym_size_int_12 == sym_size_int_5;  sym_size_int_5 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s38, s57) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
            eq_2: "Sym(True)" = sym_size_int_1 == sym_size_int_4;  sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s10, s29) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None
            eq_3: "Sym(True)" = sym_size_int == sym_size_int_3;  sym_size_int_3 = None
            _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(s33, s93) on node 'eq_3'");  eq_3 = _assert_scalar_default_3 = None
            eq_4: "Sym(True)" = sym_size_int_12 == sym_size_int_8;  sym_size_int_8 = None
            _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(eq_4, "Runtime assertion failed for expression Eq(s38, s53) on node 'eq_4'");  eq_4 = _assert_scalar_default_4 = None
            eq_5: "Sym(True)" = sym_size_int_1 == sym_size_int_7;  sym_size_int_7 = None
            _assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(eq_5, "Runtime assertion failed for expression Eq(s10, s19) on node 'eq_5'");  eq_5 = _assert_scalar_default_5 = None
            eq_6: "Sym(True)" = sym_size_int == sym_size_int_6;  sym_size_int = None
            _assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(eq_6, "Runtime assertion failed for expression Eq(s33, s13) on node 'eq_6'");  eq_6 = _assert_scalar_default_6 = None
            eq_7: "Sym(True)" = sym_size_int_12 == sym_size_int_11;  sym_size_int_12 = sym_size_int_11 = None
            _assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(eq_7, "Runtime assertion failed for expression Eq(s38, s97) on node 'eq_7'");  eq_7 = _assert_scalar_default_7 = None
            eq_8: "Sym(True)" = sym_size_int_1 == sym_size_int_10;  sym_size_int_1 = sym_size_int_10 = None
            _assert_scalar_default_8 = torch.ops.aten._assert_scalar.default(eq_8, "Runtime assertion failed for expression Eq(s10, s42) on node 'eq_8'");  eq_8 = _assert_scalar_default_8 = None
            eq_9: "Sym(True)" = sym_size_int_6 == sym_size_int_9;  sym_size_int_6 = sym_size_int_9 = None
            _assert_scalar_default_9 = torch.ops.aten._assert_scalar.default(eq_9, "Runtime assertion failed for expression Eq(s13, s25) on node 'eq_9'");  eq_9 = _assert_scalar_default_9 = None

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s13: VR[2, int_oo], s10: VR[2, int_oo], s38: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.467 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery