Source code for onnx_diagnostic.helpers.fake_tensor_helper

from typing import Any, Dict, Optional, Tuple


_UNIQUE = set()


def _unique():
    i = 129 + 1
    while i in _UNIQUE:
        i += 1
    _UNIQUE.add(i)
    return i


[docs] def fake_reshape( true_tensor: "torch.Tensor", # noqa: F821 sh: Dict[int, Any], # noqa: F821 fake_tensor: Optional["FakeTensor"] = None, # noqa: F821 fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821 ) -> "FakeTensor": # noqa: F821 """ Changes the shape of a true tensor to make it dynamic. :param true_tensor: true tensor :param sh: dynamic shape :param fake_tensor: fake tensor, if None, make a fake one :param fake_mode: fake tensor mode :return: fake tensor """ import torch # deal with 0/1 for i in sh: if true_tensor.shape[i] <= 1: expanded_shape = list(true_tensor.shape) expanded_shape[i] = _unique() true_tensor = torch.empty( tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device ) # deal with equivalent dimension new_shape = list(true_tensor.shape) mapping = {} for i, s in sh.items(): d = true_tensor.shape[i] if d not in mapping: mapping[d] = s elif mapping[d] != s: d = _unique() mapping[d] = s new_shape[i] = d true_tensor = torch.empty( tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device ) # now switch to FakeTensor if fake_mode is None: from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch._subclasses.fake_tensor import FakeTensorMode shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) if fake_tensor is None: fake_tensor = fake_mode.from_tensor(true_tensor, static_shapes=False) assert fake_mode is not None, "fake_mode must be provided" new_shape = list(true_tensor.shape) for i in sh: new_shape[i] = fake_tensor.shape[i] reduced_tensor = fake_mode.from_tensor(true_tensor, static_shapes=True).sum( axis=tuple(sorted(sh)), keepdim=True ) return reduced_tensor.expand(*new_shape)
[docs] def make_fake( x: Any, fake_mode: Optional["FakeTensorMode"] = None # noqa: F821 ) -> Tuple[Optional["FakeTensor"], Optional["FakeTensorMode"]]: # noqa: F821 """ Replaces all tensors by fake tensors. This modification happens inplace for caches. This function is only implemented for cache with ``transformers>=4.55``. .. runpython:: :showcode: import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.helpers.fake_tensor_helper import make_fake inputs, _ = make_fake( dict( input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ] ), ) ) pprint.pprint(inputs) """ if x is None: return None, None if fake_mode is None: from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch._subclasses.fake_tensor import FakeTensorMode shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) if isinstance(x, (list, tuple)): return x.__class__([make_fake(i, fake_mode=fake_mode)[0] for i in x]), fake_mode if isinstance(x, dict): return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}: assert hasattr(x, "layers"), ( f"Une more recent version of transformers (>=4.55), " f"'layers' not found in class {type(x)}" ) for layer in x.layers: assert hasattr(layer, "keys") and hasattr(layer, "values"), ( f"Une more recent version of transformers (>=4.55), 'layers' " f"not found in class {type(layer)} ({dir(layer)})" ) layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0] layer.values = make_fake(layer.values, fake_mode=fake_mode)[0] return x, fake_mode if x.__class__.__name__ == "EncoderDecoderCache": make_fake(x.self_attention_cache, fake_mode=fake_mode) make_fake(x.cross_attention_cache, fake_mode=fake_mode) return x, fake_mode if hasattr(x, "shape"): t = fake_mode.from_tensor(x, static_shapes=False) return t, fake_mode from . import string_type raise TypeError( f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}" )