InputObserver#
Note
This section covers functionality that is specific to PyTorch.
It is only relevant when exporting torch.nn.Module models with
torch.export.export() and has no bearing on ONNX models built
directly with the builder APIs.
torch.export.export() requires callers to supply both the model inputs
(args / kwargs) and a dynamic_shapes specification describing
which tensor dimensions are symbolic at export time. Assembling these two
artefacts by hand is tedious and error-prone, especially for large language
models whose inputs change between the prefill phase (first token) and the
decode phase (subsequent tokens).
InputObserver automates this task: it
temporarily replaces the model’s forward method with a thin wrapper
that records every call, then reconstructs the export arguments and dynamic
shapes from those observations.
Why manual argument construction is hard#
A typical LLM forward signature includes optional arguments such as
attention_mask, past_key_values, and pixel_values (for
vision-language models). The set of arguments that is actually present
changes between calls:
The prefill call includes
pixel_valuesbut nopast_key_values.Decode calls include
past_key_valuesbut nopixel_values.
torch.export.export() needs a single set of representative inputs that
covers all paths, with None placeholders for optional arguments. Figuring
out which arguments are optional and what their shapes look like normally
requires reading model source code. InputObserver infers this
automatically from real forward calls.
Basic usage#
Wrap the model in a with block and run one or more forward passes (or
generate() calls for LLMs). After the block the observer holds enough
information to build the export arguments:
import torch
from yobx.torch.input_observer import InputObserver
observer = InputObserver()
with observer(model):
# Run one or more forward passes with representative inputs.
model(x1, y1)
model(x2, y2)
# Build export arguments from the observed inputs.
args = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
ep = torch.export.export(model, args, dynamic_shapes=dynamic_shapes)
For LLMs the entire token-generation loop can be observed via
generate():
observer = InputObserver()
with observer(model):
model.generate(input_ids)
ep = torch.export.export(
model,
(),
kwargs=observer.infer_arguments(),
dynamic_shapes=observer.infer_dynamic_shapes(),
)
Handling optional arguments with value_if_missing#
When an argument appears only in some observed calls (e.g. pixel_values
only in the prefill pass), the observer cannot automatically fabricate a
representative empty tensor for it. Pass value_if_missing to supply
default shapes for such arguments:
observer = InputObserver(
value_if_missing=dict(
pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16)
)
)
with observer(model):
model.generate(input_ids)
args = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
The values in value_if_missing are only used to infer shapes and
argument structures; the actual tensor data is not passed to the model.
Inferring dynamic shapes#
infer_dynamic_shapes()
compares the shapes observed across multiple forward calls and marks any
dimension that varies as dynamic. Most models have a dynamic batch
dimension but all observed inputs are run with the same batch size because
generating different batch sizes is expensive.
Use set_batch_dimension_for to mark the first axis of selected inputs as
dynamic even when all observations use the same batch size:
dynamic_shapes = observer.infer_dynamic_shapes(
set_batch_dimension_for={"input_ids", "attention_mask"}
)
Pass True to mark the first dimension of all inputs as dynamic:
dynamic_shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
Inspecting observations#
num_obs reports how many
forward calls were captured. The info attribute exposes the raw
InputObserverInfo object, which contains
the full per-call input and output records along with the inferred argument
alignment.
See also
Flattening Functionalities (torch) — registering pytree nodes for DynamicCache
and other transformers classes, which is typically needed alongside
InputObserver when working with LLMs.
Patches (torch export) — applying patches to torch and transformers
internals to enable symbolic tracing during torch.export.export().
yobx.torch.input_observer — API reference for
InputObserver, InputObserverInfo, and
InputCandidate.