onnx_diagnostic.tasks.text_generation¶
- onnx_diagnostic.tasks.text_generation.get_inputs(model: Module, config: Any | None, dummy_max_token_id: int, num_hidden_layers: int, batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, dynamic_rope: bool = False, num_key_value_heads: int | None = None, head_dim: int | None = None, cls_cache: type | str | None = None, **kwargs)[source]¶
Generates input for task
text-generation
.- Parameters:
model – model to get the missing information
config – configuration used to generate the model
head_dim – last dimension of the cache
dummy_max_token_id – dummy max token id
batch_size – batch size
sequence_length – sequence length
sequence_length2 – new sequence length
dynamic_rope – use dynamic rope (see
transformers.LlamaConfig
)cls_cache – cache class, by default it is
transformers.cache_utils.DynamicCache
- Returns:
dictionary