.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples_torch/plot_input_observer_transformers.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_torch_plot_input_observer_transformers.py: .. _l-plot-input-observer-transformers: InputObserver with Transformers Cache ====================================== :class:`InputObserver ` is a context manager that **steals** a model's forward method during inference to record every set of inputs and outputs. After the context exits, the collected data can be used to: * infer which tensor dimensions are **dynamic** across the observed calls, and * build a representative set of **export arguments** (with empty tensors for optional inputs that were missing in some calls). These two pieces of information are exactly what :func:`torch.export.export` and :func:`torch.onnx.export` need. The example below shows three progressively richer scenarios: 1. **Simple model** — two plain-tensor inputs with varying batch and sequence lengths. 2. **LLM-like model** — inputs that include a :class:`transformers.cache_utils.DynamicCache` (key-value cache), which requires registering custom pytree flattening rules via :func:`register_flattening_functions `. 3. **Multimodal model** — a model that receives ``pixel_values`` only on the very first call (the *prefill* step). The ``value_if_missing`` argument tells the observer what to substitute when the input is absent, so that the dynamic shape analysis remains possible. .. GENERATED FROM PYTHON SOURCE LINES 33-35 Imports -------- .. GENERATED FROM PYTHON SOURCE LINES 35-42 .. code-block:: Python import torch from yobx.helpers import string_type from yobx.torch import register_flattening_functions from yobx.torch.input_observer import InputObserver from yobx.torch.in_transformers.cache_helper import make_dynamic_cache .. GENERATED FROM PYTHON SOURCE LINES 43-49 1. Simple model - two tensor inputs ----------------------------------- We start with the most basic case: a model that takes two float tensors and returns their sum. We run it with three different shapes so that the observer can detect that both the batch and the sequence dimension are dynamic. .. GENERATED FROM PYTHON SOURCE LINES 49-75 .. code-block:: Python class AddModel(torch.nn.Module): """Adds two tensors element-wise (broadcasting on the batch dimension).""" def forward(self, x, y): return x + y model_add = AddModel() inputs_add = [ (torch.randn(2, 6), torch.randn(1, 6)), (torch.randn(3, 7), torch.randn(1, 7)), (torch.randn(4, 8), torch.randn(1, 8)), ] observer_add = InputObserver() with observer_add(model_add): for x, y in inputs_add: model_add(x, y) # InputObserver captures at most 3 calls by default (store_n_calls=3). print("Observations stored:", observer_add.num_obs()) assert observer_add.num_obs() == 3 .. rst-class:: sphx-glr-script-out .. code-block:: none Observations stored: 3 .. GENERATED FROM PYTHON SOURCE LINES 76-78 ``infer_dynamic_shapes`` returns a tuple of per-argument shape specs, using ``torch.export.Dim.DYNAMIC`` as a placeholder wherever a dimension varies. .. GENERATED FROM PYTHON SOURCE LINES 78-82 .. code-block:: Python dyn_add = observer_add.infer_dynamic_shapes() print("Dynamic shapes (add model):", dyn_add) .. rst-class:: sphx-glr-script-out .. code-block:: none Dynamic shapes (add model): ({0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, {1: DimHint(DYNAMIC)}) .. GENERATED FROM PYTHON SOURCE LINES 83-85 ``infer_arguments`` returns one representative set of inputs with empty tensors substituted for any optional argument that was missing in some calls. .. GENERATED FROM PYTHON SOURCE LINES 85-89 .. code-block:: Python args_add = observer_add.infer_arguments() print("Inferred arguments:", string_type(args_add, with_shape=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none Inferred arguments: (T1s2x6,T1s1x6) .. GENERATED FROM PYTHON SOURCE LINES 90-104 2. LLM-like model - inputs with a DynamicCache ---------------------------------------------- Transformer language models store previously computed key/value pairs in a :class:`transformers.cache_utils.DynamicCache`. Because ``DynamicCache`` is a custom class (not a plain Python container), we must register it as a *pytree node* before ``torch.utils._pytree.tree_flatten`` can decompose it. :func:`register_flattening_functions` is a context manager that registers all supported cache types (``DynamicCache``, ``EncoderDecoderCache`` …) on entry and unregisters them on exit. We simulate two decoding steps: one with a short sequence in the cache and one with a slightly longer sequence. .. GENERATED FROM PYTHON SOURCE LINES 104-159 .. code-block:: Python class LLMLikeModel(torch.nn.Module): """Minimal stand-in for a causal LM forward pass.""" def forward(self, input_ids, attention_mask=None, past_key_values=None): # A real model would compute hidden states here. # We just return the inputs unchanged so the example is self-contained. return input_ids, past_key_values n_layers = 2 n_heads = 4 head_dim = 32 # Prefill step: the KV-cache holds 10 tokens. cache_prefill = make_dynamic_cache( [ (torch.rand(1, n_heads, 10, head_dim), torch.rand(1, n_heads, 10, head_dim)) for _ in range(n_layers) ] ) # First decode step: the KV-cache now holds 11 tokens. cache_decode = make_dynamic_cache( [ (torch.rand(1, n_heads, 11, head_dim), torch.rand(1, n_heads, 11, head_dim)) for _ in range(n_layers) ] ) llm_inputs = [ dict( input_ids=torch.randint(0, 1000, (1, 10)), attention_mask=torch.ones(1, 10, dtype=torch.int64), past_key_values=cache_prefill, ), dict( input_ids=torch.randint(0, 1000, (1, 1)), attention_mask=torch.ones(1, 11, dtype=torch.int64), past_key_values=cache_decode, ), ] model_llm = LLMLikeModel() observer_llm = InputObserver() # The `register_flattening_functions` context manager must wrap *both* the # inference calls and the subsequent shape / argument inference. with register_flattening_functions(patch_transformers=True), observer_llm(model_llm): for kwargs in llm_inputs: model_llm(**kwargs) print("\nObservations stored (LLM):", observer_llm.num_obs()) .. rst-class:: sphx-glr-script-out .. code-block:: none Observations stored (LLM): 2 .. GENERATED FROM PYTHON SOURCE LINES 160-162 Retrieve dynamic shapes. We pass ``set_batch_dimension_for=True`` to mark the first dimension of every tensor as dynamic even though both calls used batch=1. .. GENERATED FROM PYTHON SOURCE LINES 162-170 .. code-block:: Python with register_flattening_functions(patch_transformers=True): dyn_llm = observer_llm.infer_dynamic_shapes(set_batch_dimension_for=True) kwargs_llm = observer_llm.infer_arguments() print("Dynamic shapes (LLM):", dyn_llm) print("Inferred kwargs:", string_type(kwargs_llm, with_shape=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none Dynamic shapes (LLM): {'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}]} Inferred kwargs: dict(input_ids:T7s1x10,attention_mask:T7s1x10,past_key_values:DynamicCache(key_cache=#2[T1s1x4x10x32,T1s1x4x10x32], value_cache=#2[T1s1x4x10x32,T1s1x4x10x32])) .. GENERATED FROM PYTHON SOURCE LINES 171-175 The shapes for ``past_key_values`` are expressed as a flat list (one entry per key or value tensor across all layers). Both dimension 0 (batch) and dimension 2 (sequence length) are marked dynamic, while dimension 1 (heads) and dimension 3 (head dimension) are static. .. GENERATED FROM PYTHON SOURCE LINES 177-187 3. Multimodal model - pixel_values present only on the first call ----------------------------------------------------------------- Vision-language models like Gemma3 or LLaVA receive ``pixel_values`` only during the prefill step. Subsequent decode steps omit that argument and introduce ``past_key_values`` instead. Without extra information the observer cannot infer an empty tensor for ``pixel_values`` (it was never seen as an empty tensor). The ``value_if_missing`` argument provides this information explicitly. .. GENERATED FROM PYTHON SOURCE LINES 187-254 .. code-block:: Python class MultimodalModel(torch.nn.Module): """Minimal stand-in for a vision-language model forward pass.""" def forward(self, input_ids, pixel_values=None, attention_mask=None, past_key_values=None): return input_ids, past_key_values image_h, image_w = 224, 224 cache_mm_step1 = make_dynamic_cache( [ (torch.rand(1, n_heads, 20, head_dim), torch.rand(1, n_heads, 20, head_dim)) for _ in range(n_layers) ] ) cache_mm_step2 = make_dynamic_cache( [ (torch.rand(1, n_heads, 21, head_dim), torch.rand(1, n_heads, 21, head_dim)) for _ in range(n_layers) ] ) mm_inputs = [ # Prefill: image + text, no past cache yet. dict( input_ids=torch.randint(0, 1000, (1, 20)), pixel_values=torch.rand(1, 3, image_h, image_w), attention_mask=torch.ones(1, 20, dtype=torch.int64), ), # Decode step 1: no image, but a growing KV-cache. dict( input_ids=torch.randint(0, 1000, (1, 1)), attention_mask=torch.ones(1, 21, dtype=torch.int64), past_key_values=cache_mm_step1, ), # Decode step 2. dict( input_ids=torch.randint(0, 1000, (1, 1)), attention_mask=torch.ones(1, 22, dtype=torch.int64), past_key_values=cache_mm_step2, ), ] model_mm = MultimodalModel() # Provide an empty tensor (batch=0) for pixel_values so the observer knows its # shape and dtype when it is absent. The zero batch dimension signals # "optional but with this shape when present". observer_mm = InputObserver( value_if_missing=dict(pixel_values=torch.empty((0, 3, image_h, image_w), dtype=torch.float32)) ) with register_flattening_functions(patch_transformers=True), observer_mm(model_mm): for kwargs in mm_inputs: model_mm(**kwargs) print("\nObservations stored (multimodal):", observer_mm.num_obs()) with register_flattening_functions(patch_transformers=True): dyn_mm = observer_mm.infer_dynamic_shapes(set_batch_dimension_for=True) kwargs_mm = observer_mm.infer_arguments() print("Dynamic shapes (multimodal):", dyn_mm) print("Inferred kwargs:", string_type(kwargs_mm, with_shape=True)) .. rst-class:: sphx-glr-script-out .. code-block:: none Observations stored (multimodal): 3 Dynamic shapes (multimodal): {'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'pixel_values': {0: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}]} Inferred kwargs: dict(input_ids:T7s1x1,pixel_values:T1s0x3x224x224,attention_mask:T7s1x21,past_key_values:DynamicCache(key_cache=#2[T1s1x4x20x32,T1s1x4x20x32], value_cache=#2[T1s1x4x20x32,T1s1x4x20x32])) .. GENERATED FROM PYTHON SOURCE LINES 255-273 Note that ``pixel_values`` now appears in the inferred arguments with an empty first dimension (batch=0) even though it was absent in two of the three calls. The spatial dimensions 2 and 3 (height and width) are not dynamic because they were always 224x224. These shapes and arguments can be passed directly to :func:`torch.export.export` or :func:`torch.onnx.export`: .. code-block:: python import torch ep = torch.export.export( model_mm, (), kwargs=kwargs_mm, dynamic_shapes=dyn_mm, ) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.027 seconds) .. _sphx_glr_download_auto_examples_torch_plot_input_observer_transformers.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_input_observer_transformers.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_input_observer_transformers.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_input_observer_transformers.zip ` .. include:: plot_input_observer_transformers.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_