Note
Go to the end to download the full example code.
Builds dynamic shapes from any input¶
Getting dynamic shapes right for torch.export.export()
when the inputs
includes a custom class such as a transformers.cache_utils.DynamicCache
.
torch.export.export()
cannot use a DynamicCache filled with dynamic shapes
but instead it uses a kind of unserialized serialized form of it.
Standard inputs for a LLM with a dynamic cache¶
import pprint
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
bsize, nheads, slen, dim = 2, 1, 30, 96
inputs = dict(
input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
position_ids=torch.arange(3, dtype=torch.int64),
past_key_values=make_dynamic_cache(
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
),
)
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s3,past_key_values:DynamicCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
Function onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs()
produces the corresponding dynamic shapes assuming they are all dynamic.
ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)
{'attention_mask': {0: 'd_1_0', 1: 'd_1_1'},
'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
'past_key_values': {'key_cache': [{0: 'd_3_0',
1: 'd_3_1',
2: 'd_3_2',
3: 'd_3_3'}],
'value_cache': [{0: 'd_4_0',
1: 'd_4_1',
2: 'd_4_2',
3: 'd_4_3'}]},
'position_ids': {0: 'd_2_0'}}
What about a StaticCache?¶
We use onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs()
to get
a consistent configuration with a static cache.
data = get_untrained_model_with_inputs(
"arnir0/Tiny-LLM",
model_kwargs=dict(cache_implementation="static"),
inputs_kwargs=dict(cls_cache="StaticCache"),
)
inputs = data["inputs"]
print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T9s2x1x3x96,cache_position:T7s3,past_key_values:StaticCache(key_cache=#1[T1s2x1x30x96], value_cache=#1[T1s2x1x30x96]))
And the input shapes.
ds = all_dynamic_shape_from_inputs(inputs)
if ds["past_key_values"]:
print("transformers implemented serialization function for StaticCache.")
else:
print("We need to use serialization function implemented in this package.")
with torch_export_patches(patch_transformers=True):
ds = all_dynamic_shape_from_inputs(inputs)
We need to use serialization function implemented in this package.
That gives.
{'attention_mask': {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2', 3: 'd_1_3'},
'cache_position': {0: 'd_2_0'},
'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
'past_key_values': {'key_cache': [{0: 'd_3_0',
1: 'd_3_1',
2: 'd_3_2',
3: 'd_3_3'}],
'value_cache': [{0: 'd_4_0',
1: 'd_4_1',
2: 'd_4_2',
3: 'd_4_3'}]}}
We can compare with the ones returned by the function.
pprint.pprint(data["dynamic_shapes"])
{'attention_mask': {0: Dim('batch', min=1, max=1024), 2: 'seq'},
'cache_position': {0: 'seq'},
'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]]}
doc.plot_legend("dynamic shapes\nfrom inputs", "dynamic shapes", "green")

Total running time of the script: (0 minutes 3.615 seconds)
Related examples

0, 1, 2 for a Dynamic Dimension in the dummy example to export a model