Source code for onnx_diagnostic.tasks.mixture_of_expert

from typing import Any, Callable, Dict, Optional, Tuple
import torch

# from ..helpers.cache_helper import make_dynamic_cache
from ..helpers.config_helper import update_config  # , check_hasattr, _pick

__TASK__ = "MoE"


[docs] def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" kwargs: Dict[str, Any] = {} if hasattr(config, "num_hidden_layers"): config.num_hidden_layers = min(config.num_hidden_layers, 2) if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"): config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2) if hasattr(config, "audio_processor") and hasattr( config.audio_processor, "num_hidden_layers" ): config.audio_processor.num_hidden_layers = min( config.audio_processor.num_hidden_layers, 2 ) if hasattr(config, "audio_processor") and hasattr(config.audio_processor, "attention_dim"): config.audio_processor.attention_dim = min(config.audio_processor.attention_dim, 2) update_config(config, kwargs) return kwargs
[docs] def get_inputs( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, num_key_value_heads: int, num_hidden_layers: int, head_dim: int, width: int, height: int, num_channels: int, batch_size: int = 2, sequence_length: int = 30, sequence_length2: int = 3, n_images: int = 2, dynamic_rope: bool = False, add_second_input: bool = False, **kwargs, # unused ): """ Generates input for task ``MoE``. :param model: model to get the missing information :param config: configuration used to generate the model :param head_dim: last dimension of the cache :param dummy_max_token_id: dummy max token id :param batch_size: batch size :param sequence_length: sequence length :param sequence_length2: new sequence length :param n_images: number of images :param width: width of the image :param height: height of the image :param num_channels: number of channels :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :return: dictionary """ assert not add_second_input, "add_second_input=True not yet implemented" raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
[docs] def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. If the configuration is None, the function selects typical dimensions. """ raise NotImplementedError( f"random_input_kwargs not yet implemented for task {__TASK__!r}." )