yobx.torch.flatten#
- yobx.torch.flatten.flattening_functions(patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0) Dict[type, Callable[[], bool]][source]#
Returns the list of flattening functions.
- yobx.torch.flatten.make_flattening_function_for_dataclass(cls: type, supported_classes: Set[type]) Tuple[Callable, Callable, Callable][source]#
Automatically creates flattening functions for a class decorated with
dataclasses.dataclass.- Parameters:
cls – the dataclass type
supported_classes – set to register the class into
- Returns:
tuple of (flatten, flatten_with_keys, unflatten) callables
- yobx.torch.flatten.register_cache_flattening(patch_transformers: bool = False, verbose: int = 0) Dict[type, bool][source]#
Registers many classes with
yobx.torch.flatten.register_class_flattening(). Returns information needed to undo the registration.- Parameters:
patch_transformers – add flattening function for transformers package
patch_diffusers – add flattening function for diffusers package
verbosity – verbosity level
- Returns:
information to unpatch
- yobx.torch.flatten.register_class_flattening(cls: type, f_flatten: Callable, f_unflatten: Callable, f_flatten_with_keys: Callable, f_check: Callable | None = None, verbose: int = 0) bool[source]#
Registers a class. It can be undone with
yobx.torch.flatten.unregister_class_flattening().- Parameters:
cls – class to register
f_flatten – see
pytree.register_pytree_nodef_unflatten – see
pytree.register_pytree_nodef_flatten_with_keys – see
pytree.register_pytree_nodef_check – called to check the registration was successful
verbose – verbosity
- Returns:
registered or not
- yobx.torch.flatten.register_flattening_functions(patch_transformers: bool = False, verbose: int = 0) Generator[Callable, None, None][source]#
The context manager registers flattening functions the exporter needs to handle any custom class. This is used to create a signature taking only tensors as inputs even though the code shows nested structures.
from yobx.torch import register_flattening_functions with register_flattening_functions(patch_transformers=True): # ...
- yobx.torch.flatten.replacement_before_exporting(args: Any) Any[source]#
Does replacements on the given inputs if needed.