Flattening Functionalities (torch)#

Note

This section covers functionality that is specific to PyTorch. It is only relevant when exporting torch.nn.Module models with torch.export.export() and has no bearing on ONNX models built directly with the builder APIs.

torch.export.export() and the torch pytree machinery require every Python object that appears as a model input or output to be registered as a pytree node. When a class is not registered, exporting fails with a cryptic error.

yobx.torch.flatten provides utilities to register, unregister, and compose such registrations cleanly.

Why flattening matters#

torch.export.export() traces a PyTorch model into a portable torch.fx.GraphModule. During tracing every input and every output must be decomposable into a flat list of torch.Tensor objects. The decomposition is handled by torch.utils._pytree, which knows about built-in Python containers (list, tuple, dict) but not about arbitrary user-defined classes.

If a model returns — or receives as input — a class like transformers.DynamicCache, exporting will fail unless that class has been registered as a pytree node with:

  • a flatten function — extracts the tensors and a serialisable context object that describes the structure,

  • an unflatten function — recreates the original object from the flat list and the context,

  • a flatten-with-keys function — same as flatten but pairs each tensor with a torch.utils._pytree.MappingKey that names it.

Core helpers#

register_class_flattening#

yobx.torch.flatten.register_class_flattening() is a thin wrapper around torch.utils._pytree.register_pytree_node that:

  • skips the registration silently when the class is already registered (avoids duplicate-registration errors),

  • optionally runs a user-supplied check callable to verify the round-trip immediately after registration.

<<<

import dataclasses
from typing import Any, List, Tuple
import torch
from yobx.torch.flatten import (
    register_class_flattening,
    unregister_class_flattening,
)


# A minimal dict-like container.
class MyOutput(dict):
    """Simple dict subclass that can hold named tensors."""


def flatten_my_output(obj):
    keys = list(obj.keys())
    return [obj[k] for k in keys], keys


def flatten_with_keys_my_output(obj):
    keys = list(obj.keys())
    values = [obj[k] for k in keys]
    return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(keys, values)], keys


def unflatten_my_output(values, context, output_type=None):
    return MyOutput(zip(context, values))


register_class_flattening(
    MyOutput,
    flatten_my_output,
    unflatten_my_output,
    flatten_with_keys_my_output,
)

# Flatten and unflatten a MyOutput object.
obj = MyOutput(a=torch.tensor([1.0, 2.0]), b=torch.tensor([3.0]))
flat, spec = torch.utils._pytree.tree_flatten(obj)
print("flat tensors:", [t.tolist() for t in flat])

restored = torch.utils._pytree.tree_unflatten(flat, spec)
print("restored keys:", list(restored.keys()))

# Clean up so that subsequent runs of the docs are not affected.
unregister_class_flattening(MyOutput)

>>>

    flat tensors: [[1.0, 2.0], [3.0]]
    restored keys: ['a', 'b']

make_flattening_function_for_dataclass#

yobx.torch.flatten.make_flattening_function_for_dataclass() auto-generates the three required callables for any class that exposes .keys() / .values() like a mapping (e.g. transformers.ModelOutput subclasses).

<<<

import torch
from yobx.torch.flatten import (
    make_flattening_function_for_dataclass,
    register_class_flattening,
    unregister_class_flattening,
)


class HiddenState(dict):
    """Dict-like container for a transformer hidden state."""


flatten_fn, flatten_with_keys_fn, unflatten_fn = make_flattening_function_for_dataclass(
    HiddenState, set()
)

print("generated function names:")
print(" ", flatten_fn.__name__)
print(" ", flatten_with_keys_fn.__name__)
print(" ", unflatten_fn.__name__)

register_class_flattening(HiddenState, flatten_fn, unflatten_fn, flatten_with_keys_fn)

obj = HiddenState(last_hidden_state=torch.zeros(2, 3))
flat, spec = torch.utils._pytree.tree_flatten(obj)
print("flat:", [t.shape for t in flat])

unregister_class_flattening(HiddenState)

>>>

    generated function names:
      flatten_hidden_state
      flatten_with_keys_hidden_state
      unflatten_hidden_state
    flat: [torch.Size([2, 3])]

register_cache_flattening and the context manager#

yobx.torch.flatten.register_cache_flattening() registers a collection of classes in one call and returns a dict that can be passed to yobx.torch.flatten.unregister_cache_flattening() to undo every registration.

yobx.torch.flatten.register_flattening_functions() wraps both in a contextlib.contextmanager() so that registrations are automatically undone when the with block exits:

from yobx.torch import register_flattening_functions

with register_flattening_functions(patch_transformers=True):
    # Inside this block:
    # DynamicCache, StaticCache, EncoderDecoderCache, and BaseModelOutput
    # are registered as pytree nodes.
    exported = torch.export.export(model, (inputs,))

After the with block all registrations are rolled back, leaving torch.utils._pytree.SUPPORTED_NODES exactly as it was before.

Transformers-specific registrations#

When patch_transformers=True is passed to register_cache_flattening() (or register_flattening_functions()), the following classes from transformers are registered:

Class

Description

DynamicCache

Key-value cache whose layers grow as new tokens are decoded

StaticCache

Pre-allocated key-value cache with a fixed maximum length

EncoderDecoderCache

Wraps a self-attention and a cross-attention cache

BaseModelOutput

Generic output container (dict-like dataclass)

The flatten functions are defined in yobx.torch.in_transformers.flatten_class. The module also patches registrations that are already present but known to be incompatible with torch.export.export() (see WRONG_REGISTRATIONS).

DynamicCache layers#

Note

The layer-type-aware flattening described below relies on the layers attribute of transformers.cache_utils.DynamicCache, which was introduced in transformers >= 4.50. On older versions of transformers the cache is serialized with plain key_<i> / value_<i> keys and no per-layer type information is preserved. Use flatten_dynamic_cache() only with transformers >= 4.50 if you need to round-trip mixed layer types.

A DynamicCache can contain layers of different types (DynamicLayer, DynamicSlidingWindowLayer, etc.). The flatten context encodes each layer type as a short letter code so that the correct layer class and its kwargs are recreated on unflatten:

Layer class

Code

DynamicLayer

D

DynamicSlidingWindowLayer

W

StaticLayer

S

StaticSlidingWindowLayer

X

The following example builds a cache whose first layer is a plain DynamicLayer and whose second layer is a DynamicSlidingWindowLayer. It then round-trips the cache through flatten_dynamic_cache() / unflatten_dynamic_cache() and shows that the layer types and the sliding_window parameter are preserved:

<<<

import torch
import transformers
from yobx.helpers import string_type
from yobx.torch.in_transformers.cache_helper import make_dynamic_cache
from yobx.torch.in_transformers.flatten_class import (
    flatten_dynamic_cache,
    unflatten_dynamic_cache,
)

# DynamicSlidingWindowLayer was added in transformers >= 4.50.
if not hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"):
    print("DynamicSlidingWindowLayer not available, skipping example.")
else:
    bsize, nheads, slen, dim = 2, 4, 3, 7
    sliding_window = slen  # sliding_window must be >= the sequence length

    # Build a DynamicCache with two different layer types.
    cache = make_dynamic_cache(
        [
            (
                torch.randn(bsize, nheads, slen, dim),
                torch.randn(bsize, nheads, slen, dim),
            ),
            (
                torch.randn(bsize, nheads, slen, dim),
                torch.randn(bsize, nheads, slen, dim),
            ),
        ],
        cls_layers=[
            transformers.cache_utils.DynamicLayer,
            transformers.cache_utils.DynamicSlidingWindowLayer,
        ],
        cls_kwargs=[{}, {"sliding_window": sliding_window}],
    )
    print("cache:", string_type(cache, with_shape=True))

    # Flatten (serialize) the cache — the context encodes each layer's type.
    flat, context = flatten_dynamic_cache(cache)
    print("context keys:", context)

    # Unflatten (deserialize) the cache — layer types are restored from the context.
    rebuilt = unflatten_dynamic_cache(flat, context)
    print("layer types:", [type(layer).__name__ for layer in rebuilt.layers])
    print("sliding_window:", rebuilt.layers[1].sliding_window)

>>>

    cache: DynamicCache(DynamicLayer(T1s2x4x3x7, T1s2x4x3x7), DynamicSlidingWindowLayer(T1s2x4x3x7, T1s2x4x3x7))
    context keys: ['key_D_0', 'value_D_0', 'key_W3_1', 'value_W3_1']
    layer types: ['DynamicLayer', 'DynamicSlidingWindowLayer']
    sliding_window: 3

List of supported classes#

See Flattening List.

See also

Registering a custom class as a pytree node — sphinx-gallery example demonstrating registration of a custom class and the round-trip flatten / unflatten.

Interesting Helpers — the MiniOnnxBuilder which serialises pytree-flattened tensors to ONNX.