onnx_diagnostic.export.shape_helper¶
- onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = 'd') Any [source][source]¶
Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic.
dim_prefix
can be a string (the function uses it as a prefix), ortorch.export.Dim.AUTO
ortorch.export.Dim.DYNAMIC
.<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] ), ) ds = all_dynamic_shape_from_inputs(inputs) pprint.pprint(ds)
>>>
{'attention_mask': {0: 'd_1_0', 1: 'd_1_1'}, 'input_ids': {0: 'd_0_0', 1: 'd_0_1'}, 'past_key_values': {'key_cache': [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2', 3: 'd_3_3'}], 'value_cache': [{0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2', 3: 'd_4_3'}]}, 'position_ids': {0: 'd_2_0'}}