Source code for onnx_diagnostic.tasks.zero_shot_image_classification
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.config_helper import update_config, check_hasattr
__TASK__ = "zero-shot-image-classification"
[docs]
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
"""Reduces a model size."""
check_hasattr(config, "vision_config", "text_config")
check_hasattr(config.vision_config, "num_hidden_layers", "num_attention_heads")
check_hasattr(config.text_config, "num_hidden_layers", "num_attention_heads")
kwargs = dict(
vision_config=dict(
num_hidden_layers=min(2, config.vision_config.num_hidden_layers),
num_attention_heads=min(2, config.vision_config.num_attention_heads),
),
text_config=dict(
num_hidden_layers=min(2, config.text_config.num_hidden_layers),
num_attention_heads=min(2, config.text_config.num_attention_heads),
),
)
update_config(config, kwargs)
return kwargs
[docs]
def get_inputs(
model: torch.nn.Module,
config: Optional[Any],
dummy_max_token_id: int,
batch_size: int = 2,
sequence_length: int = 30,
input_width: int = 224,
input_height: int = 224,
input_channels: int = 3,
batch_size_image=3,
**kwargs, # unused
):
"""
Generates inputs for task ``zero-short-image-classification``.
:param model: model to get the missing information
:param config: configuration used to generate the model
:param dummy_max_token_id: vocabulary size
:param batch_size: batch size
:param sequence_length: sequence length
:param batch_size_image: number of images
:param input_channels: input channel
:param input_width: input width
:param input_height: input height
:return: dictionary
# input_ids:T7s2x7
# attention_mask:T7s2x7
# pixel_values:T1s2x3x224x224
"""
assert isinstance(
input_width, int
), f"Unexpected type for input_width {type(input_width)}{config}"
assert isinstance(
input_width, int
), f"Unexpected type for input_height {type(input_height)}{config}"
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
shapes = {
"input_ids": {0: batch, 1: seq_length},
"attention_mask": {0: batch, 1: seq_length},
"pixel_values": {
0: torch.export.Dim("batch_img", min=1, max=1024),
# 2: torch.export.Dim("width", min=1, max=4096),
# 3: torch.export.Dim("height", min=1, max=4096),
},
}
inputs = dict(
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
torch.int64
),
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
pixel_values=torch.randn(
batch_size_image, input_channels, input_width, input_height
).clamp(-1, 1),
)
return dict(inputs=inputs, dynamic_shapes=shapes)
[docs]
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.
If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
check_hasattr(config, "vision_config", "text_config")
check_hasattr(config.vision_config, "image_size", "num_channels")
check_hasattr(config.text_config, "vocab_size")
kwargs = dict(
batch_size=2,
batch_size_image=3,
sequence_length=30,
dummy_max_token_id=(49408 if config is None else (config.text_config.vocab_size - 1)),
input_width=224 if config is None else config.vision_config.image_size,
input_height=224 if config is None else config.vision_config.image_size,
input_channels=3 if config is None else config.vision_config.num_channels,
)
return kwargs, get_inputs