import ctypes
import sys
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from .onnx_helper import dtype_to_tensor_dtype, tensor_dtype_to_np_dtype
from . import string_type
[docs]
def proto_from_array(
arr: "torch.Tensor", # type: ignore[name-defined] # noqa: F821
name: Optional[str] = None,
verbose: int = 0,
) -> onnx.TensorProto:
"""
Converts a torch Tensor into a onnx.TensorProto.
:param arr: tensor
:param verbose: display the type and shape
:return: a onnx.TensorProto
"""
import torch
if not isinstance(arr, torch.Tensor):
raise TypeError(f"Unexpected type {type(arr)}.")
if arr.is_sparse:
raise NotImplementedError(
f"Sparse tensor is not supported yet but initializer {name!r} is."
)
# arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
arr_cpu = arr.cpu() if arr.is_contiguous() else arr.contiguous().cpu()
import torch
numel = torch.numel(arr_cpu)
element_size = arr_cpu.element_size()
import torch
if arr_cpu.dtype in {torch.bfloat16}:
np_arr = arr_cpu
elif arr_cpu.data_ptr() == arr.data_ptr():
copy = arr_cpu.clone().detach().requires_grad_(False)
assert (
arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr()
), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}"
np_arr = np.from_dlpack(copy) # type: ignore[assignment]
else:
np_arr = np.from_dlpack(arr_cpu.detach()) # type: ignore[assignment]
tensor = onnx.TensorProto()
tensor.dims.extend(arr_cpu.shape)
tensor.name = name or ""
itype = dtype_to_tensor_dtype(arr_cpu.dtype)
assert not hasattr(onnx.TensorProto, "INT4") or itype not in {
onnx.TensorProto.INT4,
onnx.TensorProto.UINT4,
}, f"Type {arr.dtype} is not supported yet for name={name!r}"
tensor.data_type = itype
if verbose > 1 and numel > 100:
print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
import torch
if isinstance(np_arr, torch.Tensor):
byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
tensor.raw_data = bytes(byte_data)
if sys.byteorder == "big":
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
else:
tensor.raw_data = np_arr.tobytes()
if sys.byteorder == "big":
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
return tensor
[docs]
class MiniOnnxBuilder:
"""
Simplified builder to create very simple ONNX models that store
tensors (numpy arrays or torch tensors) as initializers and expose
them as model outputs. The resulting model has **no inputs** — it
simply returns the stored values when executed.
:param target_opset: default ONNX opset version (default: 18)
:param ir_version: ONNX IR version (default: 10)
:param sep: separator used to build composite output names
(default: ``"___"``)
Typical usage — save a plain numpy array and round-trip it through ONNX:
.. runpython::
:showcode:
import numpy as np
from yobx.helpers.mini_onnx_builder import MiniOnnxBuilder
from yobx.reference import ExtendedReferenceEvaluator
builder = MiniOnnxBuilder()
builder.append_output_initializer("weights", np.array([1.0, 2.0, 3.0], dtype=np.float32))
model = builder.to_onnx()
ref = ExtendedReferenceEvaluator(model)
(weights,) = ref.run(None, {})
print(weights) # np.array([1.0, 2.0, 3.0], dtype=np.float32)
For serializing arbitrary nested Python structures (dicts, tuples, lists,
torch tensors, ``DynamicCache`` …) prefer the higher-level helpers
:func:`create_onnx_model_from_input_tensors` and
:func:`create_input_tensors_from_onnx_model`.
"""
def __init__(self, target_opset: int = 18, ir_version: int = 10, sep: str = "___"):
self.initializers_dict: Dict[str, Any] = {}
self.inputs: List[Any] = []
self.outputs: List[Any] = []
self.nodes: List[onnx.NodeProto] = []
self.opsets = {"": target_opset}
self.ir_version = ir_version
self.sep = sep
[docs]
def append_output_initializer(
self,
name: str,
tensor: Union[np.ndarray, "torch.Tensor"], # type: ignore[name-defined] # noqa: F821
randomize: bool = False,
):
"""
Adds an initializer as an output.
The initializer name is prefixed by ``t_``.
The output name is *name*.
If `randomize` is True, the tensor is not stored but replaced by a random generator.
"""
if randomize:
dtype = dtype_to_tensor_dtype(tensor.dtype)
if dtype in {
onnx.TensorProto.FLOAT,
onnx.TensorProto.FLOAT16,
onnx.TensorProto.DOUBLE,
onnx.TensorProto.BFLOAT16,
}:
mini, maxi = tensor.min(), tensor.max()
if mini < 0 and maxi > 0:
op_type = "RandomNormal"
kwargs = {
"mean": float(tensor.mean()),
"scale": float(tensor.std()),
"seed": 0.0,
}
else:
op_type = "RandomUniform"
kwargs = {
"low": float(mini),
"high": float(maxi),
"seed": 0.0,
}
shape = tuple(map(int, tensor.shape))
self.nodes.append(
# pyrefly: ignore[bad-argument-type]
oh.make_node(op_type, [], [name], dtype=dtype, shape=shape, **kwargs) # type: ignore[arg-type]
)
self.outputs.append(oh.make_tensor_value_info(name, dtype, shape))
return
init_name = f"t_{name}"
assert (
init_name not in self.initializers_dict
), f"name={init_name!r} already in {sorted(self.initializers_dict)}"
self.initializers_dict[init_name] = tensor
shape = tuple(map(int, tensor.shape))
self.outputs.append(
oh.make_tensor_value_info(name, dtype_to_tensor_dtype(tensor.dtype), shape)
)
self.nodes.append(oh.make_node("Identity", [init_name], [name]))
[docs]
def append_output_sequence(
self, name: str, tensors: List[Union[np.ndarray, "torch.Tensor"]] # type: ignore[name-defined] # noqa: F821
):
"""
Adds a sequence of initializers as an output.
The initializers names are prefixed by ``seq_``.
The output name is ``name``.
"""
if not tensors:
# empty list
self.nodes.append(
oh.make_node("SequenceEmpty", [], [name], dtype=onnx.TensorProto.FLOAT)
)
tensor_type_proto = oh.make_tensor_type_proto(
elem_type=onnx.TensorProto.FLOAT, shape=None
)
else:
import torch
assert all(
isinstance(t, (np.ndarray, torch.Tensor)) for t in tensors
), f"Nested sequences are not supported, types are {[type(t) for t in tensors]}"
names = []
for i, t in enumerate(tensors):
init_name = f"seq_{name}_{i}"
self.initializers_dict[init_name] = t
names.append(init_name)
self.nodes.append(oh.make_node("SequenceConstruct", names, [name]))
tensor_type_proto = oh.make_tensor_type_proto(
elem_type=dtype_to_tensor_dtype(tensors[0].dtype), shape=None
)
sequence_type_proto = oh.make_sequence_type_proto(tensor_type_proto)
output = oh.make_value_info(name, type_proto=sequence_type_proto)
self.outputs.append(output)
[docs]
def append_output_dict(
self, name: str, tensors: Dict[str, Union[np.ndarray, "torch.Tensor"]] # type: ignore[name-defined] # noqa: F821
):
"""
Adds two outputs, a string tensors for the keys and a sequence of tensors
for the values.
The output name is ``name___keys`` and ``name___values``.
"""
keys = []
values = []
for k, v in tensors.items():
keys.append(k)
values.append(v)
self.append_output_initializer(f"{name}{self.sep}keys", np.array(keys, dtype=np.str_))
self.append_output_sequence(f"{name}{self.sep}values", values)
def _build_initializers(self, switch_low_high: bool) -> List[onnx.TensorProto]:
"""
Builds initializers.
:param switch_low_high: invert low, high precision
:return: a list of tensors to stored in the model
"""
init_dict = self.initializers_dict
if switch_low_high:
# Let's try to minimize the time.
initializer: List[onnx.TensorProto] = []
for k, v in init_dict.items():
if isinstance(v, onnx.TensorProto):
initializer.append(v)
continue
if isinstance(v, np.ndarray):
itype = dtype_to_tensor_dtype(v.dtype)
if itype in {
onnx.TensorProto.BOOL,
onnx.TensorProto.STRING,
onnx.TensorProto.UNDEFINED,
onnx.TensorProto.COMPLEX64,
onnx.TensorProto.COMPLEX128,
getattr(onnx.TensorProto, "UINT4", 0),
getattr(onnx.TensorProto, "INT4", 0),
}:
t = onh.from_array(v, name=k)
initializer.append(t)
continue
from_np = True
elif isinstance(v, np.float32):
t = onh.from_array(np.array([v], dtype=np.float32), name=k)
initializer.append(t)
continue
elif isinstance(v, np.float64):
t = onh.from_array(np.array([v], dtype=np.float64), name=k)
initializer.append(t)
continue
elif isinstance(v, np.float16):
t = onh.from_array(np.array([v], dtype=np.float16), name=k)
initializer.append(t)
continue
else:
import torch
assert isinstance(
v, torch.Tensor
), f"tensor {k!r} has un unexpected type {type(v)}"
assert "FakeTensor" not in str(
type(v)
), f"tensor {k!r} cannot be a FakeTensor: {type(v)}"
from_np = False
itype = dtype_to_tensor_dtype(v.dtype)
# How to avoid a copy?
if from_np:
tensor = onnx.TensorProto()
tensor.name = k
tensor.dims.extend(v.shape)
tensor.data_type = itype
tensor.raw_data = v.tobytes()
else:
tensor = proto_from_array(v, name=k)
initializer.append(tensor)
return initializer
res: List[onnx.TensorProto] = []
for k, v in init_dict.items():
if isinstance(v, onnx.TensorProto):
res.append(v)
continue
if isinstance(v, np.ndarray):
t = onh.from_array(v, name=k)
res.append(t)
continue
import torch
if isinstance(v, torch.Tensor):
# no string tensor
t = proto_from_array(v, name=k)
res.append(t)
continue
raise TypeError(
f"Unable to convert initializer {k!r} with type {type(v)} into a TensorProto."
)
return res
[docs]
def to_onnx(self) -> onnx.ModelProto:
"""
Conversion to onnx.
:return: the proto
"""
opsets = [oh.make_opsetid(*o) for o in self.opsets.items()]
ir_version = self.ir_version
model = onnx.ModelProto()
model.graph.CopyFrom(onnx.GraphProto())
model.graph.name = "mini_model"
model.graph.input.extend(self.inputs)
model.graph.node.extend(self.nodes)
model.graph.output.extend(self.outputs)
initializers = self._build_initializers(switch_low_high=sys.byteorder != "big")
model.graph.initializer.extend(initializers)
model.opset_import.extend(opsets)
model.ir_version = ir_version
return model
def _flatten_iterator(obj: Any, sep: str) -> Iterator:
"""Iterates on all object."""
if obj is not None:
if isinstance(obj, np.ndarray):
yield "array", obj
elif isinstance(obj, bool):
yield "bool", np.array([obj], dtype=np.bool_)
elif isinstance(obj, int):
yield "int", np.array([obj], dtype=np.int64)
elif isinstance(obj, float):
yield "float", np.array([obj], dtype=np.float64)
elif isinstance(obj, tuple):
if not obj:
yield f"tuple.{sep}empty", None
else:
for i, o in enumerate(obj):
if i == len(obj) - 1:
for p, oo in _flatten_iterator(o, sep):
yield f"tuple_{i}.{sep}{p}", oo
else:
for p, oo in _flatten_iterator(o, sep):
yield f"tuple_{i}{sep}{p}", oo
elif isinstance(obj, list):
if not obj:
yield f"list.{sep}empty", None
else:
for i, o in enumerate(obj):
if i == len(obj) - 1:
for p, oo in _flatten_iterator(o, sep):
yield f"list_{i}.{sep}{p}", oo
else:
for p, oo in _flatten_iterator(o, sep):
yield f"list_{i}{sep}{p}", oo
elif isinstance(obj, dict):
if not obj:
yield f"dict.{sep}empty", None
else:
for i, (k, v) in enumerate(obj.items()):
assert sep not in k, (
f"Key {k!r} cannot contain '{sep}'. "
f"It would interfere with the serialization."
)
def _mk(k):
if isinstance(k, tuple):
# this assumes the tuple contains simple types
return f"(({','.join(map(str,k))}))"
return str(k)
if i == len(obj) - 1:
for p, o in _flatten_iterator(v, sep):
yield f"dict._{_mk(k)}{sep}{p}", o
else:
for p, o in _flatten_iterator(v, sep):
yield f"dict_{_mk(k)}{sep}{p}", o
elif obj.__class__.__name__ == "DynamicCache":
# transformers
import transformers
from .cache_helper import CacheKeyValue
assert isinstance(
obj, transformers.cache_utils.DynamicCache
), f"Unexpected type {type(obj)}"
obj = CacheKeyValue(obj)
atts = ["key_cache", "value_cache"]
for i, att in enumerate(atts):
if i == len(atts) - 1:
for p, o in _flatten_iterator(getattr(obj, att), sep):
yield f"DynamicCache._{att}{sep}{p}", o
else:
for p, o in _flatten_iterator(getattr(obj, att), sep):
yield f"DynamicCache_{att}{sep}{p}", o
elif obj.__class__.__name__ == "StaticCache":
# transformers
import transformers
from .cache_helper import CacheKeyValue
assert isinstance(
obj, transformers.cache_utils.StaticCache
), f"Unexpected type {type(obj)}"
obj = CacheKeyValue(obj)
atts = ["key_cache", "value_cache"]
for i, att in enumerate(atts):
if i == len(atts) - 1:
for p, o in _flatten_iterator(getattr(obj, att), sep):
yield f"StaticCache._{att}{sep}{p}", o
else:
for p, o in _flatten_iterator(getattr(obj, att), sep):
yield f"StaticCache_{att}{sep}{p}", o
else:
import torch
if isinstance(obj, torch.Tensor):
yield "tensor", obj
else:
raise NotImplementedError(f"Unexpected type {type(obj)}")
def _unflatten(
sep: str,
names: List[str],
outputs: List[Any],
pos: int = 0,
level: int = 0,
device: str = "cpu",
) -> Tuple[int, Any]:
"""Unflattens a list of outputs flattened with :func:`flatten_iterator`."""
name = names[pos]
spl = name.split(sep)
if len(spl) == level + 1:
# A tensor.
if spl[-1] == "empty":
return pos + 1, None
if spl[-1] == "bool":
return pos + 1, bool(outputs[pos][0])
if spl[-1] == "int":
return pos + 1, int(outputs[pos][0])
if spl[-1] == "float":
return pos + 1, float(outputs[pos][0])
if spl[-1] == "array":
return pos + 1, outputs[pos]
if spl[-1] == "tensor":
import torch
try:
return pos + 1, torch.from_numpy(outputs[pos]).to(device)
except TypeError:
# it should be more robust
import ml_dtypes
if outputs[pos].dtype == ml_dtypes.bfloat16:
import torch
return pos + 1, torch.from_numpy(outputs[pos].astype(float)).to(device).to(
torch.bfloat16
)
raise
raise AssertionError(f"Unexpected name {name!r} in {names}")
res: List[Any] = []
while True:
assert pos < len(names), f"Something went wrong with names={names!r}\nres={res!r}"
name = names[pos]
spl = name.split(sep)
prefix = spl[level]
next_pos, value = _unflatten(sep, names, outputs, pos=pos, level=level + 1, device=device)
if prefix.startswith("DynamicCache"):
key = prefix.split("_", maxsplit=1)[-1]
res.append((key, value))
lp = len("DynamicCache")
end = len(prefix) > lp and prefix[lp] == "."
elif prefix.startswith("dict"):
key = prefix.split("_", maxsplit=1)[-1]
res.append((key, value))
end = len(prefix) > 4 and prefix[4] == "."
else:
res.append(value)
end = prefix[-1] == "."
if end:
if prefix.startswith("dict"):
ty: type = dict
elif prefix.startswith("list"):
ty = list
elif prefix.startswith("tuple"):
ty = tuple
elif prefix.startswith("DynamicCache"):
from transformers.cache_utils import DynamicCache
ty = DynamicCache
else:
raise AssertionError(f"Unexpected prefix={prefix!r}")
break
pos = next_pos
def _tryint(s):
try:
return int(s)
except (ValueError, TypeError):
if s in {"True", "False"}:
return s == "True"
return s
def _make(ty: type, res: Any) -> Any:
if ty.__name__ == "DynamicCache":
from .cache_helper import CacheKeyValue
cc = CacheKeyValue()
for k, v in res:
setattr(cc, k, v)
r = cc.make_dynamic_cache()
return r
if ty is dict:
d = {}
for k, v in res:
if k.startswith("((") and k.endswith("))"):
spl = k[2:-2].split(",")
key = tuple(_tryint(s) for s in spl)
else:
key = _tryint(k)
d[key] = v
return d
return ty(res)
if end and len(res) == 1:
if res[0] is None:
return next_pos, ty()
if isinstance(res[0], tuple) and len(res[0]) == 2 and res[0] == ("dict.", None):
return next_pos, ty()
return next_pos, _make(ty, res)
return next_pos, (
ty() if len(res) == 1 and res[0] in (("dict.", None), None) else _make(ty, res)
)