Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[ 1.9451, -0.9775,  1.9891, -2.8906, -0.9322,  2.3389, -0.7793],
          [ 3.4316, -1.8031, -0.0809, -0.8129, -0.9304, -0.4021, -1.0814],
          [-2.0624,  0.4372,  1.8958, -0.0731, -1.2476,  3.4713,  0.1929]],

         [[-0.8630,  4.0720, -2.4035,  0.5047,  0.1148, -0.5308, -1.5520],
          [ 0.2530, -1.2233,  0.7213,  3.8158,  2.4889,  6.7033, -0.0673],
          [-0.9790,  0.2772,  0.9464,  0.6800, -2.3057,  3.3593, -3.1132]],

         [[ 1.9605,  1.4405, -0.5980,  1.6304, -0.8142,  0.8462, -2.8358],
          [-1.2914,  0.6726, -2.6104,  3.2841, -1.5837,  2.6457, -6.2198],
          [-0.9668,  1.5010,  3.7879,  3.1526, -3.6653,  2.2060, -0.7938]],

         [[-2.4279,  1.2829,  0.9340,  0.1455, -0.2796, -0.8558, -1.0883],
          [ 1.4644,  0.1280, -2.4767,  1.4294, -1.3800, -1.5716,  0.3952],
          [-0.8183,  0.7517, -2.0131,  2.1516,  0.3757, -1.5169, -2.8511]]],


        [[[-1.6263,  0.6243,  2.0107,  0.4918, -0.5069,  2.3118,  0.6086],
          [ 2.7925,  0.0598, -0.5227,  0.9858,  2.6931,  5.3477, -0.6900],
          [ 5.8429,  0.8649,  2.9418, -0.6360,  0.0350,  2.6317,  0.0601]],

         [[ 0.3755,  1.5707,  2.8121,  1.5394, -5.9745,  5.3113,  1.0914],
          [-0.5124,  3.3240, -0.1680, -1.0117, -1.0346,  2.9493, -1.3418],
          [ 0.6812,  0.0858,  2.9680, -1.3174, -2.9940,  0.7741, -1.5145]],

         [[ 1.4348,  1.6729,  3.1750,  2.3692, -4.0115, -0.9219,  2.0874],
          [ 1.5981, -0.0751,  0.0094,  5.3372, -1.4340,  2.3836,  0.6001],
          [ 1.7987, -2.2685,  1.9256,  1.6146, -0.0967,  1.9729, -3.0380]],

         [[-2.3862,  1.0332, -0.9314,  0.9250,  2.5707,  1.0058, -3.2224],
          [ 1.7988,  1.5077,  2.9759,  0.0201, -4.6480, -0.1775, -1.6810],
          [ 0.0271,  1.2673, -3.4595,  2.0131, -2.4796,  0.4162, -0.0272]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

if has_transformers("4.50"):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
    with torch_export_patches(patch_transformers=True) as modificator:
        ep = torch.export.export(
            model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
        )
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s26, 4, s28, s1]", cache_key_cache_1: "f32[s26, 4, s28, s1]", cache_value_cache_0: "f32[s26, 4, s28, s1]", cache_value_cache_1: "f32[s26, 4, s28, s1]", z: "f32[1, 1, 1, s1]"):
             #
            sym_size_int: "Sym(s26)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)
            sym_size_int_1: "Sym(s28)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 2)
            sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 3)
            sym_size_int_3: "Sym(s26)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 0)
            sym_size_int_4: "Sym(s28)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 2)
            sym_size_int_5: "Sym(s1)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 3)
            sym_size_int_6: "Sym(s26)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 0)
            sym_size_int_7: "Sym(s28)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 2)
            sym_size_int_8: "Sym(s1)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 3)
            sym_size_int_9: "Sym(s26)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 0)
            sym_size_int_10: "Sym(s28)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 2)
            sym_size_int_11: "Sym(s1)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 3)
            sym_size_int_12: "Sym(s1)" = torch.ops.aten.sym_size.int(z, 3)

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None

             #
            eq: "Sym(True)" = sym_size_int_12 == sym_size_int_2;  sym_size_int_12 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s98, s25) on node 'eq'");  eq = _assert_scalar_default = None
            eq_1: "Sym(True)" = sym_size_int_2 == sym_size_int_5;  sym_size_int_5 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s25, s52) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
            eq_2: "Sym(True)" = sym_size_int_1 == sym_size_int_4;  sym_size_int_1 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s83, s28) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None
            eq_3: "Sym(True)" = sym_size_int == sym_size_int_3;  sym_size_int_3 = None
            _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(s26, s46) on node 'eq_3'");  eq_3 = _assert_scalar_default_3 = None
            eq_4: "Sym(True)" = sym_size_int_2 == sym_size_int_8;  sym_size_int_8 = None
            _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(eq_4, "Runtime assertion failed for expression Eq(s25, s84) on node 'eq_4'");  eq_4 = _assert_scalar_default_4 = None
            eq_5: "Sym(True)" = sym_size_int_4 == sym_size_int_7;  sym_size_int_7 = None
            _assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(eq_5, "Runtime assertion failed for expression Eq(s28, s59) on node 'eq_5'");  eq_5 = _assert_scalar_default_5 = None
            eq_6: "Sym(True)" = sym_size_int == sym_size_int_6;  sym_size_int_6 = None
            _assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(eq_6, "Runtime assertion failed for expression Eq(s26, s54) on node 'eq_6'");  eq_6 = _assert_scalar_default_6 = None
            eq_7: "Sym(True)" = sym_size_int_2 == sym_size_int_11;  sym_size_int_2 = sym_size_int_11 = None
            _assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(eq_7, "Runtime assertion failed for expression Eq(s25, s1) on node 'eq_7'");  eq_7 = _assert_scalar_default_7 = None
            eq_8: "Sym(True)" = sym_size_int_4 == sym_size_int_10;  sym_size_int_4 = sym_size_int_10 = None
            _assert_scalar_default_8 = torch.ops.aten._assert_scalar.default(eq_8, "Runtime assertion failed for expression Eq(s28, s74) on node 'eq_8'");  eq_8 = _assert_scalar_default_8 = None
            eq_9: "Sym(True)" = sym_size_int == sym_size_int_9;  sym_size_int = sym_size_int_9 = None
            _assert_scalar_default_9 = torch.ops.aten._assert_scalar.default(eq_9, "Runtime assertion failed for expression Eq(s26, s29) on node 'eq_9'");  eq_9 = _assert_scalar_default_9 = None

             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add_1: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s26, 4, s28, s1]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s26: VR[2, int_oo], s28: VR[2, int_oo], s1: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.252 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Gallery generated by Sphinx-Gallery