Source code for onnx_diagnostic.torch_export_patches.serialization
import re
from typing import Any, Callable, List, Set, Tuple
import torch
def _lower_name_with_(name):
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
[docs]
def make_serialization_function_for_dataclass(
cls: type, supported_classes: Set[type]
) -> Tuple[Callable, Callable, Callable]:
"""
Automatically creates serialization function for a class decorated with
``dataclasses.dataclass``.
"""
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type]
"""Serializes a ``%s`` with python objects."""
return list(obj.values()), list(obj.keys())
def flatten_with_keys_cls(
obj: cls, # type: ignore[valid-type]
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
"""Serializes a ``%s`` with python objects with keys."""
values, context = list(obj.values()), list(obj.keys())
return [
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
], context
def unflatten_cls(
values: List[Any], context: torch.utils._pytree.Context, output_type=None
) -> cls: # type: ignore[valid-type]
"""Restores an instance of ``%s`` from python objects."""
return cls(**dict(zip(context, values)))
name = _lower_name_with_(cls.__name__)
flatten_cls.__name__ = f"flatten_{name}"
flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
unflatten_cls.__name__ = f"unflatten_{name}"
flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
supported_classes.add(cls)
return flatten_cls, flatten_with_keys_cls, unflatten_cls