Export with DynamicCache and guessed dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
    CacheKeyValue,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches


class Model(torch.nn.Module):
    def forward(self, cache, z):
        cache = CacheKeyValue(cache)
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-0.5314, -1.4908, -0.0666, -3.2749, -0.2032,  1.7854, -0.8906],
          [ 2.0694, -0.8017, -0.3830,  4.1040, -1.4288,  2.1290, -2.9534],
          [-1.5209, -0.8063,  3.0383, -3.5673,  0.6447, -3.1693,  3.7234]],

         [[ 1.2455, -3.7786,  0.4174, -1.8757, -3.2814, -1.5201,  0.6478],
          [ 3.4078, -0.6659,  0.9920,  0.0815, -1.9871,  2.2403, -2.9788],
          [-0.6886, -1.2344, -0.0147,  0.7816, -1.3841,  0.4532, -0.3960]],

         [[-0.8871, -0.8897,  0.8777, -3.7946,  3.0315, -0.6781,  0.6775],
          [ 0.8041, -3.3262,  1.1141,  2.2795, -2.0306, -1.5042, -4.7703],
          [ 1.9435, -3.9818,  2.5127,  1.3357,  0.5610,  1.1620, -0.9880]],

         [[ 0.6236,  1.1592, -3.5121,  3.9601, -2.4747,  2.2185, -0.5802],
          [-4.5215, -3.6361, -3.0839,  0.1815,  0.9497,  1.7112, -1.1339],
          [-1.6913, -2.5742,  0.2322,  1.3448, -0.8155, -3.5691, -1.2406]]],


        [[[ 1.5571, -0.3599,  1.9405,  1.8743, -1.6881, -0.3685, -0.3881],
          [-1.5306, -1.0521,  1.9741, -2.7603, -3.3677, -0.9413, -4.5854],
          [ 0.0598,  0.0661,  2.1115, -1.0206, -1.5212, -0.2721,  0.5203]],

         [[ 4.8487,  1.3795, -1.1567,  0.1153, -1.5890,  2.1465, -0.1232],
          [-0.3019,  0.7917,  5.4301, -0.9354, -0.1785, -0.0601, -1.1738],
          [ 2.3280,  2.9614,  0.7471,  2.9736, -1.3895, -0.7652,  1.5043]],

         [[-1.2072, -1.1084,  3.8170,  2.6557, -4.7038,  1.2627, -1.2545],
          [ 1.1424, -3.9008, -3.1237, -1.3487, -1.9008,  3.7849,  2.6191],
          [ 1.5717,  2.3284, -0.8651,  1.3904, -1.7688, -0.5297, -1.1258]],

         [[ 1.4793, -1.9195,  1.3703, -2.1846,  1.0788, -0.6701, -0.0589],
          [-0.3917, -0.4453,  1.9841,  3.1166,  0.3045,  0.0644, -2.6443],
          [ 3.0999,  1.3315,  2.7655,  2.5182, -5.7268,  2.0175, -1.5017]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the second set of inputs looks like:

print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)

Guess the dynamic shapes

The following tool can be used to guess the dynamic shapes the way torch.export.export() expects them.

(([{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)},
   {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)},
   {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)},
   {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC), 3: DimHint(DYNAMIC)}],
  {3: DimHint(DYNAMIC)}),
 {})

And finally the export. The export is simple if transformers>=4.50, otherwise, transformers needs to be patched. onnx_diagnostic.torch_export_patches.torch_export_patches() registers functions to serialize DynamicCache. This one is modified to make the shape inference implemented in torch happy.

with torch_export_patches(patch_transformers=True):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_0: "f32[s50, 4, s15, s19]", cache_value_0: "f32[s50, 4, s15, s19]", cache_key_1: "f32[s50, 4, s15, s19]", cache_value_1: "f32[s50, 4, s15, s19]", z: "f32[1, 1, 1, s19]"):
             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:40 in forward, code: z
            add: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(z, cache_key_0);  z = cache_key_0 = None
            add_1: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(add, cache_key_1);  add = cache_key_1 = None
            add_2: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(add_1, cache_value_0);  add_1 = cache_value_0 = None
            add_3: "f32[s50, 4, s15, s19]" = torch.ops.aten.add.Tensor(add_2, cache_value_1);  add_2 = cache_value_1 = None
            return (add_3,)

Graph signature:
    # inputs
    cache_key_0: USER_INPUT
    cache_value_0: USER_INPUT
    cache_key_1: USER_INPUT
    cache_value_1: USER_INPUT
    z: USER_INPUT

    # outputs
    add_3: USER_OUTPUT

Range constraints: {s50: VR[2, int_oo], s15: VR[2, int_oo], s19: VR[2, int_oo]}

Use string instead of DYNAMIC

ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.

(([{0: 'dim_0I_0o0', 2: 'dim_0I_0o2', 3: 'dim_0I_0o3'},
   {0: 'dim_0I_1o0', 2: 'dim_0I_1o2', 3: 'dim_0I_1o3'},
   {0: 'dim_0I_2o0', 2: 'dim_0I_2o2', 3: 'dim_0I_2o3'},
   {0: 'dim_0I_3o0', 2: 'dim_0I_3o2', 3: 'dim_0I_3o3'}],
  {3: 'dim_1I3'}),
 {})

Do we need to guess?

Function onnx_diagnostic.helpers.string_type() is using the serialization functions to print out the DynamicCache the was torch.export.export() expects them.

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])

You can also use function onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes() to show a DynamicCache restructured the way torch.export.export() expects it to be without the custom class.

print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#4[T1s2x4x3x7,T1s2x4x3x7,T1s2x4x3x7,T1s2x4x3x7]

This code works for any custom class if it was registered with torch.utils._pytree.register_pytree_node().

doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")
plot export with dynamic cache

Total running time of the script: (0 minutes 0.372 seconds)

Related examples

Dynamic Shapes for *args, **kwargs

Dynamic Shapes for args, *kwargs

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Find where a model is failing by running submodels

Find where a model is failing by running submodels

Gallery generated by Sphinx-Gallery