import ctypes
import sys
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
from onnx import GraphProto, ModelProto, TensorProto
import onnx.helper as oh
import onnx.numpy_helper as onh
from .helpers import string_type, tensor_dtype_to_np_dtype
STORAGE_TYPE = {
TensorProto.FLOAT16: np.int16,
TensorProto.BFLOAT16: np.int16,
}
def _get_type(elem_type: Any, exc: bool = True) -> int:
if not isinstance(elem_type, int):
st = str(elem_type)
if "float32" in st:
elem_type = TensorProto.FLOAT
elif "float64" in st:
elem_type = TensorProto.DOUBLE
elif "bfloat16" in st:
elem_type = TensorProto.BFLOAT16
elif "float16" in st:
elem_type = TensorProto.FLOAT16
elif "uint64" in st:
elem_type = TensorProto.UINT64
elif "int64" in st:
elem_type = TensorProto.INT64
elif "uint32" in st:
elem_type = TensorProto.UINT32
elif "int32" in st:
elem_type = TensorProto.INT32
elif "uint16" in st:
elem_type = TensorProto.UINT16
elif "int16" in st:
elem_type = TensorProto.INT16
elif "bool" in st:
elem_type = TensorProto.BOOL
elif "uint8" in st:
elem_type = TensorProto.UINT8
elif "int8" in st:
elem_type = TensorProto.INT8
elif "complex64" in st:
elem_type = TensorProto.COMPLEX64
elif "complex128" in st:
elem_type = TensorProto.COMPLEX128
elif elem_type is None:
elem_type = TensorProto.UNDEFINED
elif exc:
raise ValueError(f"Unable to interpret elem_type {elem_type!r}.")
return elem_type
[docs]
def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
"""
Converts a torch dtype into a onnx element type.
:param to: torch dtype
:return: onnx type
"""
import torch
if to == torch.float32:
return TensorProto.FLOAT
if to == torch.float16:
return TensorProto.FLOAT16
if to == torch.bfloat16:
return TensorProto.BFLOAT16
if to == torch.float64:
return TensorProto.DOUBLE
if to == torch.int64:
return TensorProto.INT64
if to == torch.int32:
return TensorProto.INT32
if to == torch.bool:
return TensorProto.BOOL
if to == torch.SymInt:
return TensorProto.INT64
if to == torch.SymFloat:
return TensorProto.FLOAT
if to == torch.complex64:
return TensorProto.COMPLEX64
if to == torch.complex128:
return TensorProto.COMPLEX128
raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
[docs]
def dtype_to_tensor_dtype(dt: "dtype") -> int: # noqa: F821
"""
Converts a torch dtype or numpy dtype into a onnx element type.
:param to: dtype
:return: onnx type
"""
try:
return oh.np_dtype_to_tensor_dtype(dt)
except (KeyError, TypeError):
pass
return torch_dtype_to_onnx_dtype(dt)
[docs]
def proto_from_array(
arr: "torch.Tensor", # noqa: F821
name: Optional[str] = None,
verbose: int = 0,
) -> TensorProto:
"""
Converts a torch Tensor into a TensorProto.
:param arr: tensor
:param verbose: display the type and shape
:return: a TensorProto
"""
import sys
import torch
if not isinstance(arr, torch.Tensor):
raise TypeError(f"Unexpected type {type(arr)}.")
if arr.is_sparse:
raise NotImplementedError(
f"Sparse tensor is not supported yet but initializer {name!r} is."
)
# arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
if arr.is_contiguous():
arr_cpu = arr.cpu()
else:
arr_cpu = arr.contiguous().cpu()
numel = torch.numel(arr_cpu)
element_size = arr_cpu.element_size()
if arr_cpu.dtype in {torch.bfloat16}:
np_arr = arr_cpu
elif arr_cpu.data_ptr() == arr.data_ptr():
copy = arr_cpu.clone().detach().requires_grad_(False)
assert arr_cpu.data_ptr() != copy.data_ptr()
np_arr = np.from_dlpack(copy)
else:
np_arr = np.from_dlpack(arr_cpu.detach())
tensor = TensorProto()
tensor.dims.extend(arr_cpu.shape)
tensor.name = name
itype = _get_type(arr_cpu.dtype)
assert not hasattr(TensorProto, "INT4") or itype not in {
TensorProto.INT4,
TensorProto.UINT4,
}, f"Type {arr.dtype} is not supported yet for name={name!r}"
tensor.data_type = itype
if verbose > 1 and numel > 100:
print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
if isinstance(np_arr, torch.Tensor):
byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
tensor.raw_data = bytes(byte_data)
if sys.byteorder == "big":
np_dtype = tensor_dtype_to_np_dtype(STORAGE_TYPE[tensor.data_type])
np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
else:
tensor.raw_data = np_arr.tobytes()
if sys.byteorder == "big":
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
return tensor
[docs]
class MiniOnnxBuilder:
"""
Simplified builder to build very simple model.
:param target_opset: opset to specify
:param ir_verison: IR version to use
:param sep: separator to build output names
"""
def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___"):
import torch
self.initializers_dict = {}
self.inputs = []
self.outputs = []
self.nodes = []
self.opsets = {"": target_opset}
self.ir_version = ir_version
self.torch = torch
self.sep = sep
[docs]
def append_output_initializer(
self,
name: str,
tensor: Union[np.ndarray, "torch.Tensor"], # noqa: F821
randomize: bool = False,
): # noqa: F821
"""
Adds an initializer as an output.
The initializer name is prefixed by ``t_``.
The output name is *name*.
If `randomize` is True, the tensor is not stored but replaced by a random generator.
"""
if randomize:
dtype = dtype_to_tensor_dtype(tensor.dtype)
if dtype in {
TensorProto.FLOAT,
TensorProto.FLOAT16,
TensorProto.DOUBLE,
TensorProto.BFLOAT16,
}:
mini, maxi = tensor.min(), tensor.max()
if mini < 0 and maxi > 0:
op_type = "RandomNormal"
kwargs = {
"mean": float(tensor.mean()),
"scale": float(tensor.std()),
"seed": 0.0,
}
else:
op_type = "RandomUniform"
kwargs = {
"low": float(mini),
"high": float(maxi),
"seed": 0.0,
}
shape = tuple(map(int, tensor.shape))
self.nodes.append(
oh.make_node(op_type, [], [name], dtype=dtype, shape=shape, **kwargs)
)
self.outputs.append(oh.make_tensor_value_info(name, dtype, shape))
return
init_name = f"t_{name}"
self.initializers_dict[init_name] = tensor
shape = tuple(map(int, tensor.shape))
self.outputs.append(
oh.make_tensor_value_info(name, dtype_to_tensor_dtype(tensor.dtype), shape)
)
self.nodes.append(oh.make_node("Identity", [init_name], [name]))
[docs]
def append_output_sequence(
self, name: str, tensors: List[Union[np.ndarray, "torch.Tensor"]] # noqa: F821
): # noqa: F821
"""
Adds a sequence of initializers as an output.
The initializers names are prefixed by ``seq_``.
The output name is ``name``.
"""
if not tensors:
# empty list
self.nodes.append(oh.make_node("SequenceEmpty", [], [name]))
tensor_type_proto = oh.make_tensor_type_proto(
elem_type=TensorProto.FLOAT, shape=None
)
else:
assert all(
isinstance(t, (np.ndarray, self.torch.Tensor)) for t in tensors
), f"Nested sequences are not supported, types are {[type(t) for t in tensors]}"
names = []
for i, t in enumerate(tensors):
init_name = f"seq_{name}_{i}"
self.initializers_dict[init_name] = t
names.append(init_name)
self.nodes.append(oh.make_node("SequenceConstruct", names, [name]))
tensor_type_proto = oh.make_tensor_type_proto(
elem_type=dtype_to_tensor_dtype(tensors[0].dtype), shape=None
)
sequence_type_proto = oh.make_sequence_type_proto(tensor_type_proto)
output = oh.make_value_info(name, type_proto=sequence_type_proto)
self.outputs.append(output)
[docs]
def append_output_dict(
self, name: str, tensors: Dict[str, Union[np.ndarray, "torch.Tensor"]] # noqa: F821
): # noqa: F821
"""
Adds two outputs, a string tensors for the keys and a sequence of tensors
for the values.
The output name is ``name___keys`` and ``name___values``.
"""
keys = []
values = []
for k, v in tensors.items():
keys.append(k)
values.append(v)
self.append_output_initializer(f"{name}{self.sep}keys", np.array(keys, dtype=np.str_))
self.append_output_sequence(f"{name}{self.sep}values", values)
def _build_initializers(
self, switch_low_high: bool
) -> Tuple[List[TensorProto], Dict[str, TensorProto]]:
"""
Builds initializers.
:param switch_low_high: invert low, high precision
:return: a list of tensors to stored in the model
"""
init_dict = self.initializers_dict
if switch_low_high:
# Let's try to minimize the time.
initializer = []
for k, v in init_dict.items():
if isinstance(v, TensorProto):
initializer.append(v)
continue
if isinstance(v, np.ndarray):
itype = dtype_to_tensor_dtype(v.dtype)
if itype in {
TensorProto.BOOL,
TensorProto.STRING,
TensorProto.UNDEFINED,
TensorProto.COMPLEX64,
TensorProto.COMPLEX128,
getattr(TensorProto, "UINT4", 0),
getattr(TensorProto, "INT4", 0),
}:
t = onh.from_array(v, name=k)
initializer.append(t)
continue
from_np = True
elif isinstance(v, np.float32):
t = onh.from_array(np.array([v], dtype=np.float32), name=k)
initializer.append(t)
continue
elif isinstance(v, np.float64):
t = onh.from_array(np.array([v], dtype=np.float64), name=k)
initializer.append(t)
continue
elif isinstance(v, np.float16):
t = onh.from_array(np.array([v], dtype=np.float16), name=k)
initializer.append(t)
continue
else:
assert isinstance(
v, self.torch.Tensor
), f"tensor {k!r} has un unexpected type {type(v)}"
assert "FakeTensor" not in str(
type(v)
), f"tensor {k!r} cannot be a FakeTensor: {type(v)}"
from_np = False
itype = dtype_to_tensor_dtype(v.dtype)
# How to avoid a copy?
if from_np:
tensor = TensorProto()
tensor.name = k
tensor.dims.extend(v.shape)
tensor.data_type = itype
tensor.raw_data = v.tobytes()
else:
tensor = proto_from_array(v, name=k)
initializer.append(tensor)
return initializer
res = []
for k, v in init_dict.items():
if isinstance(v, TensorProto):
res.append(v)
continue
if isinstance(v, self.torch.Tensor):
# no string tensor
t = self.from_array(v, name=k)
res.append(t)
continue
if isinstance(v, np.ndarray):
t = onh.from_array(v, name=k)
res.append(t)
continue
raise TypeError(
f"Unable to convert initializer {k!r} with type "
f"{type(v)} into a TensorProto."
)
return res
[docs]
def to_onnx(self) -> ModelProto:
"""
Conversion to onnx.
:return: the proto
"""
opsets = [oh.make_opsetid(*o) for o in self.opsets.items()]
ir_version = self.ir_version
model = ModelProto()
model.graph.CopyFrom(GraphProto())
model.graph.name = "mini_model"
model.graph.input.extend(self.inputs)
model.graph.node.extend(self.nodes)
model.graph.output.extend(self.outputs)
initializers = self._build_initializers(switch_low_high=sys.byteorder != "big")
model.graph.initializer.extend(initializers)
model.opset_import.extend(opsets)
model.ir_version = ir_version
return model
[docs]
def flatten_iterator(obj: Any, sep: str) -> Iterator:
"""
Iterates on all object.
"""
if obj is not None:
import torch
if isinstance(obj, np.ndarray):
yield "array", obj
elif isinstance(obj, torch.Tensor):
yield "tensor", obj
elif isinstance(obj, bool):
yield "bool", np.array([obj], dtype=np.bool_)
elif isinstance(obj, int):
yield "int", np.array([obj], dtype=np.int64)
elif isinstance(obj, float):
yield "float", np.array([obj], dtype=np.float64)
elif isinstance(obj, tuple):
if not obj:
yield f"tuple.{sep}empty", None
else:
for i, o in enumerate(obj):
if i == len(obj) - 1:
for p, oo in flatten_iterator(o, sep):
yield f"tuple.{sep}{p}", oo
else:
for p, oo in flatten_iterator(o, sep):
yield f"tuple{sep}{p}", oo
elif isinstance(obj, list):
if not obj:
yield f"list.{sep}empty", None
else:
for i, o in enumerate(obj):
if i == len(obj) - 1:
for p, oo in flatten_iterator(o, sep):
yield f"list.{sep}{p}", oo
else:
for p, oo in flatten_iterator(o, sep):
yield f"list{sep}{p}", oo
elif isinstance(obj, dict):
if not obj:
yield f"dict.{sep}empty", None
else:
for i, (k, v) in enumerate(obj.items()):
assert sep not in k, (
f"Key {k!r} cannot contain '{sep}'. "
f"It would interfer with the serialization."
)
if i == len(obj) - 1:
for p, o in flatten_iterator(v, sep):
yield f"dict._{k}{sep}{p}", o
else:
for p, o in flatten_iterator(v, sep):
yield f"dict_{k}{sep}{p}", o
elif obj.__class__.__name__ == "DynamicCache":
# transformers
import transformers
assert isinstance(
obj, transformers.cache_utils.DynamicCache
), f"Unexpected type {type(obj)}"
atts = ["key_cache", "value_cache"]
for i, att in enumerate(atts):
if i == len(atts) - 1:
for p, o in flatten_iterator(getattr(obj, att), sep):
yield f"DynamicCache._{att}{sep}{p}", o
else:
for p, o in flatten_iterator(getattr(obj, att), sep):
yield f"DynamicCache_{att}{sep}{p}", o
else:
raise NotImplementedError(f"Unexpected type {type(obj)}")
[docs]
def unflatten(
sep: str,
names: List[str],
outputs: List[Any],
pos: int = 0,
level: int = 0,
device: str = "cpu",
):
"""
Unflattens a list of outputs flattended with :func:`flatten_iterator`.
"""
name = names[pos]
spl = name.split(sep)
if len(spl) == level + 1:
# A tensor.
if spl[-1] == "empty":
return pos + 1, None
if spl[-1] == "bool":
return pos + 1, bool(outputs[pos][0])
if spl[-1] == "int":
return pos + 1, int(outputs[pos][0])
if spl[-1] == "float":
return pos + 1, float(outputs[pos][0])
if spl[-1] == "array":
return pos + 1, outputs[pos]
if spl[-1] == "tensor":
import torch
return pos + 1, torch.from_numpy(outputs[pos]).to(device)
raise AssertionError(f"Unexpected name {name!r} in {names}")
res = []
while True:
assert pos < len(names), f"Something went wrong with names={names!r}\nres={res!r}"
name = names[pos]
spl = name.split(sep)
prefix = spl[level]
next_pos, value = unflatten(
sep, names, outputs, pos=pos, level=level + 1, device=device
)
if prefix.startswith("DynamicCache"):
key = prefix.split("_", maxsplit=1)[-1]
res.append((key, value))
lp = len("DynamicCache")
end = len(prefix) > lp and prefix[lp] == "."
elif prefix.startswith("dict"):
key = prefix.split("_", maxsplit=1)[-1]
res.append((key, value))
end = len(prefix) > 4 and prefix[4] == "."
else:
res.append(value)
end = prefix[-1] == "."
if end:
if prefix.startswith("dict"):
ty = dict
elif prefix.startswith("list"):
ty = list
elif prefix.startswith("tuple"):
ty = tuple
elif prefix.startswith("DynamicCache"):
from transformers.cache_utils import DynamicCache
ty = DynamicCache
else:
raise AssertionError(f"Unexpected prefix={prefix!r}")
break
pos = next_pos
def _make(ty: type, res: Any) -> Any:
if ty.__name__ == "DynamicCache":
r = ty()
for k, v in res:
setattr(r, k, v)
return r
return ty(res)
return next_pos, (
ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
)