onnx_diagnostic.export.shape_helper¶
- onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = 'd') Any [source][source]¶
Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic.
dim_prefix
can be a string (the function uses it as a prefix), ortorch.export.Dim.AUTO
ortorch.export.Dim.DYNAMIC
.<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] ), ) ds = all_dynamic_shape_from_inputs(inputs) pprint.pprint(ds)
>>>
{'attention_mask': {0: 'd_1_0', 1: 'd_1_1'}, 'input_ids': {0: 'd_0_0', 1: 'd_0_1'}, 'past_key_values': {'key_cache': [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2', 3: 'd_3_3'}], 'value_cache': [{0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2', 3: 'd_4_3'}]}, 'position_ids': {0: 'd_2_0'}}
- onnx_diagnostic.export.shape_helper.guess_dynamic_shapes_from_inputs(inputs: List[Any], auto: bool | str = False) Tuple[Tuple[Any, ...], Dict[str, Any]] [source][source]¶
Guesses which dimension is dimension from a set of inputs. Every dimension having different values over multiple sets of inputs. Every dimension not changing remains static.
- Parameters:
inputs – a list of input sets
auto – True for
torch.export.Dim.AUTO
, False fortorch.export.Dim.DYNAMIC
, a string to get a unique string for every dynamic dimension
- Returns:
args and kwargs
<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs1 = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) bsize, nheads, slen, dim = 3, 1, 33, 96 inputs2 = dict( input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64), attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64), position_ids=torch.arange(4, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d") pprint.pprint(ds)
>>>
((), {'attention_mask': {0: 'd_0I0', 1: 'd_0I1'}, 'input_ids': {0: 'd_1I0', 1: 'd_1I1'}, 'past_key_values': [[{0: 'd_2I_0o_0l0', 2: 'd_2I_0o_0l2'}], [{0: 'd_2I_1o_0l0', 2: 'd_2I_1o_0l2'}]], 'position_ids': {0: 'd_3I0'}})
This function returns something equivalent to function
torch.export.dynamic_shapes.AdditionalInputs
but this one needs a model.<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True) ds = torch.export.dynamic_shapes.AdditionalInputs() ds.add((), data["inputs"]) ds.add((), data["inputs2"]) pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
>>>
{'attention_mask': (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)), 'input_ids': (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)), 'past_key_values': [[(_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None)], [(_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None)]], 'position_ids': (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True))}