Source code for onnx_diagnostic.tasks.text_to_image

from typing import Any, Callable, Dict, Optional, Tuple
import torch
from ..helpers.config_helper import update_config, check_hasattr, pick

__TASK__ = "text-to-image"


[docs] def reduce_model_config(config: Any) -> Dict[str, Any]: """Reduces a model size.""" check_hasattr(config, "sample_size", "cross_attention_dim") kwargs = dict( sample_size=min(config["sample_size"], 32), cross_attention_dim=min(config["cross_attention_dim"], 64), ) update_config(config, kwargs) return kwargs
[docs] def get_inputs( model: torch.nn.Module, config: Optional[Any], batch_size: int, sequence_length: int, cache_length: int, in_channels: int, sample_size: int, cross_attention_dim: int, add_second_input: bool = False, **kwargs, # unused ): """ Generates inputs for task ``text-to-image``. Example: :: sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184] timestep:T7s=101 encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257] """ assert ( "cls_cache" not in kwargs ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}." batch = "batch" shapes = { "sample": {0: batch}, "timestep": {}, "encoder_hidden_states": {0: batch, 1: "encoder_length"}, } inputs = dict( sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to( torch.float32 ), timestep=torch.tensor([101], dtype=torch.int64), encoder_hidden_states=torch.randn( (batch_size, sequence_length, cross_attention_dim) ).to(torch.float32), ) res = dict(inputs=inputs, dynamic_shapes=shapes) if add_second_input: res["inputs2"] = get_inputs( model=model, config=config, batch_size=batch_size + 1, sequence_length=sequence_length, cache_length=cache_length + 1, in_channels=in_channels, sample_size=sample_size, cross_attention_dim=cross_attention_dim, **kwargs, )["inputs"] return res
[docs] def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]: """ Inputs kwargs. If the configuration is None, the function selects typical dimensions. """ if config is not None: check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels") kwargs = dict( batch_size=2, sequence_length=pick(config, "in_channels", 4), cache_length=77, in_channels=pick(config, "in_channels", 4), sample_size=pick(config, "sample_size", 32), cross_attention_dim=pick(config, "cross_attention_dim", 64), ) return kwargs, get_inputs