Builds dynamic shapes from any input

Getting dynamic shapes right for torch.export.export() when the inputs includes a custom class such as a transformers.cache_utils.DynamicCache. torch.export.export() cannot use a DynamicCache filled with dynamic shapes but instead it uses a kind of unserialized serialized form of it.

Standard inputs for a LLM with a dynamic cache

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches

bsize, nheads, slen, dim = 2, 1, 30, 96

inputs = dict(
    input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
    attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
    position_ids=torch.arange(3, dtype=torch.int64),
    past_key_values=make_dynamic_cache(
        [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
    ),
)

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))

Function onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs() produces the corresponding dynamic shapes assuming they are all dynamic.

{'attention_mask': {0: 'd_1_0', 1: 'd_1_1'},
 'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
 'past_key_values': {'key_cache': [{0: 'd_3_0',
                                    1: 'd_3_1',
                                    2: 'd_3_2',
                                    3: 'd_3_3'}],
                     'value_cache': [{0: 'd_4_0',
                                      1: 'd_4_1',
                                      2: 'd_4_2',
                                      3: 'd_4_3'}]},
 'position_ids': {0: 'd_2_0'}}

What about a StaticCache?

We use onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs() to get a consistent configuration with a static cache.

data = get_untrained_model_with_inputs(
    "arnir0/Tiny-LLM",
    model_kwargs=dict(cache_implementation="static"),
    inputs_kwargs=dict(cls_cache="StaticCache"),
)
inputs = data["inputs"]
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T9s2x1x3x96,cache_position:T7s3,past_key_values:StaticCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))

And the input shapes.

ds = all_dynamic_shape_from_inputs(inputs)
if ds["past_key_values"]:
    print("transformers implemented serialization function for StaticCache.")
else:
    print("We need to use serialization function implemented in this package.")
    with torch_export_patches(patch_transformers=True):
        ds = all_dynamic_shape_from_inputs(inputs)
We need to use serialization function implemented in this package.

That gives.

{'attention_mask': {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2', 3: 'd_1_3'},
 'cache_position': {0: 'd_2_0'},
 'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
 'past_key_values': {'key_cache': [{0: 'd_3_0',
                                    1: 'd_3_1',
                                    2: 'd_3_2',
                                    3: 'd_3_3'}],
                     'value_cache': [{0: 'd_4_0',
                                      1: 'd_4_1',
                                      2: 'd_4_2',
                                      3: 'd_4_3'}]}}

We can compare with the ones returned by the function.

pprint.pprint(data["dynamic_shapes"])
{'attention_mask': {0: Dim('batch', min=1, max=1024), 2: 'seq'},
 'cache_position': {0: 'seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]]}
doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green")
plot dynamic shapes what

Total running time of the script: (0 minutes 3.615 seconds)

Related examples

JSON returns list when the original dynamic shapes are list or tuple

JSON returns list when the original dynamic shapes are list or tuple

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Gallery generated by Sphinx-Gallery