onnx_diagnostic.helpers.fake_tensor_helper

onnx_diagnostic.helpers.fake_tensor_helper.fake_reshape(true_tensor: torch.Tensor, sh: Dict[int, Any], fake_tensor: FakeTensor | None = None, fake_mode: FakeTensorMode | None = None) FakeTensor[source][source]

Changes the shape of a true tensor to make it dynamic.

Parameters:
  • true_tensor – true tensor

  • sh – dynamic shape

  • fake_tensor – fake tensor, if None, make a fake one

  • fake_mode – fake tensor mode

Returns:

fake tensor

onnx_diagnostic.helpers.fake_tensor_helper.make_fake(x: Any, fake_mode: FakeTensorMode | None = None) Tuple[FakeTensor | None, FakeTensorMode | None][source][source]

Replaces all tensors by fake tensors. This modification happens inplace for caches. This function is only implemented for cache with transformers>=4.55.

<<<

import pprint
import torch
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake

inputs, _ = make_fake(
    dict(
        input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
        attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
        position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
        past_key_values=make_dynamic_cache(
            [
                (
                    torch.rand((2, 32, 30, 96), dtype=torch.float16),
                    torch.rand((2, 32, 30, 96), dtype=torch.float16),
                ),
                (
                    torch.rand((2, 32, 30, 96), dtype=torch.float16),
                    torch.rand((2, 32, 30, 96), dtype=torch.float16),
                ),
            ]
        ),
    )
)
pprint.pprint(inputs)

>>>

    {'attention_mask': FakeTensor(..., size=(s26, s70), dtype=torch.int64),
     'input_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64),
     'past_key_values': DynamicCache(layers=[DynamicLayer, DynamicLayer]),
     'position_ids': FakeTensor(..., size=(s26, s49), dtype=torch.int64)}