Source code for onnx_diagnostic.tasks.image_text_to_text

from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
from ..helpers.config_helper import update_config, check_hasattr, _pick

__TASK__ = "image-text-to-text"


[docs] def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" kwargs: Dict[str, Any] = {} if hasattr(config, "num_hidden_layers"): config.num_hidden_layers = min(config.num_hidden_layers, 2) if hasattr(config, "mm_tokens_per_image"): config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2) if hasattr(config, "vision_config"): if hasattr(config.vision_config, "num_hidden_layers"): config.vision_config.num_hidden_layers = min( config.vision_config.num_hidden_layers, 2 ) if hasattr(config.vision_config, "image_size"): config.vision_config.image_size = min(config.vision_config.image_size, 96) if hasattr(config.vision_config, "intermediate_size"): config.vision_config.intermediate_size = min( config.vision_config.intermediate_size, 1076 ) if hasattr(config.vision_config, "patch_size"): config.vision_config.patch_size = min(config.vision_config.patch_size, 2) if hasattr(config.vision_config, "hidden_size"): config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16) if hasattr(config, "text_config"): if hasattr(config.text_config, "intermediate_size"): config.text_config.intermediate_size = min( config.text_config.intermediate_size, 320 ) if hasattr(config.text_config, "hidden_size"): config.text_config.hidden_size = min(config.text_config.hidden_size, 16) if hasattr(config.text_config, "num_hidden_layers"): config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2) if hasattr(config.text_config, "layer_types"): config.text_config.layer_types = config.text_config.layer_types[ : config.text_config.num_hidden_layers ] if hasattr(config.text_config, "num_attention_heads"): config.text_config.num_attention_heads = min( config.text_config.num_attention_heads, 2 ) update_config(config, kwargs) return kwargs
def _get_inputs_gemma3( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, num_key_value_heads: int, num_hidden_layers: int, pad_token_id: int, image_token_index: int, head_dim: int, width: int, height: int, num_channels: int, batch_size: int = 2, sequence_length: int = 43, sequence_length2: int = 43, n_images: int = 2, dynamic_rope: bool = False, max_sequence_length: int = 380, **kwargs, # unused ): """ :: dict(input_ids:T7s1x281, pixel_values:T16s1x3x896x896, attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380), position_ids:T7s1x281, past_key_values:HybridCache( key_cache=#34[T1s1x4x380x256,...], value_cache=#34[T1s1x4x380x256,...]), token_type_ids:T7s1x281, cache_position:T7s281, logits_to_keep:1) dict(input_ids:T7s1x1, pixel_values:None, attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380), position_ids:T7s1x1, past_key_values:HybridCache( key_cache=#34[T1s1x4x380x256,...], value_cache=#34[T1s1x4x380x256,...]), token_type_ids:T7s1x1, cache_position:T7s1, logits_to_keep:1) """ assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) shapes = { "input_ids": {0: batch, 1: seq_length}, "token_type_ids": {0: batch, 1: seq_length}, "attention_mask": { "full_attention": {0: batch, 2: seq_length}, "sliding_attention": {0: batch, 2: seq_length}, }, "position_ids": {0: batch, 1: seq_length}, "cache_position": {1: seq_length}, "past_key_values": [ [{0: batch} for _ in range(num_hidden_layers)], [{0: batch} for _ in range(num_hidden_layers)], ], "pixel_values": {0: batch}, "use_cache": None, } input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to( torch.int64 ) input_ids[:, 1] = image_token_index # input_ids[input_ids == image_token_index] = pad_token_id token_type_ids = torch.zeros_like(input_ids) token_type_ids[input_ids == image_token_index] = 1 inputs = dict( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=dict( full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length), sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length), ), cache_position=torch.arange(0, sequence_length).to(torch.int64), position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)), past_key_values=make_hybrid_cache( [ ( torch.randn( batch_size, num_key_value_heads, max_sequence_length, head_dim ), torch.randn( batch_size, num_key_value_heads, max_sequence_length, head_dim ), ) for i in range(num_hidden_layers) ] ), pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1), image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to( torch.int64 ), use_cache=True, # Gemma3 does not set this value to true when a cache is provided ) return dict(inputs=inputs, dynamic_shapes=shapes)
[docs] def get_inputs( model: torch.nn.Module, config: Optional[Any], dummy_max_token_id: int, num_key_value_heads: int, num_hidden_layers: int, pad_token_id: int, image_token_index: int, head_dim: int, width: int, height: int, num_channels: int, batch_size: int = 2, sequence_length: int = 43, sequence_length2: int = 43, n_images: int = 2, dynamic_rope: bool = False, add_second_input: int = 1, **kwargs, # unused ): """ Generates input for task ``image-text-to-text``. :param model: model to get the missing information :param config: configuration used to generate the model :param head_dim: last dimension of the cache :param dummy_max_token_id: dummy max token id :param pad_token_id: pad_token_id :param image_token_index: image_token_index :param batch_size: batch size :param sequence_length: sequence length :param sequence_length2: new sequence length :param n_images: number of images :param width: width of the image :param height: height of the image :param num_channels: number of channels :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :return: dictionary """ if model.__class__.__name__.startswith("Gemma3"): res = _get_inputs_gemma3( model, config, dummy_max_token_id=dummy_max_token_id, num_key_value_heads=num_key_value_heads, num_hidden_layers=num_hidden_layers, pad_token_id=pad_token_id, image_token_index=image_token_index, head_dim=head_dim, width=width, height=height, num_channels=num_channels, batch_size=batch_size, sequence_length=sequence_length, sequence_length2=sequence_length2, n_images=n_images, dynamic_rope=dynamic_rope, **kwargs, ) else: assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = torch.export.Dim("batch", min=1, max=1024) batch_img = torch.export.Dim("batch_img", min=1, max=1024) seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096) cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) images = "images" # torch.export.Dim("images", min=1, max=4096) shapes = { "input_ids": {0: batch, 1: seq_length}, "token_type_ids": {0: batch, 1: seq_length}, "attention_mask": {0: batch, 1: "cache+seq"}, "position_ids": {0: batch, 1: "cache+seq"}, "past_key_values": [ [{0: batch} for _ in range(num_hidden_layers)], [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], ], "pixel_values": ( {0: batch, 1: images} if model.__class__.__name__ == "IdeficsForVisionText2Text" else {0: batch_img} ), "image_attention_mask": {0: batch, 1: seq_length, 2: images}, "use_cache": None, } input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to( torch.int64 ) input_ids[0, 0] = image_token_index input_ids[1, 1] = image_token_index # input_ids[input_ids == image_token_index] = pad_token_id token_type_ids = torch.zeros_like(input_ids) token_type_ids[input_ids == image_token_index] = 1 inputs = dict( input_ids=input_ids, attention_mask=torch.cat( [ torch.ones((batch_size, sequence_length), dtype=torch.int64), input_ids.ne(pad_token_id).to(torch.int64), ], axis=-1, ), position_ids=torch.arange(0, sequence_length2) .to(torch.int64) .expand((batch_size, -1)), past_key_values=make_dynamic_cache( [ ( torch.randn( batch_size, num_key_value_heads, sequence_length, head_dim ), torch.randn( batch_size, num_key_value_heads, sequence_length, head_dim ), ) for i in range(num_hidden_layers) ] ), pixel_values=( torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1) if model.__class__.__name__ == "IdeficsForVisionText2Text" else torch.randn(n_images, num_channels, width, height).clamp(-1, 1) ), image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to( torch.int64 ), token_type_ids=token_type_ids, use_cache=True, # Gemma3 does not set this value to true when a cache is provided ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: assert ( add_second_input > 0 ), f"Not implemented for add_second_input={add_second_input}." res["inputs2"] = get_inputs( model=model, config=config, dummy_max_token_id=dummy_max_token_id, num_key_value_heads=num_key_value_heads, num_hidden_layers=num_hidden_layers, head_dim=head_dim, width=width, height=height, num_channels=num_channels, batch_size=batch_size + 1, sequence_length=sequence_length + add_second_input, sequence_length2=sequence_length2 + 1, n_images=n_images + 1, dynamic_rope=dynamic_rope, pad_token_id=pad_token_id, image_token_index=image_token_index, add_second_input=0, **kwargs, )["inputs"] return res
[docs] def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. If the configuration is None, the function selects typical dimensions. """ if config is not None: if hasattr(config, "text_config"): check_hasattr( config.text_config, "vocab_size", "hidden_size", "num_attention_heads", ("num_key_value_heads", "num_attention_heads"), "intermediate_size", "hidden_size", "pad_token_id", ) check_hasattr(config, "vision_config", "image_token_index") text_config = True else: check_hasattr( config, "vocab_size", "hidden_size", "num_attention_heads", ("num_key_value_heads", "num_attention_heads"), "intermediate_size", "hidden_size", "vision_config", ) text_config = False check_hasattr(config.vision_config, "image_size", "num_channels") kwargs = dict( batch_size=2, sequence_length=43, sequence_length2=43, head_dim=( 16 if config is None else getattr( config, "head_dim", ( config.text_config.head_dim if text_config and hasattr(config.text_config, "head_dim") else ( (config.text_config.hidden_size if text_config else config.hidden_size) // ( config.text_config.num_attention_heads if text_config else config.num_attention_heads ) ) ), ) ), dummy_max_token_id=( 31999 if config is None else (config.text_config.vocab_size if text_config else config.vocab_size) - 1 ), num_hidden_layers=( 4 if config is None else ( config.text_config.num_hidden_layers if text_config else config.num_hidden_layers ) ), num_key_value_heads=( 8 if config is None else ( _pick(config.text_config, "num_key_value_heads", "num_attention_heads") if text_config else _pick(config, "num_key_value_heads", "num_attention_heads") ) ), intermediate_size=( 1024 if config is None else ( config.text_config.intermediate_size if text_config else config.intermediate_size ) ), hidden_size=( 512 if config is None else (config.text_config.hidden_size if text_config else config.hidden_size) ), width=224 if config is None else config.vision_config.image_size, height=224 if config is None else config.vision_config.image_size, num_channels=3 if config is None else config.vision_config.num_channels, pad_token_id=( 0 if config is None or not hasattr(config, "text_config") else config.text_config.pad_token_id ), image_token_index=( 4 if config is None or not hasattr(config, "image_token_index") else config.image_token_index ), ) return kwargs, get_inputs