Source code for onnx_diagnostic.export.shape_helper

from typing import Any, Set
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes


[docs] def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: """ Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic. ``dim_prefix`` can be a string (the function uses it as a prefix), or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] ), ) ds = all_dynamic_shape_from_inputs(inputs) pprint.pprint(ds) """ if isinstance(dim_prefix, str): prefixes: Set[str] = set() def tensor_to_shape(tensor): n = len(prefixes) p = f"{dim_prefix}_{n}" prefixes.add(p) return {i: f"{p}_{i}" for i in range(tensor.ndim)} else: def tensor_to_shape(tensor): return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420 return flatten_unflatten_for_dynamic_shapes( inputs, change_function=tensor_to_shape, use_dict=True )