Source code for onnx_diagnostic.export.shape_helper

from typing import Any, Dict, List, Set, Optional, Tuple, Union
from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from ..helpers.fake_tensor_helper import fake_reshape
from .dynamic_shapes import ModelInputs


[docs] def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: """ Returns the dynamic shapes for the given inputs. All dimensions are considered as dynamic. ``dim_prefix`` can be a string (the function uses it as a prefix), or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``. Depending on the version of transformers, serializations function of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55). .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches bsize, nheads, slen, dim = 2, 1, 30, 96 inputs = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))] ), ) with torch_export_patches(patch_transformers=True): ds = all_dynamic_shapes_from_inputs(inputs) pprint.pprint(ds) For this function to work, patches must be enabled if :epkg:`transformers` does not implement the serialization functions. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, make_mamba_cache, make_sliding_window_cache, make_static_cache, ) from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches caches = [ make_dynamic_cache( [ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), ] ), make_encoder_decoder_cache( make_dynamic_cache( [ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), ] ), make_dynamic_cache( [ (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), (torch.rand((5, 5, 5)), torch.rand((5, 5, 5))), ] ), ), make_sliding_window_cache( [ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), ] ), make_static_cache( [ (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), ], max_cache_len=15, ), make_mamba_cache( [ (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), ] ), ] with torch_export_patches(patch_transformers=True): for cache in caches: print(f"-- {cache.__class__.__name__}") pprint.pprint(all_dynamic_shapes_from_inputs(cache)) """ if isinstance(dim_prefix, str): prefixes: Set[str] = set() def tensor_to_shape(tensor): n = len(prefixes) p = f"{dim_prefix}_{n}" prefixes.add(p) return {i: f"{p}_{i}" for i in range(tensor.ndim)} else: def tensor_to_shape(tensor): return {i: dim_prefix for i in range(tensor.ndim)} # noqa: C420 return flatten_unflatten_for_dynamic_shapes( inputs, change_function=tensor_to_shape, use_dict=True )
[docs] def guess_dynamic_shapes_from_inputs( inputs: List[Any], auto: Union[bool, str] = False ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ Guesses which dimension is dimension from a set of inputs. Every dimension having different values over multiple sets of inputs. Every dimension not changing remains static. :param inputs: a list of input sets :param auto: True for ``torch.export.Dim.AUTO``, False for ``torch.export.Dim.DYNAMIC``, a string to get a unique string for every dynamic dimension :return: args and kwargs .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs bsize, nheads, slen, dim = 2, 1, 30, 96 inputs1 = dict( input_ids=torch.randint(15, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.arange(3, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) bsize, nheads, slen, dim = 3, 1, 33, 96 inputs2 = dict( input_ids=torch.randint(15, size=(3, 4), dtype=torch.int64), attention_mask=torch.randint(1, size=(3, 34), dtype=torch.int64), position_ids=torch.arange(4, dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim), ), ] ), ) ds = guess_dynamic_shapes_from_inputs([inputs1, inputs2], auto="d") pprint.pprint(ds) This function returns something equivalent to function :class:`torch.export.dynamic_shapes.AdditionalInputs` but this one needs a model. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import guess_dynamic_shapes_from_inputs from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", add_second_input=True) ds = torch.export.dynamic_shapes.AdditionalInputs() ds.add((), data["inputs"]) ds.add((), data["inputs2"]) pprint.pprint(ds.dynamic_shapes(data["model"], (), data["inputs"])) """ mi = ModelInputs(None, inputs) return mi.guess_dynamic_shapes(auto=auto)
[docs] def make_fake_with_dynamic_dimensions( x: Any, dynamic_shapes: Any, fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821 ) -> Tuple[Any, "FakeTensorMode"]: # noqa: F821 """ Replaces all tensors by fake tensor respecting the same constraints as the following dynamic shapes. This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions inputs, _ = make_fake_with_dynamic_dimensions( dict( input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ] ), ), dynamic_shapes={ "input_ids": {0: "batch", 1: "seq_length"}, "attention_mask": {0: "batch", 1: "cache+seq"}, "position_ids": {0: "batch", 1: "seq_length"}, "past_key_values": [ [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], [{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}], ], }, ) pprint.pprint(inputs) """ if x is None: return None, None if fake_mode is None: from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch._subclasses.fake_tensor import FakeTensorMode shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) if isinstance(x, (list, tuple)): return ( x.__class__( [ make_fake_with_dynamic_dimensions( i, fake_mode=fake_mode, dynamic_shapes=ds )[0] for i, ds in zip(x, dynamic_shapes) ] ), fake_mode, ) if isinstance(x, dict): return { k: make_fake_with_dynamic_dimensions( v, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[k] )[0] for k, v in x.items() }, fake_mode if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}: assert hasattr(x, "layers"), ( f"Une more recent version of transformers (>=4.55), " f"'layers' not found in class {type(x)}" ) assert ( isinstance(dynamic_shapes, list) and len(dynamic_shapes) == 2 ), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache" for il, layer in enumerate(x.layers): assert hasattr(layer, "keys") and hasattr(layer, "values"), ( f"Une more recent version of transformers (>=4.55), 'layers' " f"not found in class {type(layer)} ({dir(layer)})" ) layer.keys = make_fake_with_dynamic_dimensions( layer.keys, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0][il] )[0] layer.values = make_fake_with_dynamic_dimensions( layer.values, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1][il] )[0] return x, fake_mode if x.__class__.__name__ == "EncoderDecoderCache": make_fake_with_dynamic_dimensions( x.self_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[0] ) make_fake_with_dynamic_dimensions( x.cross_attention_cache, fake_mode=fake_mode, dynamic_shapes=dynamic_shapes[1] ) return x, fake_mode if hasattr(x, "shape"): t = fake_reshape(x, dynamic_shapes, fake_mode=fake_mode) assert t.device == x.device, f"device mismatch {x.device} -> {t.device}" assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}" return t, fake_mode from ..helpers import string_type raise TypeError( f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}" )