yobx.torch.fake_tensor_helper#
- class yobx.torch.fake_tensor_helper.FakeTensorContext(fake_mode: FakeTensorMode | None = None)[source]#
Stores information used to reuse same dimension for the same dimension names.
- fake_reshape(true_tensor: Tensor, sh: Dict[int, Any], fake_tensor: FakeTensor | None = None) FakeTensor[source]#
Changes the shape of a true tensor to make it dynamic.
- Parameters:
true_tensor – true tensor
sh – dynamic shape
fake_tensor – fake tensor, if None, make a fake one
- Returns:
fake tensor
- yobx.torch.fake_tensor_helper.make_fake(x: Any, context: FakeTensorContext | None = None) Tuple[FakeTensor | None, FakeTensorContext | None][source]#
Replaces all tensors by fake tensors. This modification happens inplace for caches. This function is only implemented for cache with
transformers>=4.55.<<<
import pprint import torch from yobx.torch.fake_tensor_helper import make_fake from yobx.torch.in_transformers.cache_helper import make_dynamic_cache inputs, _ = make_fake( dict( input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ] ), ) ) pprint.pprint(inputs)
>>>
<frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute {'attention_mask': FakeTensor(..., size=(s26, s70), dtype=torch.int64), 'input_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64), 'past_key_values': DynamicCache(layers=[DynamicLayer, DynamicLayer]), 'position_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64)}
- yobx.torch.fake_tensor_helper.make_fake_with_dynamic_dimensions(x: Any, dynamic_shapes: Any, context: FakeTensorContext | None = None) Tuple[Any | None, FakeTensorContext | None][source]#
Replaces all tensors by fake tensor respecting the same constraints as the following dynamic shapes. This uses function
yobx.torch.fake_tensor_helper.make_fake(). Parameterexistingis used to reuse the same object when the dynamic dimension is given the same name as another one. This function works with caches only iftransformers>=4.57.A simple tensor:
<<<
import torch from yobx.torch.in_transformers.cache_helper import make_dynamic_cache from yobx.torch.fake_tensor_helper import make_fake_with_dynamic_dimensions inputs, _ = make_fake_with_dynamic_dimensions( torch.rand((2, 3, 4, 5), dtype=torch.float32), {0: "batch", 2: "cache_length"}, ) print(inputs)
>>>
FakeTensor(..., size=(s26, 3, s36, 5))
Two tensors:
<<<
import torch from yobx.torch.in_transformers.cache_helper import make_dynamic_cache from yobx.torch.fake_tensor_helper import make_fake_with_dynamic_dimensions inputs, _ = make_fake_with_dynamic_dimensions( ( torch.rand((2, 3, 4, 5), dtype=torch.float32), torch.rand((2, 3, 4, 5), dtype=torch.float32), ), ({0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}), ) print(inputs)
>>>
(FakeTensor(..., size=(s26, 3, s36, 5)), FakeTensor(..., size=(s26, 3, s36, 5)))
With a cache:
<<<
import pprint import torch from yobx.torch.in_transformers.cache_helper import make_dynamic_cache from yobx.torch.fake_tensor_helper import make_fake_with_dynamic_dimensions inputs, _ = make_fake_with_dynamic_dimensions( dict( input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ] ), ), dynamic_shapes={ "input_ids": {0: "batch", 1: "seq_length"}, "attention_mask": {0: "batch", 1: "cache+seq"}, "position_ids": {0: "batch", 1: "seq_length"}, "past_key_values": [ {0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}, ], }, ) pprint.pprint(inputs)
>>>
{'attention_mask': FakeTensor(..., size=(s26, s19), dtype=torch.int64), 'input_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64), 'past_key_values': DynamicCache(layers=[DynamicLayer, DynamicLayer]), 'position_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64)}