import base64
import copy
import io
import pickle
import re
import zlib
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from ..helpers import string_type
[docs]
def choose_kwargs_for_dynamic_shapes(
ds_args: Tuple[Dict[int, Any], ...],
ds_kwargs: Dict[str, Dict[int, Any]],
forward_names: List[str],
) -> Dict[str, Dict[int, Any]]:
"""
Chooses a dictionary to express dynamic shapes when the module uses both
unnamed arguments and named ones.
"""
assert ds_args and ds_kwargs and forward_names, (
f"No need to choose as ds_args={ds_args} or ds_kwargs={ds_kwargs} "
f"or forward_names={forward_names} is None."
)
if len(forward_names) >= len(ds_args):
# No *args.
ds = ds_kwargs
for k, v in zip(forward_names, ds_args):
ds[k] = v
return ds
raise NotImplementedError(
f"forward_positioned_parameter_names={forward_names}, "
f"ds_args={ds_args}, ds_kwargs={ds_kwargs}"
)
[docs]
def serialize_one(
obj: Any, name: Union[str, int], schema: str
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Serializes one object into a tensor or a list of tensors.
*name* and *schema* are just better error messages.
"""
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (tuple, list)) and all(isinstance(t, torch.Tensor) for t in obj):
return obj
if isinstance(obj, dict) and all(isinstance(t, torch.Tensor) for t in obj.values()):
sorted_items = sorted(obj.items())
return [_[1] for _ in sorted_items]
if obj.__class__.__name__ in {"DynamicCache", "patched_DynamicCache"}:
return [*obj.key_cache, *obj.value_cache]
if obj is None:
return None
if isinstance(obj, (bool, int, float)):
# This cannot be traced.
return obj
# Let's flatten the structure using pytorch.
flat_list, _spec = torch.utils._pytree.tree_flatten(obj)
return flat_list
[docs]
def serialize_args(
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]],
schema: str,
args_names: Optional[List[str]] = None,
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
"""
Serializes args and kwargs before calling a custom ops.
:param args: unnamed arguments
:param kwargs: named arguments
:param schema: schema function
:param args_names: ordered argument names, it must be specified
if *kwargs* is specified
"""
if isinstance(args, torch.Tensor):
new_args = args
n_args = 1
is_tensor = True
elif isinstance(args, dict) and all(isinstance(t, torch.Tensor) for t in args.values()):
sorted_items = sorted(args.items())
new_args = [_[1] for _ in sorted_items]
n_args = len(new_args)
is_tensor = False
elif isinstance(args, (list, tuple)) and all(isinstance(i, torch.Tensor) for i in args):
new_args = []
for i, a in enumerate(args):
r = serialize_one(a, name=i, schema=schema)
if r is None or isinstance(r, torch.Tensor):
new_args.append(r)
else:
new_args.extend(r)
new_args = tuple(new_args)
n_args = len(new_args)
is_tensor = False
else:
# Let's use pytorch serialization.
new_args, spec = torch.utils._pytree.tree_flatten(args)
n_args = len(new_args)
is_tensor = False
if not kwargs:
if kwargs is None:
return new_args if is_tensor else tuple(new_args)
return new_args, {}
assert args_names is not None or schema, (
f"Not implemented when args_names={args_names}, "
f"args={string_type(args, with_shape=True)}, "
f"kwargs={string_type(kwargs, with_shape=True)}, "
f"schema={schema!r}"
)
if args_names is None:
args_names = extract_names_from_schema(schema)
new_args = [new_args] if is_tensor else list(new_args)
args_names = args_names[n_args:]
assert args_names, (
f"kwargs={string_type(kwargs, with_shape=True)} is specified "
f"then args_names should be as well, schema={schema!r}, "
f"args={string_type(args, with_shape=True)}"
)
# handling arguments
for name in args_names:
if name not in kwargs:
new_args.append(None)
continue
v = kwargs[name]
r = serialize_one(v, name=name, schema=schema)
if r is None or isinstance(r, (torch.Tensor, int, bool, float)):
new_args.append(r)
else:
new_args.extend(r)
# remaining arguments (**kwargs, *args)
set_names = set(args_names)
for k, v in kwargs.items():
if k in set_names:
continue
if v is None:
new_args.append(None)
r = serialize_one(v, name=name, schema=schema)
if r is None or isinstance(r, torch.Tensor):
new_args.append(r)
else:
new_args.extend(r)
return tuple(new_args), {}
[docs]
def type_as_str_with_info(obj: Any) -> str:
"""Returns a string with information about how to deserialize."""
if isinstance(obj, torch.Tensor):
return "Tensor"
if isinstance(obj, list) and all(isinstance(t, torch.Tensor) for t in obj):
return f"list__{len(obj)}"
if isinstance(obj, dict) and all(isinstance(t, torch.Tensor) for t in obj.values()):
sorted_keys = "__".join(sorted(obj))
return f"dict__{len(obj)}_{sorted_keys}"
if obj.__class__.__name__ in {"DynamicCache", "patched_DynamicCache"}:
return f"{obj.__class__.__name__}__{len(obj.key_cache)}_{len(obj.value_cache)}"
if obj is None:
return "None"
if isinstance(obj, bool):
return "bool"
if isinstance(obj, float):
return "float"
if isinstance(obj, int):
return "int"
# Let's use pytorch serialization.
flat_list, spec = torch.utils._pytree.tree_flatten(obj)
return tree_spec_as_name(spec, len(flat_list))
[docs]
def deserialize_args(
res: List[torch.Tensor],
expected_types: List[str],
clone: bool = False,
return_n_args: bool = False,
) -> Tuple[Any, ...]:
"""
Deserizalizes output results coming from the custom op and restores
the python classes attached to it.
:param res: args to deserialize
:param expected_types: information on how to deserialize
:param clone: clone tensors before returning them
:param return_n_args: if True, the function returns the number of deserialized arguments,
if False, it assumes this number if equal to the number of expected types
:return: new args
"""
assert isinstance(res, (list, tuple, torch.Tensor)), f"unexpected type for res {type(res)}"
if isinstance(res, torch.Tensor):
assert expected_types == [
"Tensor"
], f"Mismatch information, expected_types={expected_types!r}"
return res
assert all(
t is None or isinstance(t, (list, torch.Tensor, bool, int, float)) for t in res
), f"unexpected element type in res: {string_type(res)}"
des = []
pos_res = 0
for tt in expected_types:
if tt in (None, "None"):
des.append(None)
pos_res += 1
continue
if tt == "Tensor":
des.append(res[pos_res].clone() if clone else res[pos_res])
pos_res += 1
continue
if tt in ("bool", "int", "float"):
des.append(res[pos_res])
pos_res += 1
continue
if tt.startswith("dict__"):
nl = tt[6:].split("_", maxsplit=1)
n = int(nl[0])
keys = nl[1].split("__")
assert len(keys) == n, f"Unable to parse {tt!r}, expecting {n} keys but got {keys}"
value = res[pos_res]
if isinstance(res[pos_res], torch.Tensor):
values = res[pos_res : pos_res + n]
des.append(dict(zip(keys, values)))
pos_res += n
else:
des.append(dict(zip(keys, value)))
pos_res += 1
continue
if tt.startswith("list__"):
n = int(tt[6:])
value = res[pos_res]
if isinstance(res[pos_res], torch.Tensor):
values = res[pos_res : pos_res + n]
des.append(values)
pos_res += n
else:
des.append(value)
pos_res += 1
continue
if tt.startswith(("DynamicCache__", "patched_DynamicCache__")):
info = tt.split("__")[-1]
n1, n2 = tuple(map(int, info.split("_")))
assert n1 == n2, f"Unexpected sizes for n1={n1} and n2={n2} for a DynamicCache"
if isinstance(res[pos_res], torch.Tensor):
# All flattened.
key_cache = res[pos_res : pos_res + n1]
value_cache = res[pos_res + n1 : pos_res + n1 + n2]
pos_res += n1 + n2
else:
value = res[pos_res]
assert isinstance(value, list) and all(
isinstance(t, torch.Tensor) for t in value
), (
f"Unexpected type at position {pos_res}: "
f"{string_type(value, with_shape=True)}, "
f"deserialized into {tt}"
)
assert len(value) % 2 == 0 and len(value) == n1 + n2, (
f"Number of tensors at position {pos_res} "
f"in {string_type(value, with_shape=True)} "
f"should be even. Unable to deserialize into {tt}, "
f"n1={n1}, n2={n2}, len(res[pos_res])={len(value)}"
)
key_cache = value[:n1]
value_cache = value[n1:]
pos_res += 1
if tt.startswith("DynamicCache__"):
import transformers
cache = transformers.cache_utils.DynamicCache()
elif tt.startswith("patched_DynamicCache__"):
from .patches.patch_transformers import patched_DynamicCache
cache = patched_DynamicCache()
elif tt is None:
cache = None
else:
raise NotImplementedError(f"Unable to handle type info(2) {tt!r}")
if clone:
cache.key_cache = [t.clone() for t in key_cache]
cache.value_cache = [t.clone() for t in value_cache]
else:
cache.key_cache = key_cache
cache.value_cache = value_cache
des.append(cache)
continue
if tt.startswith("___"):
n_args, spec = tree_spec_from_name(tt)
obj = torch.utils._pytree.tree_unflatten(res[pos_res : pos_res + n_args], spec)
pos_res += n_args
des.append(obj)
continue
raise NotImplementedError(
f"Unable to handle type info {tt!r}, "
f"expected_types={expected_types!r}, "
f"res={string_type(res, with_shape=True)}, "
f"des={string_type(des, with_shape=True)}"
)
if return_n_args:
return des, pos_res
assert pos_res == len(res) or (
pos_res == len(res) - 1 and isinstance(res[-1], list) and len(res[-1]) == 0
), (
f"Deserialization went wrong, pos_res={pos_res}, len(res)={len(res)}, "
f"expected_types={expected_types}, "
f"input types={string_type(res)}"
)
return des
[docs]
def deserialize_args_kwargs(
args: List[torch.Tensor],
kwargs: Dict[str, Any],
expected_types: Tuple[List[str], List[str]],
clone: bool = False,
ordered_names: Optional[List[str]] = None,
fill_kwargs: bool = False,
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
Deserializes a list of tensor or list of tensors into args and kwargs.
*kwargs* should be empty since this type is allowed as a serialized type.
:param args: arguments
:param kwargs: named arguments, they should be empty
:param expected_types: needed to understand how to deserialize
:param clone: clone every tensor
:param ordered_names: ordered need to restore **kwargs
:param fill_kwargs: if True, the last parameter is **kwargs
and it should be empty
:return: new args, new named args
"""
assert not kwargs, (
f"inputs coming from C++ functions should not have "
f"named arguments but kwargs={string_type(kwargs, with_shape=True)}."
)
assert (
isinstance(expected_types, tuple)
and not kwargs
and len(expected_types) == 2
and (not expected_types[1] or ordered_names)
), (
f"Unexpected value for expected_types={expected_types}, "
f"args={string_type(args, with_shape=True)}, "
f"kwargs={string_type(kwargs, with_shape=True)}, "
f"ordered_names={ordered_names}"
)
if expected_types[1]:
new_args, n_args = deserialize_args(
args, expected_types[0], clone=clone, return_n_args=True
)
left_args = args[n_args:]
left_names = ordered_names[n_args:]
new_kwargs = {}
pos_res = 0
for name in left_names:
if name not in expected_types[1]:
# no **kwargs
continue
if expected_types[1][name] == "Tensor":
new_kwargs[name] = left_args[pos_res]
pos_res += 1
continue
if expected_types[1][name] in ("bool", "int", "float", "None"):
new_kwargs[name] = left_args[pos_res]
pos_res += 1
continue
assert isinstance(expected_types[1][name], str), (
f"Unexpected type {type(expected_types[1][name])} for name={name!r} in "
f"expected_types={expected_types}"
)
a, n = deserialize_args(
left_args[pos_res:], [expected_types[1][name]], clone=clone, return_n_args=True
)
assert len(a) == 1, (
f"Unexpected length, a={string_type(a, limit=20)}, "
f"expected_types[1][name]={expected_types[1][name]!r}"
)
pos_res += n
new_kwargs[name] = a[0]
assert pos_res + n_args + (1 if fill_kwargs else 0) == len(args), (
f"Deserialization went wrong, pos_res={pos_res + n_args}, "
f"n_args={n_args}, len(args)={len(args)}, "
f"\nfill_kwargs={fill_kwargs}, "
f"\nexpected_types={expected_types}, "
f"\nargs={string_type(args, limit=20)}, "
f"\nnew_args={string_type(new_args, limit=20)}, "
f"\nnew_kwargs={string_type(new_kwargs)}, "
f"\nordered_names={ordered_names}"
)
return new_args, new_kwargs
new_args = deserialize_args(args, expected_types[0], clone=clone)
return new_args, {}
[docs]
def make_copy(obj: Any) -> Any:
"""Makes a copy of the objects."""
if isinstance(obj, np.ndarray):
return obj.copy()
if isinstance(obj, tuple):
return tuple(make_copy(_) for _ in obj)
if isinstance(obj, list):
return [make_copy(_) for _ in obj]
if isinstance(obj, dict):
return {k: make_copy(v) for k, v in obj.items()}
if hasattr(obj, "clone"):
return obj.clone()
if obj.__class__.__name__ in ("DynamicCache", "patched_DynamicCache"):
cache = obj.__class__()
if hasattr(obj, "_seen_tokens"):
cache._seen_tokens = obj._seen_tokens
cache.key_cache = make_copy(obj.key_cache)
cache.value_cache = make_copy(obj.value_cache)
return cache
try:
return copy.deepcopy(obj)
except RuntimeError as e:
raise RuntimeError(
f"deepcopy did not work on type {type(obj)}: {string_type(obj)}"
) from e
[docs]
def tree_spec_as_name(tree_spec: torch.utils._pytree.TreeSpec, n_elements: int) -> str:
"""
Returns a string containing all the information needed to code
an instance of TreeSpec in a string.
:param tree_spec: instance of TreeSpec to convert into a name
:param n_elements: number of elements it serializes
:return: string
"""
buffer = io.BytesIO()
pickle.dump((n_elements, tree_spec), buffer)
data = buffer.getvalue()
compressed_data = zlib.compress(data)
chars = base64.b32encode(compressed_data).decode("utf-8")
chars = chars.replace("=", "_")
return f"___{chars}"
[docs]
def tree_spec_from_name(name: str) -> Tuple[int, torch.utils._pytree.TreeSpec]:
"""
Restores the instance of TreeSpec converted into a name by
function :func:`tree_spec_as_name`.
:param name: name of the TreeSpec
:return: instance of TreeSpec, number of elements it contains
"""
assert len(name) > 3 and name[:3] == "___", f"Incorrect name {name!r}"
name = name[3:].replace("_", "=")
buffer = base64.b32decode(name)
unzip = zlib.decompress(buffer)
n_elements, treespec = pickle.load(io.BytesIO(unzip))
return n_elements, treespec