.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_torch/plot_input_observer_tiny_llm.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_torch_plot_input_observer_tiny_llm.py: .. _l-plot-input-observer-tiny-llm: Export a LLM with InputObserver (with Tiny-LLM) =============================================== The main issue when exporting a LLM is the example on HuggingFace is based on method generate but we only need to export the forward method. Example :ref:`l-plot-input-observer-transformers` gives details on how to guess dummy inputs and dynamic shapes to do so. Let's see how to simplify that. Dummy Example +++++++++++++ Let's use the example provided on `arnir0/Tiny-LLM `_. .. GENERATED FROM PYTHON SOURCE LINES 19-72 .. code-block:: Python import pandas from transformers import AutoModelForCausalLM, AutoTokenizer from yobx import doc from yobx.helpers import string_type from yobx.helpers.rt_helper import onnx_generate from yobx.torch import ( register_flattening_functions, apply_patches_for_model, to_onnx, InputObserver, ) MODEL_NAME = "arnir0/Tiny-LLM" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) def generate_text( prompt, model, tokenizer, max_length=50, temperature=0.01, top_k=50, top_p=0.95, do_sample=True, ): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=do_sample, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text # Define your prompt prompt = "Continue: it rains, what should I do?" generated_text = generate_text(prompt, model, tokenizer) print("-----------------") print(generated_text) print("-----------------") .. rst-class:: sphx-glr-script-out .. code-block:: none Loading weights: 0%| | 0/12 [00:00`. We also need to registers additional patches for :epkg:`transformers`. Then :epkg:`pytorch` knows how to flatten/unflatten inputs. .. GENERATED FROM PYTHON SOURCE LINES 80-88 .. code-block:: Python observer = InputObserver() with register_flattening_functions(patch_transformers=True), observer(model): generate_text(prompt, model, tokenizer) print(f"number of stored inputs: {len(observer.info)}") .. rst-class:: sphx-glr-script-out .. code-block:: none number of stored inputs: 3 .. GENERATED FROM PYTHON SOURCE LINES 89-95 Exports +++++++ The `InputObserver` has now enough data to infer arguments and dynamic shapes. We need more than flattening but also patches to export the model. Inferred dynamic shapes looks like: .. GENERATED FROM PYTHON SOURCE LINES 95-99 .. code-block:: Python with register_flattening_functions(patch_transformers=True): dynamic_shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) kwargs = observer.infer_arguments() .. GENERATED FROM PYTHON SOURCE LINES 100-101 and inferred arguments: .. GENERATED FROM PYTHON SOURCE LINES 101-104 .. code-block:: Python print("dynamic_shapes:", dynamic_shapes) print("kwargs:", string_type(kwargs, with_shape=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none dynamic_shapes: {'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'position_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}], 'logits_to_keep': None} kwargs: dict(input_ids:T7s1x13,attention_mask:T7s1x13,position_ids:T7s1x13,past_key_values:DynamicCache(key_cache=#1[T1s1x1x0x96], value_cache=#1[T1s1x1x0x96]),logits_to_keep:int) .. GENERATED FROM PYTHON SOURCE LINES 105-106 Let's export. .. GENERATED FROM PYTHON SOURCE LINES 106-120 .. code-block:: Python filenamec = "plot_input_observer_tiny_llm.onnx" with ( register_flattening_functions(patch_transformers=True), apply_patches_for_model(patch_torch=True, patch_transformers=True, model=model), ): to_onnx( model, (), kwargs=observer.infer_arguments(), dynamic_shapes=observer.infer_dynamic_shapes(set_batch_dimension_for=True), filename=filenamec, ) .. GENERATED FROM PYTHON SOURCE LINES 121-126 Check discrepancies +++++++++++++++++++ The model is exported into ONNX. We use again the stored inputs and outputs to verify the model produces the same outputs. .. GENERATED FROM PYTHON SOURCE LINES 126-131 .. code-block:: Python data = observer.check_discrepancies(filenamec, progress_bar=True) print(pandas.DataFrame(data)) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/3 [00:00` runs the exported ONNX model in a greedy auto-regressive loop, feeding the *present* key/value tensors back as *past* key/values on every subsequent call, just like the HuggingFace ``generate`` method. .. GENERATED FROM PYTHON SOURCE LINES 186-199 .. code-block:: Python onnx_tokens = onnx_generate( filenamec, input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], eos_token_id=model.config.eos_token_id, max_new_tokens=50, ) onnx_generated_text = tokenizer.decode(onnx_tokens[0], skip_special_tokens=True) print("-----------------") print(onnx_generated_text) print("-----------------") .. rst-class:: sphx-glr-script-out .. code-block:: none ----------------- Continue: it rains, what should I do? I have a lot of people who are in the world. I have a lot of people who are in the world, and I have a lot of people who are in the world. I have a lot of people who are in the world, ----------------- .. GENERATED FROM PYTHON SOURCE LINES 200-201 .. code-block:: Python doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) .. image-sg:: /auto_examples_torch/images/sphx_glr_plot_input_observer_tiny_llm_001.png :alt: plot input observer tiny llm :srcset: /auto_examples_torch/images/sphx_glr_plot_input_observer_tiny_llm_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 14.143 seconds) .. _sphx_glr_download_auto_examples_torch_plot_input_observer_tiny_llm.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_input_observer_tiny_llm.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_input_observer_tiny_llm.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_input_observer_tiny_llm.zip ` .. include:: plot_input_observer_tiny_llm.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_