Note
Go to the end to download the full example code.
Export with DynamicCache and guessed dynamic shapes¶
Every LLMs implemented in transformers use cache.
One of the most used is transformers.cache_utils.DynamicCache.
The cache size is dynamic to cope with the growing context.
The example shows a tool which determines the dynamic shapes
for torch.export.export() based on a set of valid inputs.
DynamicCache¶
torch.export.export() serializes caches and any custom class
if these serialization functions are provided with is the case for
transformers.cache_utils.DynamicCache and transformers>=4.50.
The dynamic shapes must be provided following the serialized form.
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.ext_test_case import has_transformers
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import (
    flatten_unflatten_for_dynamic_shapes,
    make_dynamic_cache,
)
from onnx_diagnostic.export import ModelInputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )
model = Model()
n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-1.1534, -0.4810, -2.7559,  3.5524, -2.6642, -6.1048,  1.8736],
          [-3.9018,  0.7266, -0.9381,  2.3329, -0.8829, -3.8779,  5.9116],
          [-1.8228,  0.0299, -0.7934,  4.2816, -6.3535, -4.2360,  0.2349]],
         [[-4.2640, -0.9247, -0.1411, -0.1737, -2.9008,  0.1997, -2.6438],
          [ 0.0840,  0.3889, -1.0523, -2.6057,  1.6853, -2.1530,  0.4002],
          [-1.9219,  3.2423, -1.8033, -0.0108,  2.5272, -3.9575, -0.7794]],
         [[ 0.4002,  0.3289, -2.8407, -0.4206, -1.8948, -3.8877,  0.5374],
          [-2.2765, -0.7927, -4.1071, -0.1069,  0.7096, -4.7109, -1.4282],
          [-0.4841, -1.5750, -5.2166,  2.3293,  0.1353, -3.6288,  2.4231]],
         [[ 1.7191,  4.5104, -3.7494,  1.9053,  0.3526, -2.1625, -2.1564],
          [-0.8744,  0.0809, -1.1439,  0.9984, -1.3885, -1.2748, -0.3671],
          [-0.1743, -2.0525, -1.0917,  0.8684,  0.4196,  0.3833,  1.2645]]],
        [[[-0.3088,  0.2936,  0.0545,  2.4301, -2.2080, -1.8228,  1.0070],
          [ 1.7969,  1.2497, -3.1076,  1.4958, -0.7197, -4.4798,  1.3573],
          [-0.7751,  1.8067, -0.6388, -4.9977,  1.1114, -0.6417,  1.9996]],
         [[ 0.3497, -0.9197, -5.6521,  1.2137,  1.2347, -2.8634,  1.7814],
          [ 0.4986, -2.1289, -0.8862,  2.4905, -3.3386,  1.3207,  1.8182],
          [-0.2019, -6.7506, -0.7204,  2.6980, -1.3016, -3.3119,  1.6042]],
         [[-0.1474, -0.1328, -1.6456,  1.4732,  0.2217, -4.2840,  0.6970],
          [ 2.9944,  3.0970, -0.5592,  2.8136,  0.2917,  0.7529,  0.5986],
          [-0.7545, -0.6447, -2.5607,  2.9409, -0.1490, -0.8170, -3.2199]],
         [[ 1.5678, -1.4138, -0.4724, -0.3759,  3.2607, -1.3677,  0.0093],
          [ 0.5380,  0.7210,  0.5958, -0.3307, -3.4815, -0.7025,  0.2686],
          [-1.3187, -0.6886, -2.0378,  3.1369,  1.6712,  0.9054,  0.0266]]]])
The cache looks like this:
print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]
And the second set of inputs looks like:
print(string_type(inputs[1], with_shape=True))
(DynamicCache(key_cache=#2[T1s3x4x4x8,T1s3x4x4x8], value_cache=#2[T1s3x4x4x8,T1s3x4x4x8]),T1s1x1x1x8)
Guess the dynamic shapes¶
The following tool can be used to guess the dynamic shapes
the way torch.export.export() expects them.
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(([[{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}],
   [{0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)},
    {0: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     2: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True),
     3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
                 min=None,
                 max=None,
                 _factory=True)}]],
  {3: _DimHint(type=<_DimHintType.DYNAMIC: 3>,
               min=None,
               max=None,
               _factory=True)}),
 {})
And finally the export.
The export is simple if transformers>=4.50, otherwise,
transformers needs to be patched.
onnx_diagnostic.torch_export_patches.torch_export_patches()
registers functions to serialize DynamicCache. This one is modified to make
the shape inference implemented in torch happy.
if has_transformers("4.50"):
    ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
else:
    with torch_export_patches(patch_transformers=True) as modificator:
        ep = torch.export.export(
            model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
        )
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s13, 4, s10, s38]", cache_key_cache_1: "f32[s13, 4, s10, s38]", cache_value_cache_0: "f32[s13, 4, s10, s38]", cache_value_cache_1: "f32[s13, 4, s10, s38]", z: "f32[1, 1, 1, s38]"):
             #
            sym_size_int: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 0)
            sym_size_int_1: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 2)
            sym_size_int_2: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_0, 3)
            sym_size_int_3: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 0)
            sym_size_int_4: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 2)
            sym_size_int_5: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_key_cache_1, 3)
            sym_size_int_6: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 0)
            sym_size_int_7: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 2)
            sym_size_int_8: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_value_cache_0, 3)
            sym_size_int_9: "Sym(s13)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 0)
            sym_size_int_10: "Sym(s10)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 2)
            sym_size_int_11: "Sym(s38)" = torch.ops.aten.sym_size.int(cache_value_cache_1, 3)
            sym_size_int_12: "Sym(s38)" = torch.ops.aten.sym_size.int(z, 3)
             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None
             #
            eq: "Sym(True)" = sym_size_int_12 == sym_size_int_2;  sym_size_int_2 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s38, s92) on node 'eq'");  eq = _assert_scalar_default = None
            eq_1: "Sym(True)" = sym_size_int_12 == sym_size_int_5;  sym_size_int_5 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(eq_1, "Runtime assertion failed for expression Eq(s38, s57) on node 'eq_1'");  eq_1 = _assert_scalar_default_1 = None
            eq_2: "Sym(True)" = sym_size_int_1 == sym_size_int_4;  sym_size_int_4 = None
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(s10, s29) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None
            eq_3: "Sym(True)" = sym_size_int == sym_size_int_3;  sym_size_int_3 = None
            _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(s33, s93) on node 'eq_3'");  eq_3 = _assert_scalar_default_3 = None
            eq_4: "Sym(True)" = sym_size_int_12 == sym_size_int_8;  sym_size_int_8 = None
            _assert_scalar_default_4 = torch.ops.aten._assert_scalar.default(eq_4, "Runtime assertion failed for expression Eq(s38, s53) on node 'eq_4'");  eq_4 = _assert_scalar_default_4 = None
            eq_5: "Sym(True)" = sym_size_int_1 == sym_size_int_7;  sym_size_int_7 = None
            _assert_scalar_default_5 = torch.ops.aten._assert_scalar.default(eq_5, "Runtime assertion failed for expression Eq(s10, s19) on node 'eq_5'");  eq_5 = _assert_scalar_default_5 = None
            eq_6: "Sym(True)" = sym_size_int == sym_size_int_6;  sym_size_int = None
            _assert_scalar_default_6 = torch.ops.aten._assert_scalar.default(eq_6, "Runtime assertion failed for expression Eq(s33, s13) on node 'eq_6'");  eq_6 = _assert_scalar_default_6 = None
            eq_7: "Sym(True)" = sym_size_int_12 == sym_size_int_11;  sym_size_int_12 = sym_size_int_11 = None
            _assert_scalar_default_7 = torch.ops.aten._assert_scalar.default(eq_7, "Runtime assertion failed for expression Eq(s38, s97) on node 'eq_7'");  eq_7 = _assert_scalar_default_7 = None
            eq_8: "Sym(True)" = sym_size_int_1 == sym_size_int_10;  sym_size_int_1 = sym_size_int_10 = None
            _assert_scalar_default_8 = torch.ops.aten._assert_scalar.default(eq_8, "Runtime assertion failed for expression Eq(s10, s42) on node 'eq_8'");  eq_8 = _assert_scalar_default_8 = None
            eq_9: "Sym(True)" = sym_size_int_6 == sym_size_int_9;  sym_size_int_6 = sym_size_int_9 = None
            _assert_scalar_default_9 = torch.ops.aten._assert_scalar.default(eq_9, "Runtime assertion failed for expression Eq(s13, s25) on node 'eq_9'");  eq_9 = _assert_scalar_default_9 = None
             # File: ~/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:39 in forward, code: z
            add_1: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s13, 4, s10, s38]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)
Graph signature:
    # inputs
    cache_key_cache_0: USER_INPUT
    cache_key_cache_1: USER_INPUT
    cache_value_cache_0: USER_INPUT
    cache_value_cache_1: USER_INPUT
    z: USER_INPUT
    # outputs
    add_3: USER_OUTPUT
Range constraints: {s13: VR[2, int_oo], s10: VR[2, int_oo], s38: VR[2, int_oo]}
Use string instead of DYNAMIC¶
ONNX exporter considers strings instead of DYNAMIC or AUTO to give names to every dimension.
dss = mi.guess_dynamic_shapes(auto="dim")
pprint.pprint(dss)
(([[{0: 'dim_0I_0o_0l0', 2: 'dim_0I_0o_0l2', 3: 'dim_0I_0o_0l3'},
    {0: 'dim_0I_0o_1l0', 2: 'dim_0I_0o_1l2', 3: 'dim_0I_0o_1l3'}],
   [{0: 'dim_0I_1o_0l0', 2: 'dim_0I_1o_0l2', 3: 'dim_0I_1o_0l3'},
    {0: 'dim_0I_1o_1l0', 2: 'dim_0I_1o_1l2', 3: 'dim_0I_1o_1l3'}]],
  {3: 'dim_1I3'}),
 {})
Do we need to guess?¶
Function onnx_diagnostic.helpers.string_type() is using
the serialization functions to print out the DynamicCache the was
torch.export.export() expects them.
print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
You can also use function
onnx_diagnostic.helpers.cache_helper.flatten_unflatten_for_dynamic_shapes()
to show a DynamicCache restructured the way torch.export.export() expects
it to be without the custom class.
print(string_type(flatten_unflatten_for_dynamic_shapes(cache), with_shape=True))
#2[#2[T1s2x4x3x7,T1s2x4x3x7],#2[T1s2x4x3x7,T1s2x4x3x7]]
This code works for any custom class if it was registered
with torch.utils._pytree.register_pytree_node().
doc.plot_legend("dynamic shapes\nfor DynamicCache", "torch.export.export", "tomato")

Total running time of the script: (0 minutes 0.362 seconds)
Related examples
 
Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)
 
