Flattening Functionalities (torch)#
Note
This section covers functionality that is specific to PyTorch.
It is only relevant when exporting torch.nn.Module models with
torch.export.export() and has no bearing on ONNX models built
directly with the builder APIs.
torch.export.export() and the torch pytree machinery require every
Python object that appears as a model input or output to be registered as a
pytree node. When a class is not registered, exporting fails with a cryptic
error.
yobx.torch.flatten provides utilities to register, unregister,
and compose such registrations cleanly.
Why flattening matters#
torch.export.export() traces a PyTorch model into a portable
torch.fx.GraphModule. During tracing every input and every output
must be decomposable into a flat list of torch.Tensor objects. The
decomposition is handled by torch.utils._pytree, which knows about
built-in Python containers (list, tuple, dict) but not about
arbitrary user-defined classes.
If a model returns — or receives as input — a class like
transformers.DynamicCache, exporting will fail unless that class has been
registered as a pytree node with:
a flatten function — extracts the tensors and a serialisable context object that describes the structure,
an unflatten function — recreates the original object from the flat list and the context,
a flatten-with-keys function — same as flatten but pairs each tensor with a
torch.utils._pytree.MappingKeythat names it.
Core helpers#
register_class_flattening#
yobx.torch.flatten.register_class_flattening() is a thin wrapper
around torch.utils._pytree.register_pytree_node that:
skips the registration silently when the class is already registered (avoids duplicate-registration errors),
optionally runs a user-supplied check callable to verify the round-trip immediately after registration.
<<<
import dataclasses
from typing import Any, List, Tuple
import torch
from yobx.torch.flatten import (
register_class_flattening,
unregister_class_flattening,
)
# A minimal dict-like container.
class MyOutput(dict):
"""Simple dict subclass that can hold named tensors."""
def flatten_my_output(obj):
keys = list(obj.keys())
return [obj[k] for k in keys], keys
def flatten_with_keys_my_output(obj):
keys = list(obj.keys())
values = [obj[k] for k in keys]
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(keys, values)], keys
def unflatten_my_output(values, context, output_type=None):
return MyOutput(zip(context, values))
register_class_flattening(
MyOutput,
flatten_my_output,
unflatten_my_output,
flatten_with_keys_my_output,
)
# Flatten and unflatten a MyOutput object.
obj = MyOutput(a=torch.tensor([1.0, 2.0]), b=torch.tensor([3.0]))
flat, spec = torch.utils._pytree.tree_flatten(obj)
print("flat tensors:", [t.tolist() for t in flat])
restored = torch.utils._pytree.tree_unflatten(flat, spec)
print("restored keys:", list(restored.keys()))
# Clean up so that subsequent runs of the docs are not affected.
unregister_class_flattening(MyOutput)
>>>
flat tensors: [[1.0, 2.0], [3.0]]
restored keys: ['a', 'b']
make_flattening_function_for_dataclass#
yobx.torch.flatten.make_flattening_function_for_dataclass()
auto-generates the three required callables for any class that exposes
.keys() / .values() like a mapping (e.g. transformers.ModelOutput
subclasses).
<<<
import torch
from yobx.torch.flatten import (
make_flattening_function_for_dataclass,
register_class_flattening,
unregister_class_flattening,
)
class HiddenState(dict):
"""Dict-like container for a transformer hidden state."""
flatten_fn, flatten_with_keys_fn, unflatten_fn = make_flattening_function_for_dataclass(
HiddenState, set()
)
print("generated function names:")
print(" ", flatten_fn.__name__)
print(" ", flatten_with_keys_fn.__name__)
print(" ", unflatten_fn.__name__)
register_class_flattening(HiddenState, flatten_fn, unflatten_fn, flatten_with_keys_fn)
obj = HiddenState(last_hidden_state=torch.zeros(2, 3))
flat, spec = torch.utils._pytree.tree_flatten(obj)
print("flat:", [t.shape for t in flat])
unregister_class_flattening(HiddenState)
>>>
generated function names:
flatten_hidden_state
flatten_with_keys_hidden_state
unflatten_hidden_state
flat: [torch.Size([2, 3])]
register_cache_flattening and the context manager#
yobx.torch.flatten.register_cache_flattening() registers a
collection of classes in one call and returns a dict that can be passed to
yobx.torch.flatten.unregister_cache_flattening() to undo every
registration.
yobx.torch.flatten.register_flattening_functions() wraps both in
a contextlib.contextmanager() so that registrations are automatically
undone when the with block exits:
from yobx.torch import register_flattening_functions
with register_flattening_functions(patch_transformers=True):
# Inside this block:
# DynamicCache, StaticCache, EncoderDecoderCache, and BaseModelOutput
# are registered as pytree nodes.
exported = torch.export.export(model, (inputs,))
After the with block all registrations are rolled back, leaving
torch.utils._pytree.SUPPORTED_NODES exactly as it was before.
Transformers-specific registrations#
When patch_transformers=True is passed to
register_cache_flattening() (or
register_flattening_functions()), the
following classes from transformers are registered:
Class |
Description |
|---|---|
|
Key-value cache whose layers grow as new tokens are decoded |
|
Pre-allocated key-value cache with a fixed maximum length |
|
Wraps a self-attention and a cross-attention cache |
|
Generic output container (dict-like dataclass) |
The flatten functions are defined in
yobx.torch.in_transformers.flatten_class. The module also patches
registrations that are already present but known to be incompatible with
torch.export.export() (see WRONG_REGISTRATIONS).
DynamicCache layers#
Note
The layer-type-aware flattening described below relies on the layers
attribute of transformers.cache_utils.DynamicCache, which was
introduced in transformers >= 4.50. On older versions of
transformers the cache is serialized with plain key_<i> /
value_<i> keys and no per-layer type information is preserved. Use
flatten_dynamic_cache() only
with transformers >= 4.50 if you need to round-trip mixed layer types.
A DynamicCache can contain layers of different types
(DynamicLayer, DynamicSlidingWindowLayer, etc.). The flatten
context encodes each layer type as a short letter code so that the
correct layer class and its kwargs are recreated on unflatten:
Layer class |
Code |
|---|---|
|
D |
|
W |
|
S |
|
X |
The following example builds a cache whose first layer is a plain
DynamicLayer and whose second layer is a DynamicSlidingWindowLayer.
It then round-trips the cache through
flatten_dynamic_cache() /
unflatten_dynamic_cache() and
shows that the layer types and the sliding_window parameter are
preserved:
<<<
import torch
import transformers
from yobx.helpers import string_type
from yobx.torch.in_transformers.cache_helper import make_dynamic_cache
from yobx.torch.in_transformers.flatten_class import (
flatten_dynamic_cache,
unflatten_dynamic_cache,
)
# DynamicSlidingWindowLayer was added in transformers >= 4.50.
if not hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"):
print("DynamicSlidingWindowLayer not available, skipping example.")
else:
bsize, nheads, slen, dim = 2, 4, 3, 7
sliding_window = slen # sliding_window must be >= the sequence length
# Build a DynamicCache with two different layer types.
cache = make_dynamic_cache(
[
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
),
(
torch.randn(bsize, nheads, slen, dim),
torch.randn(bsize, nheads, slen, dim),
),
],
cls_layers=[
transformers.cache_utils.DynamicLayer,
transformers.cache_utils.DynamicSlidingWindowLayer,
],
cls_kwargs=[{}, {"sliding_window": sliding_window}],
)
print("cache:", string_type(cache, with_shape=True))
# Flatten (serialize) the cache — the context encodes each layer's type.
flat, context = flatten_dynamic_cache(cache)
print("context keys:", context)
# Unflatten (deserialize) the cache — layer types are restored from the context.
rebuilt = unflatten_dynamic_cache(flat, context)
print("layer types:", [type(layer).__name__ for layer in rebuilt.layers])
print("sliding_window:", rebuilt.layers[1].sliding_window)
>>>
cache: DynamicCache(DynamicLayer(T1s2x4x3x7, T1s2x4x3x7), DynamicSlidingWindowLayer(T1s2x4x3x7, T1s2x4x3x7))
context keys: ['key_D_0', 'value_D_0', 'key_W3_1', 'value_W3_1']
layer types: ['DynamicLayer', 'DynamicSlidingWindowLayer']
sliding_window: 3
List of supported classes#
See Flattening List.
See also
Registering a custom class as a pytree node — sphinx-gallery example demonstrating registration of a custom class and the round-trip flatten / unflatten.
Interesting Helpers — the MiniOnnxBuilder which serialises
pytree-flattened tensors to ONNX.