Source code for onnx_diagnostic.tasks.automatic_speech_recognition
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import transformers
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
from ..helpers.config_helper import update_config, check_hasattr
__TASK__ = "automatic-speech-recognition"
[docs]
def reduce_model_config(config: Any) -> Dict[str, Any]:
    """Reduces a model size."""
    kwargs: Dict[str, Any] = {}
    if hasattr(config, "num_decoder_layers"):
        config.num_decoder_layers = min(config.num_decoder_layers, 2)
    if hasattr(config, "decoder_layers"):
        config.decoder_layers = min(config.decoder_layers, 2)
    if hasattr(config, "num_hidden_layers"):
        config.num_hidden_layers = min(config.num_hidden_layers, 2)
    update_config(config, kwargs)
    return kwargs
[docs]
def get_inputs(
    model: torch.nn.Module,
    config: Optional[Any],
    dummy_max_token_id: int,
    max_source_positions: int,
    d_model: int,
    num_hidden_layers: int,
    encoder_attention_heads: int,
    encoder_layers: int,
    decoder_layers: int,
    head_dim: int,
    batch_size: int = 2,
    sequence_length: int = 30,
    add_second_input: int = 1,
    **kwargs,  # unused
):
    """
    Generates inputs for task ``automatic-speech-recognition``.
    Example:
    ::
        dict(
            cache_position:T7s4,
            past_key_values:EncoderDecoderCache(
                self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]),
                cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]])
            ),
            decoder_input_ids:T7s1x4,
            encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384),
            use_cache:bool,return_dict:bool
        )
        dict(
            cache_position:T7s1,
            past_key_values:EncoderDecoderCache(
                self_attention_cache=DynamicCache[serialized](#2[
                    #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64],
                    #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64]
                ]),
                cross_attention_cache=DynamicCache[serialized](#2[
                    #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64],
                    #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64]
                ]),
            ),
            decoder_input_ids:T7s1x1,
            encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384),
            use_cache:bool,return_dict:bool
        )
    """
    assert (
        "cls_cache" not in kwargs
    ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
    batch = torch.export.Dim("batch", min=1, max=1024)
    seq_length = "seq_length"
    shapes = {
        "decoder_input_ids": {0: batch, 1: seq_length},
        "cache_position": {0: seq_length},
        "encoder_outputs": [{0: batch}],  # last_hidden_state
        "past_key_values": [
            [
                [{0: batch} for _ in range(num_hidden_layers)],
                [{0: batch} for _ in range(num_hidden_layers)],
            ],
            [
                [{0: batch} for _ in range(num_hidden_layers)],
                [{0: batch} for _ in range(num_hidden_layers)],
            ],
        ],
    }
    inputs = dict(
        decoder_input_ids=torch.randint(
            0, dummy_max_token_id, (batch_size, sequence_length)
        ).to(torch.int64),
        cache_position=(torch.arange(sequence_length) + 5).to(torch.int64),
        encoder_outputs=transformers.modeling_outputs.BaseModelOutput(
            last_hidden_state=torch.randn(batch_size, max_source_positions, d_model)
        ),
        past_key_values=make_encoder_decoder_cache(
            make_dynamic_cache(
                [
                    (
                        torch.randn(
                            batch_size, encoder_attention_heads, encoder_layers, head_dim
                        ),
                        torch.randn(
                            batch_size, encoder_attention_heads, encoder_layers, head_dim
                        ),
                    )
                    for i in range(num_hidden_layers)
                ]
            ),
            make_dynamic_cache(
                [
                    (
                        torch.randn(
                            batch_size, encoder_attention_heads, max_source_positions, head_dim
                        ),
                        torch.randn(
                            batch_size, encoder_attention_heads, max_source_positions, head_dim
                        ),
                    )
                    for i in range(num_hidden_layers)
                ]
            ),
        ),
        # one these is selected based on the forward method signature
        # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
        # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
    )
    res = dict(inputs=inputs, dynamic_shapes=shapes)
    if add_second_input:
        assert (
            add_second_input > 0
        ), f"Not implemented for add_second_input={add_second_input}."
        res["inputs2"] = get_inputs(
            model=model,
            config=config,
            dummy_max_token_id=dummy_max_token_id,
            max_source_positions=max_source_positions,
            d_model=d_model,
            num_hidden_layers=num_hidden_layers,
            encoder_attention_heads=encoder_attention_heads,
            encoder_layers=encoder_layers,
            decoder_layers=decoder_layers,
            head_dim=head_dim,
            batch_size=batch_size + 1,
            sequence_length=sequence_length + add_second_input,
            add_second_input=0,
            **kwargs,
        )["inputs"]
    return res
[docs]
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
    """
    Inputs kwargs.
    If the configuration is None, the function selects typical dimensions.
    """
    if config is not None:
        check_hasattr(
            config,
            "d_model",
            "decoder_attention_heads",
            "decoder_layers",
            "encoder_attention_heads",
            "encoder_layers",
            "max_source_positions",
            "num_hidden_layers",
            "vocab_size",
        )
    kwargs = dict(
        batch_size=2,
        sequence_length=30,
        dummy_max_token_id=31000 if config is None else config.vocab_size,
        max_source_positions=1500 if config is None else config.max_source_positions,
        d_model=384 if config is None else config.d_model,
        num_hidden_layers=4 if config is None else config.num_hidden_layers,
        encoder_attention_heads=6 if config is None else config.encoder_attention_heads,
        encoder_layers=4 if config is None else config.encoder_layers,
        decoder_attention_heads=6 if config is None else config.decoder_attention_heads,
        decoder_layers=4 if config is None else config.decoder_layers,
        head_dim=(
            64 if config is None else (config.d_model // config.encoder_attention_heads)
        ),
    )
    return kwargs, get_inputs