Source code for onnx_diagnostic.tasks.mixture_of_expert
from typing import Any, Callable, Dict, Optional, Tuple
import torch
# from ..helpers.cache_helper import make_dynamic_cache
from ..helpers.config_helper import update_config  # , check_hasattr, _pick
__TASK__ = "MoE"
[docs]
def reduce_model_config(config: Any) -> Dict[str, Any]:
    """Reduces a model size."""
    kwargs: Dict[str, Any] = {}
    if hasattr(config, "num_hidden_layers"):
        config.num_hidden_layers = min(config.num_hidden_layers, 2)
    if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
        config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
    if hasattr(config, "audio_processor") and hasattr(
        config.audio_processor, "num_hidden_layers"
    ):
        config.audio_processor.num_hidden_layers = min(
            config.audio_processor.num_hidden_layers, 2
        )
    if hasattr(config, "audio_processor") and hasattr(config.audio_processor, "attention_dim"):
        config.audio_processor.attention_dim = min(config.audio_processor.attention_dim, 2)
    update_config(config, kwargs)
    return kwargs
[docs]
def get_inputs(
    model: torch.nn.Module,
    config: Optional[Any],
    dummy_max_token_id: int,
    num_key_value_heads: int,
    num_hidden_layers: int,
    head_dim: int,
    width: int,
    height: int,
    num_channels: int,
    batch_size: int = 2,
    sequence_length: int = 30,
    sequence_length2: int = 3,
    n_images: int = 2,
    dynamic_rope: bool = False,
    add_second_input: bool = False,
    **kwargs,  # unused
):
    """
    Generates input for task ``MoE``.
    :param model: model to get the missing information
    :param config: configuration used to generate the model
    :param head_dim: last dimension of the cache
    :param dummy_max_token_id: dummy max token id
    :param batch_size: batch size
    :param sequence_length: sequence length
    :param sequence_length2: new sequence length
    :param n_images: number of images
    :param width: width of the image
    :param height: height of the image
    :param num_channels: number of channels
    :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
    :return: dictionary
    """
    assert not add_second_input, "add_second_input=True not yet implemented"
    raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
[docs]
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
    """
    Inputs kwargs.
    If the configuration is None, the function selects typical dimensions.
    """
    raise NotImplementedError(
        f"random_input_kwargs not yet implemented for task {__TASK__!r}."
    )