onnx_diagnostic.helpers.fake_tensor_helper¶
- onnx_diagnostic.helpers.fake_tensor_helper.fake_reshape(true_tensor: torch.Tensor, sh: Dict[int, Any], fake_tensor: FakeTensor | None = None, fake_mode: FakeTensorMode | None = None) FakeTensor[source][source]¶
Changes the shape of a true tensor to make it dynamic.
- Parameters:
true_tensor – true tensor
sh – dynamic shape
fake_tensor – fake tensor, if None, make a fake one
fake_mode – fake tensor mode
- Returns:
fake tensor
- onnx_diagnostic.helpers.fake_tensor_helper.make_fake(x: Any, fake_mode: FakeTensorMode | None = None) Tuple[FakeTensor | None, FakeTensorMode | None][source][source]¶
Replaces all tensors by fake tensors. This modification happens inplace for caches. This function is only implemented for cache with
transformers>=4.55.<<<
import pprint import torch from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.helpers.fake_tensor_helper import make_fake inputs, _ = make_fake( dict( input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64), attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64), position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64), past_key_values=make_dynamic_cache( [ ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ( torch.rand((2, 32, 30, 96), dtype=torch.float16), torch.rand((2, 32, 30, 96), dtype=torch.float16), ), ] ), ) ) pprint.pprint(inputs)
>>>
{'attention_mask': FakeTensor(..., size=(s26, s70), dtype=torch.int64), 'input_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64), 'past_key_values': DynamicCache(layers=[DynamicLayer, DynamicLayer]), 'position_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64)}