yobx.torch.flatten#

yobx.torch.flatten.flattening_functions(patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0) Dict[type, Callable[[], bool]][source]#

Returns the list of flattening functions.

yobx.torch.flatten.make_flattening_function_for_dataclass(cls: type, supported_classes: Set[type]) Tuple[Callable, Callable, Callable][source]#

Automatically creates flattening functions for a class decorated with dataclasses.dataclass.

Parameters:
  • cls – the dataclass type

  • supported_classes – set to register the class into

Returns:

tuple of (flatten, flatten_with_keys, unflatten) callables

yobx.torch.flatten.register_cache_flattening(patch_transformers: bool = False, verbose: int = 0) Dict[type, bool][source]#

Registers many classes with yobx.torch.flatten.register_class_flattening(). Returns information needed to undo the registration.

Parameters:
  • patch_transformers – add flattening function for transformers package

  • patch_diffusers – add flattening function for diffusers package

  • verbosity – verbosity level

Returns:

information to unpatch

yobx.torch.flatten.register_class_flattening(cls: type, f_flatten: Callable, f_unflatten: Callable, f_flatten_with_keys: Callable, f_check: Callable | None = None, verbose: int = 0) bool[source]#

Registers a class. It can be undone with yobx.torch.flatten.unregister_class_flattening().

Parameters:
  • cls – class to register

  • f_flatten – see pytree.register_pytree_node

  • f_unflatten – see pytree.register_pytree_node

  • f_flatten_with_keys – see pytree.register_pytree_node

  • f_check – called to check the registration was successful

  • verbose – verbosity

Returns:

registered or not

yobx.torch.flatten.register_flattening_functions(patch_transformers: bool = False, verbose: int = 0) Generator[Callable, None, None][source]#

The context manager registers flattening functions the exporter needs to handle any custom class. This is used to create a signature taking only tensors as inputs even though the code shows nested structures.

from yobx.torch import register_flattening_functions

with register_flattening_functions(patch_transformers=True):
    # ...
yobx.torch.flatten.replacement_before_exporting(args: Any) Any[source]#

Does replacements on the given inputs if needed.

yobx.torch.flatten.unregister_cache_flattening(undo: Dict[type, bool], verbose: int = 0)[source]#

Undo the registration made by yobx.torch.flatten.register_cache_flattening().

yobx.torch.flatten.unregister_class_flattening(cls: type, verbose: int = 0)[source]#

Undo the registration for a class.