Export with DynamicCache and dynamic shapes

Every LLMs implemented in transformers use cache. One of the most used is transformers.cache_utils.DynamicCache. The cache size is dynamic to cope with the growing context. The example shows a tool which determines the dynamic shapes for torch.export.export() based on a set of valid inputs.

Simple Examples

We first look at examples playing positional and names parameters to understand how torch.export.export() works.

args

import pprint
import torch
from onnx_diagnostic.cache_helpers import make_dynamic_cache
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.export import ModelInputs


class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y)  # to check it works

ep = torch.export.export(model, (x, y))
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[5, 6]", y: "f32[1, 6]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:33 in forward, code: return x + y
            add: "f32[5, 6]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

As expected there is no dynamic shapes. We use onnx_diagnostic.export.ModelInputs to define them from two set of valid inputs. These inputs must have different value for the dynamic dimensions.

(({0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  {1: <_DimHint.DYNAMIC: 3>}),
 {})

The function returns a tuple with two objects. The first one for the positional arguments, the other one for the named arguments. There is no named arguments. We we used the first result to export.

ep = torch.export.export(model, (x, y), dynamic_shapes=ds[0])
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1]", y: "f32[1, s1]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:33 in forward, code: return x + y
            add: "f32[s0, s1]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}

kwargs

We do the same with named arguments.

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x=x, y=y)  # to check it works
tensor([[-1.0201, -1.9076, -0.3050,  0.0533,  0.8182, -0.8327],
        [-2.2935, -0.9657,  0.9467,  0.6726, -0.0212,  0.6299],
        [-1.2506, -0.8341, -1.4627,  1.2358,  0.2910, -1.7509],
        [-1.1133, -0.9155, -0.7246, -1.1544,  0.5863, -1.2973],
        [-0.7934, -1.6219,  0.4357, -0.4643,  0.4957, -1.1040]])

Two sets of valid inputs.

((),
 {'x': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  'y': {1: <_DimHint.DYNAMIC: 3>}})

And we export.

ep = torch.export.export(model, (), kwargs=dict(x=x, y=y), dynamic_shapes=ds[1])
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1]", y: "f32[1, s1]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:74 in forward, code: return x + y
            add: "f32[s0, s1]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}

args and kwargs

torch.export.export() does not like having dynami shapes for both args and kwargs. We need to define them using one mechanism.

class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y=y)  # to check it works
tensor([[-0.1624,  0.6861,  1.9658,  0.1582, -0.7413,  0.6881],
        [-1.6511, -0.7769,  0.7927,  1.3940, -1.0881,  2.4199],
        [-1.2740,  0.8089,  1.8449,  0.6138, -0.4120,  1.8403],
        [-1.0791,  0.4893,  1.6329, -0.5010,  0.8021,  1.3950],
        [-1.4956,  0.4924,  1.2462,  1.6752,  1.2960,  0.3391]])

Two sets of valid inputs with positional and names arguments.

inputs = [((x,), dict(y=y)), ((torch.randn((7, 8)),), dict(y=torch.randn((1, 8))))]
mi = ModelInputs(Model(), inputs)
ds = mi.guess_dynamic_shapes()
pprint.pprint(ds)
(({0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},),
 {'y': {1: <_DimHint.DYNAMIC: 3>}})

This does not work with torch.export.export() so we use a method to move the positional dynamic shapes to named one. The method relies on the signature of the forward method.

((),
 {'x': {0: <_DimHint.DYNAMIC: 3>, 1: <_DimHint.DYNAMIC: 3>},
  'y': {1: <_DimHint.DYNAMIC: 3>}})

And we export.

ep = torch.export.export(model, new_args, kwargs=new_kwargs, dynamic_shapes=new_ds[1])
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, s1]", y: "f32[1, s1]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:104 in forward, code: return x + y
            add: "f32[s0, s1]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}

DynamicCache

torch.export.export() serializes caches and any custom class if these serialization functions are provided with is the case for transformers.cache_utils.DynamicCache and transformers>=4.50. The dynamic shapes must be provided following the serialized form.

class Model(torch.nn.Module):
    def forward(self, cache, z):
        return (
            z
            + cache.key_cache[0]
            + cache.key_cache[1]
            + cache.value_cache[0]
            + cache.value_cache[1]
        )


model = Model()

n_layers = 2
bsize, nheads, slen, dim = 2, 4, 3, 7
cache = make_dynamic_cache(
    [
        (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))
        for i in range(n_layers)
    ]
)
z = torch.randn((1, 1, 1, 7))
model(cache, z)  # to check it works.
tensor([[[[-3.2845,  4.2926, -1.7517, -4.8644, -1.5986,  0.4895,  1.0207],
          [-0.2886,  2.6616, -5.2097, -1.6180, -3.4835, -1.5230,  0.3657],
          [ 2.3324,  1.8868, -3.3126, -2.7687, -0.4793,  0.3189,  0.6251]],

         [[-1.6105,  1.3440, -3.6883, -3.2647, -0.5125, -2.4907, -0.2397],
          [ 0.3935, -0.6067, -0.3976, -2.5595,  0.3269, -1.7304,  2.2176],
          [ 0.4912,  0.4410,  1.2822,  1.2822,  0.1216,  2.8847,  2.4807]],

         [[-1.5902,  1.4074,  0.5187, -0.0664,  1.5466,  1.0184,  1.8762],
          [ 1.5781,  0.2618, -2.6422, -0.3781, -1.4884, -0.1200,  0.0077],
          [-0.1962,  0.2378, -1.1902, -3.4412, -1.3535,  0.5866,  4.8000]],

         [[ 0.6132,  1.6157, -3.8814, -2.3609,  1.5207, -1.3498,  2.9325],
          [-2.1361,  0.8503,  1.2289, -0.7826, -1.3881,  1.6127, -1.2768],
          [ 1.3806, -0.2831, -2.1345, -0.3700, -2.6012, -5.5707,  2.4474]]],


        [[[-1.0034,  0.9038, -0.0516, -1.8436,  2.0595, -0.1720,  1.4887],
          [-1.3535,  0.2050, -1.1001, -2.7771,  0.9746, -1.1793,  2.1590],
          [ 1.9655,  0.6035, -3.6063,  1.5803, -4.0936, -0.1030, -0.6102]],

         [[-0.0210,  0.0587, -0.4822, -0.3298,  0.1540, -1.6425,  2.8792],
          [ 1.8894,  0.4439, -2.1557, -2.9282, -1.4422, -3.1668,  0.4406],
          [-1.0992,  2.9409, -0.6066,  0.6485, -1.5444,  3.0429,  1.6119]],

         [[-1.5391, -0.6656, -2.8273,  1.2773,  0.7963,  1.1290,  3.4955],
          [-4.0890,  0.9967,  1.3557, -1.4801, -2.7495,  0.0787,  4.3855],
          [-5.0766,  3.4235,  0.6910, -5.3388, -3.9815,  0.8682,  2.6344]],

         [[ 1.4206,  1.0074, -1.8708, -1.2733, -4.6728,  1.6277,  1.3218],
          [-2.9718,  0.9872, -4.2313, -3.1923, -1.8926, -0.4952,  1.3969],
          [ 0.4730,  1.3637,  1.0247, -4.3649, -3.1086, -1.0412,  2.5493]]]])

The cache looks like this:

print(string_type(cache, with_shape=True))
DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7])
cache2 = make_dynamic_cache(
    [
        (
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
            torch.randn(bsize + 1, nheads, slen + 1, dim + 1),
        )
        for i in range(n_layers)
    ]
)
inputs = [
    (cache, z),
    (cache2, torch.randn((1, 1, 1, 8))),
]

And the first set of inputs looks like:

print(string_type(inputs[0], with_shape=True))
(DynamicCache(key_cache=#2[T1s2x4x3x7,T1s2x4x3x7], value_cache=#2[T1s2x4x3x7,T1s2x4x3x7]),T1s1x1x1x7)

We can now compute the dynamic shapes.

(([[{0: <_DimHint.DYNAMIC: 3>,
     2: <_DimHint.DYNAMIC: 3>,
     3: <_DimHint.DYNAMIC: 3>},
    {0: <_DimHint.DYNAMIC: 3>,
     2: <_DimHint.DYNAMIC: 3>,
     3: <_DimHint.DYNAMIC: 3>}],
   [{0: <_DimHint.DYNAMIC: 3>,
     2: <_DimHint.DYNAMIC: 3>,
     3: <_DimHint.DYNAMIC: 3>},
    {0: <_DimHint.DYNAMIC: 3>,
     2: <_DimHint.DYNAMIC: 3>,
     3: <_DimHint.DYNAMIC: 3>}]],
  {3: <_DimHint.DYNAMIC: 3>}),
 {})

And finally the export.

ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
print(ep)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, cache_key_cache_0: "f32[s0, 4, s1, s11]", cache_key_cache_1: "f32[s0, 4, s1, s11]", cache_value_cache_0: "f32[s0, 4, s1, s11]", cache_value_cache_1: "f32[s0, 4, s1, s11]", z: "f32[1, 1, 1, s11]"):
             # File: /home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_with_dynamic_cache.py:148 in forward, code: z
            add: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(z, cache_key_cache_0);  z = cache_key_cache_0 = None
            add_1: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(add, cache_key_cache_1);  add = cache_key_cache_1 = None
            add_2: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(add_1, cache_value_cache_0);  add_1 = cache_value_cache_0 = None
            add_3: "f32[s0, 4, s1, s11]" = torch.ops.aten.add.Tensor(add_2, cache_value_cache_1);  add_2 = cache_value_cache_1 = None
            return (add_3,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='cache_key_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='cache_key_cache_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='cache_value_cache_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='cache_value_cache_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo], s11: VR[2, int_oo]}

Total running time of the script: (0 minutes 1.569 seconds)

Related examples

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints

Steel method forward to guess the dynamic shapes

Steel method forward to guess the dynamic shapes

Export a model with a control flow (If)

Export a model with a control flow (If)

Gallery generated by Sphinx-Gallery