Source code for onnx_diagnostic.tasks.image_text_to_text
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
from ..helpers.config_helper import update_config, check_hasattr, _pick
__TASK__ = "image-text-to-text"
[docs]
def reduce_model_config(config: Any) -> Dict[str, Any]:
"""Reduces a model size."""
kwargs: Dict[str, Any] = {}
if hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = min(config.num_hidden_layers, 2)
if hasattr(config, "mm_tokens_per_image"):
config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
if hasattr(config, "vision_config"):
if hasattr(config.vision_config, "num_hidden_layers"):
config.vision_config.num_hidden_layers = min(
config.vision_config.num_hidden_layers, 2
)
if hasattr(config.vision_config, "image_size"):
config.vision_config.image_size = min(config.vision_config.image_size, 96)
if hasattr(config.vision_config, "intermediate_size"):
config.vision_config.intermediate_size = min(
config.vision_config.intermediate_size, 1076
)
if hasattr(config.vision_config, "patch_size"):
config.vision_config.patch_size = min(config.vision_config.patch_size, 2)
if hasattr(config.vision_config, "hidden_size"):
config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16)
if hasattr(config, "text_config"):
if hasattr(config.text_config, "intermediate_size"):
config.text_config.intermediate_size = min(
config.text_config.intermediate_size, 320
)
if hasattr(config.text_config, "hidden_size"):
config.text_config.hidden_size = min(config.text_config.hidden_size, 16)
if hasattr(config.text_config, "num_hidden_layers"):
config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2)
if hasattr(config.text_config, "layer_types"):
config.text_config.layer_types = config.text_config.layer_types[
: config.text_config.num_hidden_layers
]
if hasattr(config.text_config, "num_attention_heads"):
config.text_config.num_attention_heads = min(
config.text_config.num_attention_heads, 2
)
update_config(config, kwargs)
return kwargs
def _get_inputs_gemma3(
model: torch.nn.Module,
config: Optional[Any],
dummy_max_token_id: int,
num_key_value_heads: int,
num_hidden_layers: int,
pad_token_id: int,
image_token_index: int,
head_dim: int,
width: int,
height: int,
num_channels: int,
batch_size: int = 2,
sequence_length: int = 43,
sequence_length2: int = 43,
n_images: int = 2,
dynamic_rope: bool = False,
max_sequence_length: int = 380,
**kwargs, # unused
):
"""
::
dict(input_ids:T7s1x281,
pixel_values:T16s1x3x896x896,
attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
position_ids:T7s1x281,
past_key_values:HybridCache(
key_cache=#34[T1s1x4x380x256,...],
value_cache=#34[T1s1x4x380x256,...]),
token_type_ids:T7s1x281,
cache_position:T7s281,
logits_to_keep:1)
dict(input_ids:T7s1x1,
pixel_values:None,
attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
position_ids:T7s1x1,
past_key_values:HybridCache(
key_cache=#34[T1s1x4x380x256,...],
value_cache=#34[T1s1x4x380x256,...]),
token_type_ids:T7s1x1,
cache_position:T7s1,
logits_to_keep:1)
"""
assert (
"cls_cache" not in kwargs
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
batch = torch.export.Dim("batch", min=1, max=1024)
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
# cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
shapes = {
"input_ids": {0: batch, 1: seq_length},
"token_type_ids": {0: batch, 1: seq_length},
"attention_mask": {
"full_attention": {0: batch, 2: seq_length},
"sliding_attention": {0: batch, 2: seq_length},
},
"position_ids": {0: batch, 1: seq_length},
"cache_position": {1: seq_length},
"past_key_values": [
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch} for _ in range(num_hidden_layers)],
],
"pixel_values": {0: batch},
"use_cache": None,
}
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
torch.int64
)
input_ids[:, 1] = image_token_index
# input_ids[input_ids == image_token_index] = pad_token_id
token_type_ids = torch.zeros_like(input_ids)
token_type_ids[input_ids == image_token_index] = 1
inputs = dict(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=dict(
full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
),
cache_position=torch.arange(0, sequence_length).to(torch.int64),
position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
past_key_values=make_hybrid_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, max_sequence_length, head_dim
),
torch.randn(
batch_size, num_key_value_heads, max_sequence_length, head_dim
),
)
for i in range(num_hidden_layers)
]
),
pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
torch.int64
),
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
)
return dict(inputs=inputs, dynamic_shapes=shapes)
[docs]
def get_inputs(
model: torch.nn.Module,
config: Optional[Any],
dummy_max_token_id: int,
num_key_value_heads: int,
num_hidden_layers: int,
pad_token_id: int,
image_token_index: int,
head_dim: int,
width: int,
height: int,
num_channels: int,
batch_size: int = 2,
sequence_length: int = 43,
sequence_length2: int = 43,
n_images: int = 2,
dynamic_rope: bool = False,
add_second_input: int = 1,
**kwargs, # unused
):
"""
Generates input for task ``image-text-to-text``.
:param model: model to get the missing information
:param config: configuration used to generate the model
:param head_dim: last dimension of the cache
:param dummy_max_token_id: dummy max token id
:param pad_token_id: pad_token_id
:param image_token_index: image_token_index
:param batch_size: batch size
:param sequence_length: sequence length
:param sequence_length2: new sequence length
:param n_images: number of images
:param width: width of the image
:param height: height of the image
:param num_channels: number of channels
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
:return: dictionary
"""
if model.__class__.__name__.startswith("Gemma3"):
res = _get_inputs_gemma3(
model,
config,
dummy_max_token_id=dummy_max_token_id,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=num_hidden_layers,
pad_token_id=pad_token_id,
image_token_index=image_token_index,
head_dim=head_dim,
width=width,
height=height,
num_channels=num_channels,
batch_size=batch_size,
sequence_length=sequence_length,
sequence_length2=sequence_length2,
n_images=n_images,
dynamic_rope=dynamic_rope,
**kwargs,
)
else:
assert (
"cls_cache" not in kwargs
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
batch = torch.export.Dim("batch", min=1, max=1024)
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
images = "images" # torch.export.Dim("images", min=1, max=4096)
shapes = {
"input_ids": {0: batch, 1: seq_length},
"token_type_ids": {0: batch, 1: seq_length},
"attention_mask": {0: batch, 1: "cache+seq"},
"position_ids": {0: batch, 1: "cache+seq"},
"past_key_values": [
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
],
"pixel_values": (
{0: batch, 1: images}
if model.__class__.__name__ == "IdeficsForVisionText2Text"
else {0: batch_img}
),
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
"use_cache": None,
}
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
torch.int64
)
input_ids[0, 0] = image_token_index
input_ids[1, 1] = image_token_index
# input_ids[input_ids == image_token_index] = pad_token_id
token_type_ids = torch.zeros_like(input_ids)
token_type_ids[input_ids == image_token_index] = 1
inputs = dict(
input_ids=input_ids,
attention_mask=torch.cat(
[
torch.ones((batch_size, sequence_length), dtype=torch.int64),
input_ids.ne(pad_token_id).to(torch.int64),
],
axis=-1,
),
position_ids=torch.arange(0, sequence_length2)
.to(torch.int64)
.expand((batch_size, -1)),
past_key_values=make_dynamic_cache(
[
(
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
),
torch.randn(
batch_size, num_key_value_heads, sequence_length, head_dim
),
)
for i in range(num_hidden_layers)
]
),
pixel_values=(
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
if model.__class__.__name__ == "IdeficsForVisionText2Text"
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
),
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
torch.int64
),
token_type_ids=token_type_ids,
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
)
res = dict(inputs=inputs, dynamic_shapes=shapes)
if add_second_input:
assert (
add_second_input > 0
), f"Not implemented for add_second_input={add_second_input}."
res["inputs2"] = get_inputs(
model=model,
config=config,
dummy_max_token_id=dummy_max_token_id,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=num_hidden_layers,
head_dim=head_dim,
width=width,
height=height,
num_channels=num_channels,
batch_size=batch_size + 1,
sequence_length=sequence_length + add_second_input,
sequence_length2=sequence_length2 + 1,
n_images=n_images + 1,
dynamic_rope=dynamic_rope,
pad_token_id=pad_token_id,
image_token_index=image_token_index,
add_second_input=0,
**kwargs,
)["inputs"]
return res
[docs]
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
"""
Inputs kwargs.
If the configuration is None, the function selects typical dimensions.
"""
if config is not None:
if hasattr(config, "text_config"):
check_hasattr(
config.text_config,
"vocab_size",
"hidden_size",
"num_attention_heads",
("num_key_value_heads", "num_attention_heads"),
"intermediate_size",
"hidden_size",
"pad_token_id",
)
check_hasattr(config, "vision_config", "image_token_index")
text_config = True
else:
check_hasattr(
config,
"vocab_size",
"hidden_size",
"num_attention_heads",
("num_key_value_heads", "num_attention_heads"),
"intermediate_size",
"hidden_size",
"vision_config",
)
text_config = False
check_hasattr(config.vision_config, "image_size", "num_channels")
kwargs = dict(
batch_size=2,
sequence_length=43,
sequence_length2=43,
head_dim=(
16
if config is None
else getattr(
config,
"head_dim",
(
config.text_config.head_dim
if text_config and hasattr(config.text_config, "head_dim")
else (
(config.text_config.hidden_size if text_config else config.hidden_size)
// (
config.text_config.num_attention_heads
if text_config
else config.num_attention_heads
)
)
),
)
),
dummy_max_token_id=(
31999
if config is None
else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
),
num_hidden_layers=(
4
if config is None
else (
config.text_config.num_hidden_layers
if text_config
else config.num_hidden_layers
)
),
num_key_value_heads=(
8
if config is None
else (
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
if text_config
else _pick(config, "num_key_value_heads", "num_attention_heads")
)
),
intermediate_size=(
1024
if config is None
else (
config.text_config.intermediate_size
if text_config
else config.intermediate_size
)
),
hidden_size=(
512
if config is None
else (config.text_config.hidden_size if text_config else config.hidden_size)
),
width=224 if config is None else config.vision_config.image_size,
height=224 if config is None else config.vision_config.image_size,
num_channels=3 if config is None else config.vision_config.num_channels,
pad_token_id=(
0
if config is None or not hasattr(config, "text_config")
else config.text_config.pad_token_id
),
image_token_index=(
4
if config is None or not hasattr(config, "image_token_index")
else config.image_token_index
),
)
return kwargs, get_inputs