Source code for onnx_diagnostic.tasks.text_generation

from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import transformers
from ..helpers.cache_helper import (
    make_dynamic_cache,
    make_mamba_cache,
    make_sliding_window_cache,
)
from ..helpers.config_helper import update_config, check_hasattr, _pick

__TASK__ = "text-generation"


[docs] def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: """Reduces a model size.""" # FalconMambaConfig: use_mambapy check_hasattr( config, ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"), "num_hidden_layers", ("num_key_value_heads", "num_attention_heads", "use_mambapy"), "intermediate_size", "hidden_size", "vocab_size", ) if config.__class__.__name__ == "FalconMambaConfig": check_hasattr(config, "conv_kernel", "state_size") # 4 and 8 kwargs = dict( num_hidden_layers=min(config.num_hidden_layers, 2), intermediate_size=256 if config is None else min(512, config.intermediate_size), hidden_size=256 if config is None else min(256, config.hidden_size), cls_cache="MambaCache", state_size=8 if config is None else getattr(config, "state_size", None), conv_kernel=4 if config is None else getattr(config, "conv_kernel", None), ) else: kwargs = dict( head_dim=getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ), num_hidden_layers=min(config.num_hidden_layers, 2), num_key_value_heads=( config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads ), intermediate_size=( min(config.intermediate_size, 24576 // 4) if config.intermediate_size % 4 == 0 else config.intermediate_size ), hidden_size=( min(config.hidden_size, 3072 // 4) if config.hidden_size % 4 == 0 else config.hidden_size ), ) update_config(config, kwargs) return kwargs
[docs] def get_inputs( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, num_hidden_layers: int, batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, dynamic_rope: bool = False, num_key_value_heads: Optional[int] = None, head_dim: Optional[int] = None, cls_cache: Optional[Union[type, str]] = None, **kwargs, # unused ): """ Generates input for task ``text-generation``. :param model: model to get the missing information :param config: configuration used to generate the model :param head_dim: last dimension of the cache :param dummy_max_token_id: dummy max token id :param batch_size: batch size :param sequence_length: sequence length :param sequence_length2: new sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :param cls_cache: cache class, by default it is :class:`transformers.cache_utils.DynamicCache` :return: dictionary """ batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) if config is not None and config.__class__.__name__ == "FalconMambaConfig": assert cls_cache in ( "MambaCache", transformers.cache_utils.MambaCache, ), f"Unexpected value for cls_cache={cls_cache} and config={config}" seq_length_multiple = 8 sequence_length = ( (sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple ) # sequence_inc = seq_length_multiple sequence_length2 = seq_length_multiple shapes = { "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC}, "attention_mask": { 0: batch, 1: "cache+seq", # cache_length + seq_length }, "cache_position": { 0: batch, 1: "cache+seq", # cache_length + seq_length }, "cache_params": [ [{0: batch} for _ in range(num_hidden_layers)], [{0: batch} for _ in range(num_hidden_layers)], ], } inputs = dict( input_ids=torch.randint( 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2) ).to(torch.int64), attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( torch.int64 ), cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64), # .expand((batch_size, -1)) cache_params=make_mamba_cache( [ ( torch.randn( batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"] ), torch.randn( batch_size, kwargs["intermediate_size"], kwargs["state_size"] ), ) for i in range(num_hidden_layers) ] ), ) return dict(inputs=inputs, dynamic_shapes=shapes) if head_dim is None: assert config, "head_dim is None, the value cannot be set without a configuration" head_dim = config.hidden_size // config.num_attention_heads shapes = { "input_ids": {0: batch, 1: seq_length}, "attention_mask": { 0: batch, 1: "cache+seq", # cache_length + seq_length }, "position_ids": { 0: batch, 1: "cache+seq", # cache_length + seq_length }, "past_key_values": [ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], ], } make_cache = ( make_sliding_window_cache if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache) else make_dynamic_cache ) inputs = dict( input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to( torch.int64 ), attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( torch.int64 ), position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) .to(torch.int64) .expand((batch_size, -1)), past_key_values=make_cache( [ ( torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim), torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim), ) for i in range(num_hidden_layers) ] ), ) return dict(inputs=inputs, dynamic_shapes=shapes)
[docs] def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. If the configuration is None, the function selects typical dimensions. """ if config is not None: check_hasattr( config, "vocab_size", ("num_attention_heads", "use_mambapy"), ("num_key_value_heads", "num_attention_heads", "use_mambapy"), "intermediate_size", "hidden_size", ) if config.__class__.__name__ == "FalconMambaConfig": check_hasattr(config, "conv_kernel", "state_size") # 4 and 8 kwargs = dict( batch_size=2, sequence_length=30, sequence_length2=3, dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), num_hidden_layers=4 if config is None else config.num_hidden_layers, intermediate_size=256 if config is None else config.intermediate_size, cls_cache="MambaCache", state_size=8 if config is None else getattr(config, "state_size", None), conv_kernel=8 if config is None else getattr(config, "conv_kernel", None), ) else: kwargs = dict( batch_size=2, sequence_length=30, sequence_length2=3, head_dim=( 16 if config is None else getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) ), dummy_max_token_id=31999 if config is None else (config.vocab_size - 1), num_hidden_layers=4 if config is None else config.num_hidden_layers, num_key_value_heads=( 24 if config is None else _pick(config, "num_key_value_heads", "num_attention_heads") ), intermediate_size=1024 if config is None else config.intermediate_size, hidden_size=512 if config is None else config.hidden_size, ) return kwargs, get_inputs