Source code for onnx_diagnostic.tasks.zero_shot_image_classification

from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.config_helper import update_config, check_hasattr

__TASK__ = "zero-shot-image-classification"


[docs] def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: """Reduces a model size.""" check_hasattr(config, "vision_config", "text_config") check_hasattr(config.vision_config, "num_hidden_layers", "num_attention_heads") check_hasattr(config.text_config, "num_hidden_layers", "num_attention_heads") kwargs = dict( vision_config=dict( num_hidden_layers=min(2, config.vision_config.num_hidden_layers), num_attention_heads=min(2, config.vision_config.num_attention_heads), ), text_config=dict( num_hidden_layers=min(2, config.text_config.num_hidden_layers), num_attention_heads=min(2, config.text_config.num_attention_heads), ), ) update_config(config, kwargs) return kwargs
[docs] def get_inputs( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, batch_size: int = 2, sequence_length: int = 30, input_width: int = 224, input_height: int = 224, input_channels: int = 3, batch_size_image=3, **kwargs, # unused ): """ Generates inputs for task ``zero-short-image-classification``. :param model: model to get the missing information :param config: configuration used to generate the model :param dummy_max_token_id: vocabulary size :param batch_size: batch size :param sequence_length: sequence length :param batch_size_image: number of images :param input_channels: input channel :param input_width: input width :param input_height: input height :return: dictionary # input_ids:T7s2x7 # attention_mask:T7s2x7 # pixel_values:T1s2x3x224x224 """ assert isinstance( input_width, int ), f"Unexpected type for input_width {type(input_width)}{config}" assert isinstance( input_width, int ), f"Unexpected type for input_height {type(input_height)}{config}" batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) shapes = { "input_ids": {0: batch, 1: seq_length}, "attention_mask": {0: batch, 1: seq_length}, "pixel_values": { 0: torch.export.Dim("batch_img", min=1, max=1024), # 2: torch.export.Dim("width", min=1, max=4096), # 3: torch.export.Dim("height", min=1, max=4096), }, } inputs = dict( input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to( torch.int64 ), attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64), pixel_values=torch.randn( batch_size_image, input_channels, input_width, input_height ).clamp(-1, 1), ) return dict(inputs=inputs, dynamic_shapes=shapes)
[docs] def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. If the configuration is None, the function selects typical dimensions. """ if config is not None: check_hasattr(config, "vision_config", "text_config") check_hasattr(config.vision_config, "image_size", "num_channels") check_hasattr(config.text_config, "vocab_size") kwargs = dict( batch_size=2, batch_size_image=3, sequence_length=30, dummy_max_token_id=(49408 if config is None else (config.text_config.vocab_size - 1)), input_width=224 if config is None else config.vision_config.image_size, input_height=224 if config is None else config.vision_config.image_size, input_channels=3 if config is None else config.vision_config.num_channels, ) return kwargs, get_inputs