Registering a custom class as a pytree node#

torch.export.export() requires every object that appears as a model input or output to be decomposable into a flat list of torch.Tensor objects. torch.utils._pytree handles this decomposition, but it only knows about built-in Python containers. Custom classes — including all the cache types from transformers — must be explicitly registered before exporting.

This example walks through the three steps:

  1. Writing flatten / unflatten / flatten-with-keys callables for a custom dict-like class.

  2. Registering the class with register_class_flattening.

  3. Verifying the round-trip and then cleaning up with unregister_class_flattening.

See Flattening Functionalities (torch) for a full description of the flattening design including the transformers cache registrations.

from dataclasses import dataclass
import torch
import torch.utils._pytree
from yobx.torch.flatten import (
    make_flattening_function_for_dataclass,
    register_class_flattening,
    unregister_class_flattening,
)

1. Define a custom dict-like container#

We create a minimal dict subclass that stores named tensors. This pattern mirrors how transformers.modeling_outputs.ModelOutput works.

class EncoderOutput(dict):
    """Holds the output tensors produced by a (mock) encoder."""

2. Write the three required callables#

  • flatten — extract a flat list of tensors plus a context (the key order) that is needed to reconstruct the original object.

  • flatten_with_keys — same, but pair each tensor with a torch.utils._pytree.MappingKey so that torch.export.export() can refer to each leaf by name.

  • unflatten — given the flat tensors and the context, recreate the original EncoderOutput.

def flatten_encoder_output(obj):
    keys = list(obj.keys())
    return [obj[k] for k in keys], keys


def flatten_with_keys_encoder_output(obj):
    keys = list(obj.keys())
    values = [obj[k] for k in keys]
    return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(keys, values)], keys


def unflatten_encoder_output(values, context, output_type=None):
    return EncoderOutput(zip(context, values))

3. Register the class#

register_class_flattening wraps torch.utils._pytree.register_pytree_node and returns True when the registration succeeds (False when the class is already registered).

registered = register_class_flattening(
    EncoderOutput,
    flatten_encoder_output,
    unflatten_encoder_output,
    flatten_with_keys_encoder_output,
)
assert EncoderOutput in torch.utils._pytree.SUPPORTED_NODES
print("registered:", registered)
registered: True

4. Flatten a nested structure#

Once registered, torch.utils._pytree.tree_flatten can decompose any nested Python structure that contains EncoderOutput objects.

output = EncoderOutput(t1=torch.zeros(2, 5, 8), t2=torch.ones(2, 8))

flat, spec = torch.utils._pytree.tree_flatten(output)
print("number of leaf tensors:", len(flat))
for i, t in enumerate(flat):
    print(f"  leaf[{i}]: shape={tuple(t.shape)}, dtype={t.dtype}")
number of leaf tensors: 2
  leaf[0]: shape=(2, 5, 8), dtype=torch.float32
  leaf[1]: shape=(2, 8), dtype=torch.float32

5. Unflatten and verify the round-trip#

torch.utils._pytree.tree_unflatten() reconstructs the original EncoderOutput from the flat list using the spec returned by tree_flatten.

restored = torch.utils._pytree.tree_unflatten(flat, spec)
print("restored type :", type(restored).__name__)
print("restored keys :", list(restored.keys()))
assert torch.equal(restored["t1"], output["t1"])
assert torch.equal(restored["t2"], output["t2"])
print("round-trip OK")
restored type : EncoderOutput
restored keys : ['t1', 't2']
round-trip OK

6. Auto-generate callables with make_flattening_function_for_dataclass#

For classes that already expose .keys() / .values() (all transformers.modeling_outputs.ModelOutput subclasses do), make_flattening_function_for_dataclass generates the three required callables automatically.

@dataclass
class EncoderOutput2:
    """Holds the output tensors produced by a (mock) encoder."""

    t1: torch.Tensor
    t2: torch.Tensor


supported = set()
flatten_fn, flatten_with_keys_fn, unflatten_fn = make_flattening_function_for_dataclass(
    EncoderOutput2, supported
)

print("auto-generated names:")
print(" ", flatten_fn.__name__)
print(" ", flatten_with_keys_fn.__name__)
print(" ", unflatten_fn.__name__)
print("supported set:", {c.__name__ for c in supported})
auto-generated names:
  flatten_encoder_output2
  flatten_with_keys_encoder_output2
  unflatten_encoder_output2
supported set: {'EncoderOutput2'}

Let’s register.

registered = register_class_flattening(
    EncoderOutput2, flatten_fn, unflatten_fn, flatten_with_keys_fn
)
assert EncoderOutput2 in torch.utils._pytree.SUPPORTED_NODES
print("registered:", registered)
registered: True

New test.

output2 = EncoderOutput2(t1=torch.zeros(2, 5, 8), t2=torch.ones(2, 8))

flat, spec = torch.utils._pytree.tree_flatten(output)
restored = torch.utils._pytree.tree_unflatten(flat, spec)
print("restored type :", type(restored).__name__)
print("restored keys :", list(restored.keys()))
assert torch.equal(restored["t1"], output["t1"])
assert torch.equal(restored["t2"], output["t2"])
print("round-trip OK again")
restored type : EncoderOutput
restored keys : ['t1', 't2']
round-trip OK again

7. Unregister to restore the original state#

After exporting (or when running inside a test) call unregister_class_flattening to undo the registration and leave torch.utils._pytree.SUPPORTED_NODES exactly as it was before.

assert EncoderOutput in torch.utils._pytree.SUPPORTED_NODES
assert EncoderOutput2 in torch.utils._pytree.SUPPORTED_NODES
unregister_class_flattening(EncoderOutput)
unregister_class_flattening(EncoderOutput2)
print("EncoderOutput and EncoderOutput2 unregistered")
assert EncoderOutput not in torch.utils._pytree.SUPPORTED_NODES
assert EncoderOutput2 not in torch.utils._pytree.SUPPORTED_NODES
EncoderOutput and EncoderOutput2 unregistered

Plot: pytree flatten / unflatten diagram#

The diagram below illustrates the flatten→leaf-list→unflatten round-trip for an EncoderOutput container with two tensor fields.

import matplotlib.pyplot as plt  # noqa: E402
import matplotlib.patches as mpatches  # noqa: E402

fig, ax = plt.subplots(figsize=(7, 3.5))
ax.set_xlim(0, 10)
ax.set_ylim(0, 5)
ax.axis("off")
ax.set_title("pytree flatten / unflatten round-trip", fontsize=11)

# Container box
container = mpatches.FancyBboxPatch(
    (0.3, 1.5),
    2.6,
    2.0,
    boxstyle="round,pad=0.15",
    linewidth=1.5,
    edgecolor="#4c72b0",
    facecolor="#dce9f5",
)
ax.add_patch(container)
ax.text(1.6, 3.7, "EncoderOutput", ha="center", va="center", fontsize=9, fontweight="bold")
ax.text(1.6, 2.85, "t1: Tensor(2,5,8)", ha="center", va="center", fontsize=8)
ax.text(1.6, 2.35, "t2: Tensor(2,8)", ha="center", va="center", fontsize=8)

# Flat list box
flat_box = mpatches.FancyBboxPatch(
    (3.8, 1.5),
    2.6,
    2.0,
    boxstyle="round,pad=0.15",
    linewidth=1.5,
    edgecolor="#dd8452",
    facecolor="#fde8d8",
)
ax.add_patch(flat_box)
ax.text(5.1, 3.7, "flat list + spec", ha="center", va="center", fontsize=9, fontweight="bold")
ax.text(5.1, 2.85, "[Tensor(2,5,8),", ha="center", va="center", fontsize=8)
ax.text(5.1, 2.35, " Tensor(2,8)]", ha="center", va="center", fontsize=8)

# Restored box
restored_box = mpatches.FancyBboxPatch(
    (7.3, 1.5),
    2.4,
    2.0,
    boxstyle="round,pad=0.15",
    linewidth=1.5,
    edgecolor="#4c72b0",
    facecolor="#dce9f5",
)
ax.add_patch(restored_box)
ax.text(8.5, 3.7, "Restored", ha="center", va="center", fontsize=9, fontweight="bold")
ax.text(8.5, 2.85, "t1: Tensor(2,5,8)", ha="center", va="center", fontsize=8)
ax.text(8.5, 2.35, "t2: Tensor(2,8)", ha="center", va="center", fontsize=8)

# Arrows
ax.annotate(
    "",
    xy=(3.8, 2.5),
    xytext=(2.9, 2.5),
    arrowprops=dict(arrowstyle="->", color="#dd8452", lw=1.5),
)
ax.text(3.35, 2.75, "flatten", ha="center", va="bottom", fontsize=8, color="#dd8452")

ax.annotate(
    "",
    xy=(7.3, 2.5),
    xytext=(6.4, 2.5),
    arrowprops=dict(arrowstyle="->", color="#4c72b0", lw=1.5),
)
ax.text(6.85, 2.75, "unflatten", ha="center", va="bottom", fontsize=8, color="#4c72b0")

plt.tight_layout()
plt.show()
pytree flatten / unflatten round-trip

Total running time of the script: (0 minutes 0.090 seconds)

Related examples

InputObserver: recording inputs for ONNX export

InputObserver: recording inputs for ONNX export

Applying patches to a model and displaying the diff

Applying patches to a model and displaying the diff

Export a LLM with InputObserver (with Tiny-LLM)

Export a LLM with InputObserver (with Tiny-LLM)

Gallery generated by Sphinx-Gallery