Source code for onnx_diagnostic.tasks.automatic_speech_recognition

from typing import Any, Callable, Dict, Optional, Tuple
import torch
import transformers
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
from ..helpers.config_helper import update_config, check_hasattr

__TASK__ = "automatic-speech-recognition"


[docs] def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" kwargs: Dict[str, Any] = {} if hasattr(config, "num_decoder_layers"): config.num_decoder_layers = min(config.num_decoder_layers, 2) if hasattr(config, "decoder_layers"): config.decoder_layers = min(config.decoder_layers, 2) if hasattr(config, "num_hidden_layers"): config.num_hidden_layers = min(config.num_hidden_layers, 2) update_config(config, kwargs) return kwargs
[docs] def get_inputs( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, max_source_positions: int, d_model: int, num_hidden_layers: int, encoder_attention_heads: int, encoder_layers: int, decoder_layers: int, head_dim: int, batch_size: int = 2, sequence_length: int = 30, add_second_input: bool = False, **kwargs, # unused ): """ Generates inputs for task ``automatic-speech-recognition``. Example: :: dict( cache_position:T7s4, past_key_values:EncoderDecoderCache( self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]), cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]) ), decoder_input_ids:T7s1x4, encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384), use_cache:bool,return_dict:bool ) dict( cache_position:T7s1, past_key_values:EncoderDecoderCache( self_attention_cache=DynamicCache[serialized](#2[ #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64], #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64] ]), cross_attention_cache=DynamicCache[serialized](#2[ #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64], #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64] ]), ), decoder_input_ids:T7s1x1, encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384), use_cache:bool,return_dict:bool ) """ batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" shapes = { "decoder_input_ids": {0: batch, 1: seq_length}, "cache_position": {0: seq_length}, "encoder_outputs": [{0: batch}], # last_hidden_state "past_key_values": [ [ [{0: batch} for _ in range(num_hidden_layers)], [{0: batch} for _ in range(num_hidden_layers)], ], [ [{0: batch} for _ in range(num_hidden_layers)], [{0: batch} for _ in range(num_hidden_layers)], ], ], } inputs = dict( decoder_input_ids=torch.randint( 0, dummy_max_token_id, (batch_size, sequence_length) ).to(torch.int64), cache_position=(torch.arange(sequence_length) + 5).to(torch.int64), encoder_outputs=transformers.modeling_outputs.BaseModelOutput( last_hidden_state=torch.randn(batch_size, max_source_positions, d_model) ), past_key_values=make_encoder_decoder_cache( make_dynamic_cache( [ ( torch.randn( batch_size, encoder_attention_heads, encoder_layers, head_dim ), torch.randn( batch_size, encoder_attention_heads, encoder_layers, head_dim ), ) for i in range(num_hidden_layers) ] ), make_dynamic_cache( [ ( torch.randn( batch_size, encoder_attention_heads, max_source_positions, head_dim ), torch.randn( batch_size, encoder_attention_heads, max_source_positions, head_dim ), ) for i in range(num_hidden_layers) ] ), ), # one these is selected based on the forward method signature # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim), # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim), ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: res["inputs2"] = get_inputs( model=model, config=config, dummy_max_token_id=dummy_max_token_id, max_source_positions=max_source_positions, d_model=d_model, num_hidden_layers=num_hidden_layers, encoder_attention_heads=encoder_attention_heads, encoder_layers=encoder_layers, decoder_layers=decoder_layers, head_dim=head_dim, batch_size=batch_size + 1, sequence_length=sequence_length + 1, **kwargs, )["inputs"] return res
[docs] def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. If the configuration is None, the function selects typical dimensions. """ if config is not None: check_hasattr( config, "d_model", "decoder_attention_heads", "decoder_layers", "encoder_attention_heads", "encoder_layers", "max_source_positions", "num_hidden_layers", "vocab_size", ) kwargs = dict( batch_size=2, sequence_length=30, dummy_max_token_id=31000 if config is None else config.vocab_size, max_source_positions=1500 if config is None else config.max_source_positions, d_model=384 if config is None else config.d_model, num_hidden_layers=4 if config is None else config.num_hidden_layers, encoder_attention_heads=6 if config is None else config.encoder_attention_heads, encoder_layers=4 if config is None else config.encoder_layers, decoder_attention_heads=6 if config is None else config.decoder_attention_heads, decoder_layers=4 if config is None else config.decoder_layers, head_dim=( 64 if config is None else (config.d_model // config.encoder_attention_heads) ), ) return kwargs, get_inputs