Source code for onnx_diagnostic.ci_models.export_qwen25_vl

r"""
Export visual and embedding parts of Qwen/Qwen2.5-VL-7B-Instruct
================================================================

Requirements
++++++++++++

::

    git+https://github.com/sdpython/experimental-experiment.git  # optional
    huggingface_hub>=1.2.1
    onnx-diagnostic>=0.8.6
    onnxruntime>=1.23
    torch>=2.9  # weekly is better
    tqdm
    transformers>=4.57

Examples
++++++++

.. code-block:: bash

    python -m onnx_diagnostic.ci_models.export_qwen25_vl \
        -m Qwen/Qwen2.5-VL-7B-Instruct \
        --device cpu --dtype float32 --exporter onnx-dynamo --pretrained --second-input --zip

To choose a specific Attention schema:

.. code-block:: bash

    QWEN25ATTENTION=LOOPMHA python -m onnx_diagnostic.ci_models.export_qwen25_vl \
        -m Qwen/Qwen2.5-VL-7B-Instruct \
        --device cpu --dtype float32 --exporter onnx-dynamo --pretrained --second-input --zip

Cheat sheet for tar commands. To make a tar:
``tar -czvf model.tar.gz model.onnx model.data``
And to untar:
``tar -xzvf model.tar.gz``.

Rewritings
++++++++++

* `overview <https://sdpython.github.io/doc/onnx-diagnostic/dev/status/patches_diff.html#auto-patch-transformers-qwen2-5-vlforconditionalgeneration-prepare-inputs-for-generation-patched-qwen2-5-vlforconditionalgeneration-prepare-inputs-for-generation>`_
* code: `_patch_transformers_qwen2_5.py <https://github.com/sdpython/onnx-diagnostic/blob/main/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py>`_

Attention
+++++++++

The attention is either implemented with ``MultiHeadAttention`` in a loop,
either with ``PackedMultiHeadAttention``. The choice is made based on the device.
It is possible to overwrite this by by setting environment variable
``QWEN25ATTENTION`` to:

* ``PACKED``: PackedMultiHeadAttention
* ``LOOPMHA``: Loop over MultiHeadAttention
* ``LOOPA23``: Loop over Attention(23), needs opset 23+.
"""

import os
import sys
import time
from typing import Any, Dict, List, Tuple
from .ci_helpers import (
    check_for_discrepancies_and_log_everything_into_a_json_file,
    compute_expected_outputs,
    get_parser,
    get_torch_dtype_from_command_line_args,
    remove_inplace_body_last_input_output_type_for_loop_because_they_might_be_sequences,
    simplify_model_id_for_a_filename,
    zip_model_and_data_into_a_single_file,
)


[docs] def get_untrained_model(model_id: str, second_input: bool, verbose: int) -> Dict[str, Any]: """ Returns an untrained model. :param model_id: model id :param second_input: second input set :param verbose: verbosity :return: model and data """ from ..torch_models.hghub.model_inputs import get_untrained_model_with_inputs if model_id == "arnir0/Tiny-LLM": # used to run a unit test _config_reduction = None else: def _config_reduction(config, task): return { # "num_hidden_layers": 2, "vision_config": {"depth": 2}, "text_config": { "num_hidden_layers": 2, "layer_types": ["full_attention", "full_attention"], }, # "_attn_implementation": "flash_attention_2", "_attn_implementation": "sdpa", "dtype": "float16", } config_reduction = _config_reduction data = get_untrained_model_with_inputs( model_id, verbose=verbose, add_second_input=second_input, config_reduction=config_reduction, ) return data
def get_inputs_for_part( part: str, torch_dtype: "torch.dtype", # noqa: F821 device: str, second_input: bool, image_token_id: int, bos_token_id: int, eos_token_id: int, ) -> Tuple[Dict[str, "torch.Tensor"], List[Dict[str, "torch.Tensor"]]]: # noqa: F821 import torch if part == "visual": export_inputs = dict( pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device), image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), ) other_inputs = [] if second_input: other_inputs = [ dict( pixel_values=torch.randn((1292, 1176), dtype=torch_dtype).to(device), image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), ), dict( pixel_values=torch.rand((1292, 1176), dtype=torch_dtype).to(device), image_grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device), ), dict( pixel_values=torch.randn((14308, 1176), dtype=torch_dtype).to(device), image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), ), dict( pixel_values=torch.rand((14308, 1176), dtype=torch_dtype).to(device), image_grid_thw=torch.tensor([[1, 98, 146]], dtype=torch.int64).to(device), ), ] return export_inputs, other_inputs if part == "embedding": def fix(inputs): img_start_index = 3 img_end_index = img_start_index + patches_per_image # 3 + 3577 = 3580 # Fill in with image token index inputs["input_ids"][0][2] = bos_token_id # <start_of_image> inputs["input_ids"][0][img_start_index:img_end_index] = image_token_id # <image> inputs["input_ids"][0][img_end_index] = eos_token_id # <end_of_image> inputs["input_ids"][1][2] = bos_token_id # <start_of_image> inputs["input_ids"][1][img_start_index:img_end_index] = image_token_id # <image> inputs["input_ids"][1][img_end_index] = eos_token_id # <end_of_image> return inputs batch_size, sequence_length, patches_per_image, out_hidden_size = (2, 3606, 3577, 3584) num_logical_patches = batch_size * patches_per_image def draw(): return { "input_ids": torch.randint( low=0, high=image_token_id, size=(batch_size, sequence_length), dtype=torch.int64, ).to(device), "image_features": torch.randn( num_logical_patches, out_hidden_size, dtype=torch_dtype ).to(device), } return fix(draw()), ([fix(draw()), fix(draw())] if second_input else []) raise NotImplementedError(f"No inputs yet implement for part={part!r}")
[docs] def main( model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct", device: str = "cpu", dtype: str = "float32", exporter: str = "onnx-dynamo", pretrained: bool = True, second_input: bool = True, make_zip: bool = False, output_folder: str = "dump_models", existing_onnx: str | None = None, part: str = "visual", atol: float = 0.01, mismatch01: float = 0.1, profile_exporter: bool = False, ): """ Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it. The script applies as well to other models based on the same architecture. The function saves everything on disk. It does not generate new inputs on the second run but reuses the saved ones. Same goes for the expected outputs with are also saved on disk. :param model_id: model id :param device: device :param dtype: dtype :param exporter: exportor to use :param pretrained: pretrained=False is usually used to test :param second_input: checks discrepancies on more examples :param make_zip: creates a zip at the end :param output_folder: output folder :param part: "" to export the whole model, ``"visual"`` for visual part, ``"embedding"`` for the embedding part :param atol: raises an exception if tolerance is above that threshold :param mismatch01: raises an exception if the ratio of mismatches is above that threshold :param profile_exporter: profiles the exporter """ prefix = simplify_model_id_for_a_filename(model_id) if "QWEN25ATTENTION" in os.environ: prefix = f"{prefix}.{os.environ['QWEN25ATTENTION']}" basename = os.path.join( output_folder, f"model.{prefix}.{part}.{device}.{dtype}.{exporter}" ) filename = f"{basename}.onnx" stat_file = f"{basename}.stats" print("------------------------------------------------------------------") print(f"-- model_id={model_id}") print(f"-- part={part}") print(f"-- device={device}") print(f"-- dtype={dtype}") print(f"-- exporter={exporter}") print(f"-- pretrained={pretrained}") print(f"-- second_input={second_input}") print(f"-- make_zip={make_zip}") print(f"-- output_folder={output_folder}") print(f"-- atol={atol}") print(f"-- mismatch01={mismatch01}") print(f"-- profile_exporter={profile_exporter}") print("------------------------------------------------------------------") print(f"-- prefix={prefix}") print(f"-- export in {filename!r}") print("------------------------------------------------------------------") if os.path.exists(stat_file) and not existing_onnx: print(f"-- skipping because {stat_file!r} already exists") return print("-- import torch and others") import torch from transformers import AutoModel, AutoProcessor from ..helpers import string_type from ..torch_export_patches.patches._patch_transformers_qwen2_5 import ( PLUGS, ) from ..torch_export_patches import torch_export_patches from ..export.api import to_onnx if output_folder and output_folder != ".": os.makedirs(output_folder, exist_ok=True) print(f"-- create model {model_id!r}") print( f"-- device={device!r}, dtype={dtype!r}, exporter={exporter!r}, " f"pretrained={pretrained!r}" ) torch_dtype = get_torch_dtype_from_command_line_args(dtype) if pretrained: print("-- pretrained model") model = AutoModel.from_pretrained( model_id, device_map=device, dtype=torch_dtype, attn_implementation="sdpa" ).eval() data = dict(model=model) config = model.config else: print("-- random model") data = get_untrained_model(model_id, second_input=second_input, verbose=1) model = data["model"] config = data["configuration"] assert ( hasattr(config, "bos_token_id") and config.bos_token_id ), f"missing 'bos_token_id' from config\n{config}" assert ( hasattr(config, "eos_token_id") and config.eos_token_id ), f"missing 'eos_token_id' from config\n{config}" model = model.to(device).to(getattr(torch, dtype)) print(f"-- config._attn_implementation={model.config._attn_implementation}") print(f"-- model.dtype={model.dtype}") print(f"-- model.device={model.device}") processor = AutoProcessor.from_pretrained(model_id, use_fast=True) print(f"-- processor={type(processor)}") export_inputs, other_inputs = None, None if not part: # used to unit test from ..helpers.torch_helper import to_any assert "inputs" in data, f"key 'inputs' is missing from data (available {set(data)})" model_to_export = data["model"] export_inputs = to_any(to_any(data["inputs"], device), torch_dtype) other_inputs = [ v for k, v in data.items() if k.startswith("inputs_") if k != "inputs_prompt" ] dynamic_shapes = data["dynamic_shapes"] assert other_inputs, f"No other inputs was found from data (available {set(data)})" elif part == "visual": class VisualPart(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, pixel_values, image_grid_thw): return model.get_image_features(pixel_values, image_grid_thw) assert hasattr( model, "get_image_features" ), f"get_image_features not found in class {type(model)}" model_to_export = VisualPart(model) dynamic_shapes = dict( pixel_values={0: "hidden_width", 1: "hidden_height"}, image_grid_thw={}, # {0: "n_images"}, # TODO: fix ) elif part == "embedding": class EmbeddingPart(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids, image_features): inputs_embeds = None if inputs_embeds is None: inputs_embeds = self.model.get_input_embeddings()(input_ids) def process_image(inputs_embeds, image_features): image_embeds = image_features image_embeds = torch.cat((image_embeds,), dim=0).to( inputs_embeds.device, inputs_embeds.dtype ) image_mask, _ = self.model.model.get_placeholder_mask( input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) return inputs_embeds return torch.cond( image_features.shape[0] == 0, (lambda embs, _imgf: embs.clone()), process_image, [inputs_embeds, image_features], ) assert hasattr( model, "get_image_features" ), f"get_image_features not found in class {type(model)}" model_to_export = EmbeddingPart(model) dynamic_shapes = { "input_ids": {0: "batch_size", 1: "sequence_length"}, "image_features": {0: "num_logical_patches"}, } else: raise NotImplementedError(f"no export yet for part={part!r}") print(f"-- part={part!r}") print(f"-- model_to_export={type(model_to_export)}") print(f"-- dynamic_shapes={dynamic_shapes}") print("-- ############") print("-- INPUT/OUTPUT") print("-- ############") input_filename = os.path.join(output_folder, f"inputs.{prefix}.{part}.{device}.{dtype}.pt") if os.path.exists(input_filename): print(f"-- restore inputs from {input_filename!r}") data = torch.load(input_filename, weights_only=False) export_inputs = data["export_inputs"] other_inputs = data["other_inputs"] dynamic_shapes = data["dynamic_shapes"] elif export_inputs is not None: data = dict( export_inputs=export_inputs, other_inputs=other_inputs, dynamic_shapes=dynamic_shapes, ) print(f"-- dump inputs into {input_filename!r}") torch.save(data, input_filename) else: export_inputs, other_inputs = get_inputs_for_part( part, torch_dtype, device, second_input, image_token_id=config.image_token_id, bos_token_id=config.bos_token_id, eos_token_id=config.eos_token_id, ) data = dict( export_inputs=export_inputs, other_inputs=other_inputs, dynamic_shapes=dynamic_shapes, ) print(f"-- dump inputs into {input_filename!r}") torch.save(data, input_filename) print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}") print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}") print(f"-- dynamic_shapes={dynamic_shapes}") output_filename = os.path.join( output_folder, f"expected.{prefix}.visual.{device}.{dtype}.pt" ) print("-- ##################") print("-- # EXPECTED_OUTPUTS") print("-- ##################") compute_expected_outputs(output_filename, model_to_export, input_filename) if existing_onnx and os.path.exists(existing_onnx): print("-- ######") print(f"-- USING EXISTING ONNX {existing_onnx!r}") print("-- ######") exporter = existing_onnx filename = existing_onnx target_opset = None else: print("-- ######") print("-- EXPORT") print("-- ######") if exporter != "custom": import packaging.version as pv try: import onnxscript v_onnxscript = onnxscript.__version__ if pv.Version(v_onnxscript) <= pv.Version("0.5.6"): print(f"-- onnxscript=={v_onnxscript} not recent enough") print("-- stop.") return except AttributeError: pass except ImportError: print("-- missing onnxscript, cannot continue") print("-- stop.") return begin = time.perf_counter() target_opset = 22 if ( exporter == "onnx-dynamo" and device == "cuda" and "QWEN25ATTENTION" not in os.environ ): os.environ["QWEN25ATTENTION"] = "PACKED" elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23": target_opset = 23 with torch_export_patches( patch_torch=False, patch_sympy=False, patch_transformers=True, verbose=1, stop_if_static=2, profile=(f"{basename}.profile.html" if profile_exporter else None), ): to_onnx( model_to_export, kwargs=export_inputs, dynamic_shapes=dynamic_shapes, filename=filename, exporter=exporter, verbose=1, save_ep=None, target_opset=target_opset, optimize=True, onnx_plugs=PLUGS, ) export_duration = time.perf_counter() - begin if exporter == "onnx-dynamo": # onnx-dynamo fails at producing function body with sequences as input / output. # They are replaced by tensor type one step in the model. print("-- remove_body_last_input_output_for_loop") remove_inplace_body_last_input_output_type_for_loop_because_they_might_be_sequences( filename ) print("-- done.") print("-- ###############") print("-- # DISCREPANCIES") print("-- ###############") info = { "model_id": model_id, "part": part, "device": device, "dtype": dtype, "exporter": exporter, "pretrained": pretrained, "attention": os.environ.get("QWEN25ATTENTION", "default"), } check_for_discrepancies_and_log_everything_into_a_json_file( agg_stat_file=os.path.join(output_folder, "collection_statistics.js"), stat_file=stat_file, export_duration=export_duration, device=device, model_file=filename, cached_inputs=input_filename, cached_expected_outputs=output_filename, main_info=info, atol=atol, mismatch01=mismatch01, ) if make_zip: print("-- #####") print("-- # ZIP") print("-- #####") zip_model_and_data_into_a_single_file(f"{basename}.zip", filename)
if __name__ == "__main__": parser = get_parser("qwen25") args = parser.parse_args(sys.argv[1:]) main( model_id=args.mid, device=args.device, dtype=args.dtype, exporter=args.exporter, pretrained=args.pretrained, second_input=args.second_input, make_zip=args.zip, output_folder=args.output_folder, existing_onnx=args.existing_onnx, part=args.part, atol=args.atol, mismatch01=args.mismatch01, profile_exporter=args.profile_exporter, )