Builds dynamic shapes from any input

Getting dynamic shapes right for torch.export.export() when the inputs includes a custom class such as a transformers.cache_utils.DynamicCache. torch.export.export() cannot use a DynamicCache filled with dynamic shapes but instead it uses a kind of unserialized serialized form of it.

Standard inputs for a LLM with a dynamic cache

import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches

bsize, nheads, slen, dim = 2, 1, 30, 96

inputs = dict(
    input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
    attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
    position_ids=torch.arange(3, dtype=torch.int64),
    past_key_values=make_dynamic_cache(
        [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
    ),
)

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))

Function onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs() produces the corresponding dynamic shapes assuming they are all dynamic.

ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)
{'attention_mask': {0: 'd_1_0', 1: 'd_1_1'},
 'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
 'past_key_values': {'key_cache': [{0: 'd_3_0',
                                    1: 'd_3_1',
                                    2: 'd_3_2',
                                    3: 'd_3_3'}],
                     'value_cache': [{0: 'd_4_0',
                                      1: 'd_4_1',
                                      2: 'd_4_2',
                                      3: 'd_4_3'}]},
 'position_ids': {0: 'd_2_0'}}

What about a StaticCache?

We use onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs() to get a consistent configuration with a static cache.

data = get_untrained_model_with_inputs(
    "arnir0/Tiny-LLM",
    model_kwargs=dict(cache_implementation="static"),
    inputs_kwargs=dict(cls_cache="StaticCache"),
)
inputs = data["inputs"]
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T9s2x1x3x96,cache_position:T7s3,past_key_values:StaticCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))

And the input shapes.

ds = all_dynamic_shape_from_inputs(inputs)
if ds["past_key_values"]:
    print("transformers implemented serialization function for StaticCache.")
else:
    print("We need to use serialization function implemented in this package.")
    with torch_export_patches(patch_transformers=True):
        ds = all_dynamic_shape_from_inputs(inputs)
We need to use serialization function implemented in this package.

That gives.

{'attention_mask': {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2', 3: 'd_1_3'},
 'cache_position': {0: 'd_2_0'},
 'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
 'past_key_values': {'key_cache': [{0: 'd_3_0',
                                    1: 'd_3_1',
                                    2: 'd_3_2',
                                    3: 'd_3_3'}],
                     'value_cache': [{0: 'd_4_0',
                                      1: 'd_4_1',
                                      2: 'd_4_2',
                                      3: 'd_4_3'}]}}

We can compare with the ones returned by the function.

pprint.pprint(data["dynamic_shapes"])
{'attention_mask': {0: Dim('batch', min=1, max=1024), 2: 'seq'},
 'cache_position': {0: 'seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]]}
doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green")
plot dynamic shapes what

Total running time of the script: (0 minutes 3.615 seconds)

Related examples

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model

Half certain nonzero

Half certain nonzero

Do not use python int with dynamic shapes

Do not use python int with dynamic shapes

Gallery generated by Sphinx-Gallery