Source code for onnx_diagnostic.tasks.text2text_generation
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
from ..helpers.config_helper import update_config, check_hasattr, _pick
__TASK__ = "text2text-generation"
[docs]
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
"""Reduces a model size."""
kwargs: Dict[str, Any] = {}
if hasattr(config, "num_decoder_layers"):
config.num_decoder_layers = min(config.num_decoder_layers, 2)
if hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = min(config.num_hidden_layers, 2)
update_config(config, kwargs)
return kwargs
[docs]
def get_inputs(
model: torch.nn.Module,
config: Optional[Any],
dummy_max_token_id: int,
num_key_value_heads: int,
num_hidden_layers: int,
head_dim: int,
encoder_dim: int,
batch_size: int = 2,
sequence_length: int = 30,
sequence_length2: int = 3,
**kwargs, # unused
):
"""
Generates input for task ``text2text-generation``.
:param model: model to get the missing information
:param config: configuration used to generate the model
:param head_dim: last dimension of the cache
:param dummy_max_token_id: dummy max token id
:param batch_size: batch size
:param encoder_dim: last dimension of encoder_last_hidden_state
:param sequence_length: sequence length
:param sequence_length2: new sequence length
:return: dictionary
Stolen inputs for one model.
::
cache_position:T7s1
past_key_values:EncoderDecoderCache(
self_attention_cache=DynamicCache(
key_cache=#6[T1s1x8x1x64,...],
value_cache=#6[T1s1x8x1x64,...]),
cross_attention_cache=DynamicCache(
key_cache=#6[T1s1x8x16x64,...],
value_cache=#6[T1s1x8x16x64,...])),
decoder_input_ids:T7s1x1,
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
"""
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
shapes = {
"input_ids": {0: batch, 1: seq_length},
"decoder_input_ids": {0: batch, 1: "seq_ids"},
"attention_mask": {0: batch, 1: "seq_mask"},
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
"past_key_values": [
[
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
],
[
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
],
],
# one these is selected based on the forward method signature
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
# "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
}
inputs = dict(
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
torch.int64
),
decoder_input_ids=torch.randint(
0, dummy_max_token_id, (batch_size, sequence_length2)
).to(torch.int64),
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
# cache_position=torch.arange(sequence_length, sequence_length + sequence_length2)
# .to(torch.int64)
# .expand((batch_size, -1)),
past_key_values=make_encoder_decoder_cache(
make_dynamic_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
),
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
),
)
for i in range(num_hidden_layers)
]
),
make_dynamic_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length2, head_dim
),
torch.randn(
batch_size, num_key_value_heads, sequence_length2, head_dim
),
)
for i in range(num_hidden_layers)
]
),
),
# one these is selected based on the forward method signature
# encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
# encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
)
return dict(inputs=inputs, dynamic_shapes=shapes)
[docs]
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.
If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
check_hasattr(
config,
"vocab_size",
"hidden_size",
"num_attention_heads",
("num_hidden_layers", "num_layers"),
("n_positions", "d_model"),
(
"num_key_value_heads",
"num_heads",
("decoder_attention_heads", "encoder_attention_heads"),
),
)
kwargs = dict(
batch_size=2,
sequence_length=30,
sequence_length2=3,
head_dim=16 if config is None else (config.d_kv if hasattr(config, "d_kv") else 1),
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
num_hidden_layers=(
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
),
num_key_value_heads=(
16
if config is None
else _pick(
config,
"num_key_value_heads",
"num_heads",
(sum, "encoder_attention_heads", "decoder_attention_heads"),
)
),
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
)
return kwargs, get_inputs