Source code for experimental_experiment.torch_models.diffusion_model_helper

from typing import Any, Dict, Tuple, Union
from . import assert_found


[docs] def get_stable_diffusion_2_unet( inputs_as_dict: bool = False, overwrite: bool = False, **kwargs, ) -> Tuple[Any, Union[Tuple[Any, ...], Dict[str, Any]]]: """ Gets a non initialized model. :param inputs_as_dict: returns dummy inputs as a dictionary or not :param overwrite: do not consider the config from the true model :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` :return: model, inputs See `SableDiffusion2Unet <https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/unet/config.json>`_. """ import torch from diffusers import UNet2DConditionModel config = { "_class_name": "UNet2DConditionModel", "_diffusers_version": "0.8.0", "_name_or_path": "hf-models/stable-diffusion-v2-768x768/unet", "act_fn": "silu", "attention_head_dim": [5, 10, 20, 20], "block_out_channels": [320, 640, 1280, 1280], "center_input_sample": False, "cross_attention_dim": 1024, "down_block_types": [ "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ], "downsample_padding": 1, "dual_cross_attention": False, "flip_sin_to_cos": True, "freq_shift": 0, "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, "norm_eps": 1e-05, "norm_num_groups": 32, "out_channels": 4, "sample_size": 96, "up_block_types": [ "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", ], "use_linear_projection": True, } if overwrite: config = kwargs else: assert_found(kwargs, config) config.update(**kwargs) model = UNet2DConditionModel(**config) model.eval() inputs = dict( sample=torch.randn(1, 4, 128, 128), timestep=torch.tensor([1.0]), encoder_hidden_states=torch.randn(1, 1, 32 if overwrite else 1024), ) if inputs_as_dict: inputs = tuple(inputs.values()) return model, inputs