onnx_diagnostic.export.shape_helper¶
- onnx_diagnostic.export.shape_helper.all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = 'd') Any [source][source]¶
Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic.
dim_prefix
can be a string (the function uses it as a prefix), ortorch.export.Dim.AUTO
ortorch.export.Dim.DYNAMIC
.<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] ), ) ds = all_dynamic_shape_from_inputs(inputs) pprint.pprint(ds)
>>>
{'attention_mask': {0: 'd_1_0', 1: 'd_1_1'}, 'input_ids': {0: 'd_0_0', 1: 'd_0_1'}, 'past_key_values': [[{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2', 3: 'd_3_3'}], [{0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2', 3: 'd_4_3'}]], 'position_ids': {0: 'd_2_0'}}
For this function to work, patches must be enabled if transformers does not implement the serialization functions.
<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, make_mamba_cache, make_sliding_window_cache, make_static_cache, ) from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches caches = [ make_dynamic_cache( [ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), ] ), make_encoder_decoder_cache( make_dynamic_cache( [ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), ] ), make_dynamic_cache( [ (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), ] ), ), make_sliding_window_cache( [ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), ] ), make_static_cache( [ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), ], max_cache_len=15, ), make_mamba_cache( [ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), ] ), ] with torch_export_patches(patch_transformers=True): for cache in caches: print(f"-- {cache.__class__.__name__}") pprint.pprint(all_dynamic_shape_from_inputs(cache))
>>>
-- DynamicCache [[{0: 'd_0_0', 1: 'd_0_1', 2: 'd_0_2'}, {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2'}, {0: 'd_2_0', 1: 'd_2_1', 2: 'd_2_2'}], [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2'}, {0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2'}, {0: 'd_5_0', 1: 'd_5_1', 2: 'd_5_2'}]] -- EncoderDecoderCache [[[{0: 'd_0_0', 1: 'd_0_1', 2: 'd_0_2'}, {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2'}, {0: 'd_2_0', 1: 'd_2_1', 2: 'd_2_2'}], [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2'}, {0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2'}, {0: 'd_5_0', 1: 'd_5_1', 2: 'd_5_2'}]], [[{0: 'd_6_0', 1: 'd_6_1', 2: 'd_6_2'}, {0: 'd_7_0', 1: 'd_7_1', 2: 'd_7_2'}, {0: 'd_8_0', 1: 'd_8_1', 2: 'd_8_2'}], [{0: 'd_9_0', 1: 'd_9_1', 2: 'd_9_2'}, {0: 'd_10_0', 1: 'd_10_1', 2: 'd_10_2'}, {0: 'd_11_0', 1: 'd_11_1', 2: 'd_11_2'}]]] -- SlidingWindowCache [[{0: 'd_0_0', 1: 'd_0_1', 2: 'd_0_2', 3: 'd_0_3'}, {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2', 3: 'd_1_3'}, {0: 'd_2_0', 1: 'd_2_1', 2: 'd_2_2', 3: 'd_2_3'}], [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2', 3: 'd_3_3'}, {0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2', 3: 'd_4_3'}, {0: 'd_5_0', 1: 'd_5_1', 2: 'd_5_2', 3: 'd_5_3'}]] -- StaticCache [[{0: 'd_0_0', 1: 'd_0_1', 2: 'd_0_2', 3: 'd_0_3'}, {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2', 3: 'd_1_3'}, {0: 'd_2_0', 1: 'd_2_1', 2: 'd_2_2', 3: 'd_2_3'}], [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2', 3: 'd_3_3'}, {0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2', 3: 'd_4_3'}, {0: 'd_5_0', 1: 'd_5_1', 2: 'd_5_2', 3: 'd_5_3'}]] -- MambaCache [[{0: 'd_0_0', 1: 'd_0_1', 2: 'd_0_2'}, {0: 'd_1_0', 1: 'd_1_1', 2: 'd_1_2'}, {0: 'd_2_0', 1: 'd_2_1', 2: 'd_2_2'}], [{0: 'd_3_0', 1: 'd_3_1', 2: 'd_3_2'}, {0: 'd_4_0', 1: 'd_4_1', 2: 'd_4_2'}, {0: 'd_5_0', 1: 'd_5_1', 2: 'd_5_2'}]]
- onnx_diagnostic.export.shape_helper.guess_dynamic_shapes_from_inputs(inputs: List[Any], auto: bool | str = False) Tuple[Tuple[Any, ...], Dict[str, Any]] [source][source]¶
Guesses which dimension is dimension from a set of inputs. Every dimension having different values over multiple sets of inputs. Every dimension not changing remains static.
- Parameters:
inputs – a list of input sets
auto – True for
torch.export.Dim.AUTO
, False fortorch.export.Dim.DYNAMIC
, a string to get a unique string for every dynamic dimension
- Returns:
args and kwargs
<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs1 = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) bsize, nheads, slen, dim = 3, 1, 33, 96 inputs2 = dict( input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64), attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64), position_ids=torch.arange(4, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d") pprint.pprint(ds)
>>>
((), {'attention_mask': {0: 'd_0I0', 1: 'd_0I1'}, 'input_ids': {0: 'd_1I0', 1: 'd_1I1'}, 'past_key_values': [[{0: 'd_2I_0o_0l0', 2: 'd_2I_0o_0l2'}], [{0: 'd_2I_1o_0l0', 2: 'd_2I_1o_0l2'}]], 'position_ids': {0: 'd_3I0'}})
This function returns something equivalent to function
torch.export.dynamic_shapes.AdditionalInputs
but this one needs a model.<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True) ds = torch.export.dynamic_shapes.AdditionalInputs() ds.add((), data["inputs"]) ds.add((), data["inputs2"]) pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"]))
>>>
{'attention_mask': (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)), 'input_ids': (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True)), 'past_key_values': [[(_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None)], [(_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None)]], 'position_ids': (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True))}