.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_transformers/plot_llm_to_onnx.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_transformers_plot_llm_to_onnx.py: .. _l-plot-llm-to-onnx: Export a LLM to ONNX with InputObserver ========================================= This example shows how to export a HuggingFace :epkg:`transformers` LLM to ONNX using :class:`InputObserver `. The key challenge when exporting a LLM is that the HuggingFace examples typically call ``model.generate``, but we only need to export the ``forward`` method. :class:`InputObserver ` intercepts the forward calls during generation to record the actual inputs and outputs, which are then used to infer: * the **dynamic shapes** (which tensor dimensions vary across calls), and * a representative set of **export arguments** (with empty tensors for optional inputs that were absent in some calls). We use :epkg:`arnir0/Tiny-LLM` — a very small causal language model — so the example runs without a GPU. **Command-line options** Run with pre-trained weights (default) or a randomly initialised model:: python plot_llm_to_onnx.py # pre-trained weights (default) python plot_llm_to_onnx.py --no-trained # random weights — fast python plot_llm_to_onnx.py --num-hidden-layers 2 # use only 2 transformer layers python plot_llm_to_onnx.py --model Qwen/Qwen2-0.5B-Instruct # use a different model When ``--trained`` is given (the default) the full checkpoint is downloaded (~hundreds of MB) and the exported ONNX model produces meaningful text. Pass ``--no-trained`` to build the model from the config with random weights via :func:`transformers.AutoModelForCausalLM.from_config` — only the tokenizer and the architecture config are downloaded (~few KB), which is useful for quick testing and CI. ``--num-hidden-layers`` overrides ``config.num_hidden_layers`` before the model is instantiated, which shrinks the number of transformer decoder blocks. This is useful for reducing memory use and export time during development. .. GENERATED FROM PYTHON SOURCE LINES 45-47 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 47-65 .. code-block:: Python import argparse import sys import pandas import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from yobx import doc from yobx.helpers import string_type from yobx.helpers.rt_helper import onnx_generate from yobx.torch import ( InputObserver, apply_patches_for_model, register_flattening_functions, to_onnx, ) .. GENERATED FROM PYTHON SOURCE LINES 66-81 Command-line arguments ---------------------- ``--trained`` / ``--no-trained`` controls whether the full pre-trained checkpoint is loaded (default: ``--trained``). Pass ``--no-trained`` to build a randomly initialised model from the architecture config only (faster, no large download, suitable for CI). ``--num-hidden-layers`` overrides the number of transformer decoder blocks in the config before the model is built. Use a small value (e.g. ``2``) to speed up export and reduce memory during development. ``--model`` selects the HuggingFace model ID to use (default: ``arnir0/Tiny-LLM``). Any :class:`transformers.AutoModelForCausalLM`-compatible model can be passed here. .. GENERATED FROM PYTHON SOURCE LINES 81-118 .. code-block:: Python _DEFAULT_MODEL = "arnir0/Tiny-LLM" parser = argparse.ArgumentParser(description="Export a HuggingFace LLM to ONNX.") parser.add_argument( "--model", default=_DEFAULT_MODEL, metavar="MODEL_ID", help=( f"HuggingFace model ID to export (default: {_DEFAULT_MODEL!r}). " "Any AutoModelForCausalLM-compatible model can be used." ), ) parser.add_argument( "--trained", action=argparse.BooleanOptionalAction, default=True, help=( "Load the full pre-trained weights from HuggingFace Hub (default). " "Pass --no-trained to build a randomly initialised model from the config " "(no weight download, suitable for CI)." ), ) parser.add_argument( "--num-hidden-layers", type=int, default=None, metavar="LAYERS", help=( "Override config.num_hidden_layers to N before building the model. " "Reduces the number of transformer decoder blocks, which lowers memory " "use and speeds up export. Defaults to the value in the model config." ), ) # parse_known_args avoids failures when sphinx-gallery passes extra arguments. args, _ = parser.parse_known_args(sys.argv[1:]) .. GENERATED FROM PYTHON SOURCE LINES 119-128 Load model and tokenizer ------------------------ The tokenizer is always fetched from HuggingFace (small download). The architecture config is fetched next; if ``--num-hidden-layers`` was given the corresponding config attribute is overridden before the model is built. By default the model is loaded with pre-trained weights (``--trained``). Pass ``--no-trained`` to use random weights instead (much faster, no large download). .. GENERATED FROM PYTHON SOURCE LINES 128-157 .. code-block:: Python MODEL_NAME = args.model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) config = AutoConfig.from_pretrained(MODEL_NAME) if args.num_hidden_layers is not None: print( f"Overriding num_hidden_layers: " f"{config.num_hidden_layers} -> {args.num_hidden_layers}" ) config.num_hidden_layers = args.num_hidden_layers if args.trained: print(f"Loading pre-trained weights for {MODEL_NAME!r} ...") # ignore_mismatched_sizes=True is required when num_hidden_layers has been # reduced: the checkpoint contains weights for all original layers, and # without this flag from_pretrained would raise an error on the missing keys. model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, config=config, ignore_mismatched_sizes=True ) else: print(f"Building randomly initialised model from config for {MODEL_NAME!r} ...") model = AutoModelForCausalLM.from_config(config) print( f" trained={args.trained} num_hidden_layers={config.num_hidden_layers} " f"#params={sum(p.numel() for p in model.parameters()):,}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Loading pre-trained weights for 'arnir0/Tiny-LLM' ... Loading weights: 0%| | 0/12 [00:00` acts as a context manager that replaces the model's ``forward`` method. Every time ``forward`` is called (internally by ``model.generate``), the inputs and outputs are recorded. :func:`register_flattening_functions ` must wrap the observation because the KV-cache (:class:`transformers.cache_utils.DynamicCache`) is a custom Python class that needs to be registered as a pytree node before :mod:`torch.utils._pytree` can flatten it. .. GENERATED FROM PYTHON SOURCE LINES 182-203 .. code-block:: Python prompt = "Continue: it rains, what should I do?" inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} observer = InputObserver() with ( register_flattening_functions(patch_transformers=True), apply_patches_for_model(patch_transformers=True, model=model), observer(model), ): model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], do_sample=False, max_new_tokens=10, ) print(f"number of stored forward calls: {observer.num_obs()}") .. rst-class:: sphx-glr-script-out .. code-block:: none number of stored forward calls: 3 .. GENERATED FROM PYTHON SOURCE LINES 204-214 Infer dynamic shapes and representative arguments ------------------------------------------------- After generation the observer has seen several forward calls, each with different sequence lengths and KV-cache sizes. We can now ask it to infer: * ``dynamic_shapes`` — a nested structure of ``torch.export.Dim`` values describing which dimensions must be treated as dynamic during export. * ``kwargs`` — one representative set of inputs that can be passed directly to :func:`torch.export.export` or :func:`yobx.torch.to_onnx`. .. GENERATED FROM PYTHON SOURCE LINES 214-222 .. code-block:: Python with register_flattening_functions(patch_transformers=True): dynamic_shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) kwargs = observer.infer_arguments() print("dynamic_shapes:", dynamic_shapes) print("kwargs:", string_type(kwargs, with_shape=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none dynamic_shapes: {'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None} kwargs: dict(input_ids:T7s1x13,attention_mask:T7s1x13,position_ids:T7s1x13,past_key_values:DynamicCache(key_cache=#1[T1s1x1x0x96], value_cache=#1[T1s1x1x0x96]),logits_to_keep:int) .. GENERATED FROM PYTHON SOURCE LINES 223-231 Export to ONNX -------------- We now export the model. Both :func:`register_flattening_functions ` and :func:`apply_patches_for_model ` must be active during export so that the exporter can correctly handle the KV-cache type and any PyTorch ops that need patching. .. GENERATED FROM PYTHON SOURCE LINES 231-246 .. code-block:: Python filename = "plot_llm_to_onnx.onnx" with ( register_flattening_functions(patch_transformers=True), apply_patches_for_model(patch_torch=True, patch_transformers=True, model=model), ): to_onnx( model, (), kwargs=observer.infer_arguments(), dynamic_shapes=observer.infer_dynamic_shapes(set_batch_dimension_for=True), filename=filename, ) .. GENERATED FROM PYTHON SOURCE LINES 247-254 Verify: check discrepancies ---------------------------- :meth:`check_discrepancies ` runs every recorded set of inputs through both the original PyTorch model and the exported ONNX model, then reports the maximum absolute difference for each output. Values close to zero confirm that the export is correct. .. GENERATED FROM PYTHON SOURCE LINES 254-258 .. code-block:: Python data = observer.check_discrepancies(filename, progress_bar=True) print(pandas.DataFrame(data)) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/3 [00:000.1 >0.01 SUCCESS index duration_torch ort_duration n_inputs n_none n_empty inputs outputs_torch outputs_ort 0 0.006306 0.328028 26.505746 34496.0 0 0 0 0 False 0 0.178980 0.700125 5 2 2 dict(input_ids:T7s1x13,attention_mask:T7s1x13,... #3[T1s1x1x32000,T1s1x1x13x96,T1s1x1x13x96] #3[T1s1x1x32000,T1s1x1x13x96,T1s1x1x13x96] 1 0.000009 0.002013 0.062643 34688.0 0 0 0 0 True 1 0.041286 0.028306 5 0 0 dict(input_ids:T7s1x1,attention_mask:T7s1x14,p... #3[T1s1x1x32000,T1s1x1x14x96,T1s1x1x14x96] #3[T1s1x1x32000,T1s1x1x14x96,T1s1x1x14x96] 2 0.000012 0.002171 0.065408 34880.0 0 0 0 0 True 2 0.008502 0.021704 5 0 0 dict(input_ids:T7s1x1,attention_mask:T7s1x15,p... #3[T1s1x1x32000,T1s1x1x15x96,T1s1x1x15x96] #3[T1s1x1x32000,T1s1x1x15x96,T1s1x1x15x96] .. GENERATED FROM PYTHON SOURCE LINES 259-267 Run the ONNX model in a greedy auto-regressive loop ---------------------------------------------------- :func:`onnx_generate ` mimics ``model.generate`` for the exported ONNX model: it feeds the *present* key/value tensors back as *past* key/values on every decoding step. (With random weights the output tokens will be meaningless, but the pipeline itself is exercised end-to-end.) .. GENERATED FROM PYTHON SOURCE LINES 267-280 .. code-block:: Python onnx_tokens = onnx_generate( filename, input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], eos_token_id=model.config.eos_token_id, max_new_tokens=50, ) onnx_generated_text = tokenizer.decode(onnx_tokens[0], skip_special_tokens=True) print("-----------------") print(onnx_generated_text) print("-----------------") .. rst-class:: sphx-glr-script-out .. code-block:: none ----------------- Continue: it rains, what should I do? I have a lot of people who are in the world. I have a lot of people who are in the world, and I have a lot of people who are in the world. I have a lot of people who are in the world, ----------------- .. GENERATED FROM PYTHON SOURCE LINES 281-285 Visualise the ONNX graph ------------------------ Render the exported ONNX model as a DOT graph. .. GENERATED FROM PYTHON SOURCE LINES 285-287 .. code-block:: Python doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) .. image-sg:: /auto_examples_transformers/images/sphx_glr_plot_llm_to_onnx_001.png :alt: plot llm to onnx :srcset: /auto_examples_transformers/images/sphx_glr_plot_llm_to_onnx_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 21.173 seconds) .. _sphx_glr_download_auto_examples_transformers_plot_llm_to_onnx.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_llm_to_onnx.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_llm_to_onnx.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_llm_to_onnx.zip ` .. include:: plot_llm_to_onnx.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_