onnx_diagnostic.tasks.text_generation¶
- onnx_diagnostic.tasks.text_generation.get_inputs(model: Module, config: Any | None, dummy_max_token_id: int, num_hidden_layers: int, batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, dynamic_rope: bool = False, num_key_value_heads: int | None = None, head_dim: int | None = None, cls_cache: type | str | None = None, add_second_input: int = 1, **kwargs)[source][source]¶
- Generates input for task - text-generation.- Parameters:
- model – model to get the missing information 
- config – configuration used to generate the model 
- head_dim – last dimension of the cache 
- dummy_max_token_id – dummy max token id 
- batch_size – batch size 
- sequence_length – sequence length 
- sequence_length2 – new sequence length 
- dynamic_rope – use dynamic rope (see - transformers.LlamaConfig)
- cls_cache – cache class, by default it is - transformers.cache_utils.DynamicCache
 
- Returns:
- dictionary