onnx_diagnostic.export.shape_helper

onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = 'd') Any[source][source]

Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic. dim_prefix can be a string (the function uses it as a prefix), or torch.export.Dim.AUTO or torch.export.Dim.DYNAMIC.

<<<

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs

bsize, nheads, slen, dim = 2, 1, 30, 96
inputs = dict(
    input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64),
    attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
    position_ids=torch.arange(3, dtype=torch.int64),
    past_key_values=make_dynamic_cache(
        [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
    ),
)
ds = all_dynamic_shape_from_inputs(inputs)
pprint.pprint(ds)

>>>

    {'attention_mask': {0: 'd_1_0', 1: 'd_1_1'},
     'input_ids': {0: 'd_0_0', 1: 'd_0_1'},
     'past_key_values': {'key_cache': [{0: 'd_3_0',
                                        1: 'd_3_1',
                                        2: 'd_3_2',
                                        3: 'd_3_3'}],
                         'value_cache': [{0: 'd_4_0',
                                          1: 'd_4_1',
                                          2: 'd_4_2',
                                          3: 'd_4_3'}]},
     'position_ids': {0: 'd_2_0'}}