from typing import Any, Callable, Dict, List, Tuple
from . import (
    automatic_speech_recognition,
    feature_extraction,
    fill_mask,
    image_classification,
    image_text_to_text,
    mixture_of_expert,
    object_detection,
    sentence_similarity,
    summarization,
    text_classification,
    text_generation,
    text_to_image,
    text2text_generation,
    zero_shot_image_classification,
)
__TASKS__ = [
    automatic_speech_recognition,
    feature_extraction,
    fill_mask,
    image_classification,
    image_text_to_text,
    mixture_of_expert,
    object_detection,
    sentence_similarity,
    summarization,
    text_classification,
    text_generation,
    text_to_image,
    text2text_generation,
    zero_shot_image_classification,
]
[docs]
def supported_tasks() -> List[str]:
    "Returns the list of supported tasks."
    return sorted(mod.__TASK__ for mod in __TASKS__) 
[docs]
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
    """Reduces a model size."""
    head_size0 = (
        config.head_dim
        if hasattr(config, "head_dim") and config.head_dim
        else (
            config.hidden_size // config.num_attention_heads
            if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads")
            else None
        )
    )
    tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
    assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
    res = tasks[task](config)
    if head_size0 and "head_dim" in res:
        head_size = (
            config.head_dim
            if hasattr(config, "head_dim") and config.head_dim
            else config.hidden_size // config.num_attention_heads
        )
        assert head_size0 == head_size or head_size % 16 == 0, (
            f"head_size should be a multiple of 16 "
            f"(head_size0={head_size0}), res={res}, "
            f"config=\n{config}"
        )
    return res