Source code for onnx_diagnostic.export.shape_helper

from typing import Any, Dict, List, Set, Tuple, Union
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from .dynamic_shapes import ModelInputs


[docs] def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: """ Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic. ``dim_prefix`` can be a string (the function uses it as a prefix), or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] ), ) ds = all_dynamic_shape_from_inputs(inputs) pprint.pprint(ds) """ if isinstance(dim_prefix, str): prefixes: Set[str] = set() def tensor_to_shape(tensor): n = len(prefixes) p = f"{dim_prefix}_{n}" prefixes.add(p) return {i: f"{p}_{i}" for i in range(tensor.ndim)} else: def tensor_to_shape(tensor): return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420 return flatten_unflatten_for_dynamic_shapes( inputs, change_function=tensor_to_shape, use_dict=True )
[docs] def guess_dynamic_shapes_from_inputs( inputs: List[Any], auto: Union[bool, str] = False ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Guesses which dimension is dimension from a set of inputs. Every dimension having different values over multiple sets of inputs. Every dimension not changing remains static. :param inputs: a list of input sets :param auto: True for ``torch.export.Dim.AUTO``, False for ``torch.export.Dim.DYNAMIC``, a string to get a unique string for every dynamic dimension :return: args and kwargs .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs1 = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) bsize, nheads, slen, dim = 3, 1, 33, 96 inputs2 = dict( input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64), attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64), position_ids=torch.arange(4, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d") pprint.pprint(ds) This function returns something equivalent to function :class:`torch.export.dynamic_shapes.AdditionalInputs` but this one needs a model. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True) ds = torch.export.dynamic_shapes.AdditionalInputs() ds.add((), data["inputs"]) ds.add((), data["inputs2"]) pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"])) """ mi = ModelInputs(None, inputs) return mi.guess_dynamic_shapes(auto=auto)