Source code for experimental_experiment.xbuilder.model_container

import os
import time
import sys
from typing import Any, List, Dict, Optional
import numpy as np
from onnx import GraphProto, ModelProto, StringStringEntryProto, TensorProto, load_model
from onnx.model_container import ModelContainer, _set_external_data
from onnx.external_data_helper import _get_all_tensors, uses_external_data
from onnx.inliner import inline_local_functions
from ..helpers import tensor_dtype_to_np_dtype, torch_dtype_to_onnx_dtype
from ..mini_onnx_builder import proto_from_array


STORAGE_TYPE = {
    TensorProto.FLOAT16: np.int16,
    TensorProto.BFLOAT16: np.int16,
}


def _get_type(elem_type: Any, exc: bool = True) -> int:
    if not isinstance(elem_type, int):
        st = str(elem_type)
        if "float32" in st:
            elem_type = TensorProto.FLOAT
        elif "float64" in st:
            elem_type = TensorProto.DOUBLE
        elif "bfloat16" in st:
            elem_type = TensorProto.BFLOAT16
        elif "float16" in st:
            elem_type = TensorProto.FLOAT16
        elif "uint64" in st:
            elem_type = TensorProto.UINT64
        elif "int64" in st:
            elem_type = TensorProto.INT64
        elif "uint32" in st:
            elem_type = TensorProto.UINT32
        elif "int32" in st:
            elem_type = TensorProto.INT32
        elif "uint16" in st:
            elem_type = TensorProto.UINT16
        elif "int16" in st:
            elem_type = TensorProto.INT16
        elif "bool" in st:
            elem_type = TensorProto.BOOL
        elif "uint8" in st:
            elem_type = TensorProto.UINT8
        elif "int8" in st:
            elem_type = TensorProto.INT8
        elif "complex64" in st:
            elem_type = TensorProto.COMPLEX64
        elif "complex128" in st:
            elem_type = TensorProto.COMPLEX128
        elif elem_type is None:
            elem_type = TensorProto.UNDEFINED
        elif exc:
            raise ValueError(f"Unable to interpret elem_type {elem_type!r}.")
    return elem_type


[docs] class TorchModelContainer(ModelContainer): """ Overwrites :class:`onnx.model_container.ModelContainer` to support torch tensors. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._stats = { "time_export_write_model": 0, "time_export_byteswap_tobytes": 0, "time_export_tobytes": 0, "time_export_proto_from_array": 0, "time_export_write_tensor_bytes": 0, "time_export_inline_model": 0, } self.inline = False
[docs] def save(self, file_path: str, all_tensors_to_one_file: bool = True) -> ModelProto: """ Saves the large model. The function returns a ModelProto, the current one if the model did not need any modification, a modified copy of it if it required changes such as giving file names to every external tensor. :param file_path: model file :param all_tensors_to_one_file: saves all large tensors in one file or one file per lerge tensor :return: the saved ModelProto """ return self._save_external(file_path, all_tensors_to_one_file=all_tensors_to_one_file)
[docs] def load( self, file_path: str, load_large_initializers: bool = True ) -> "TorchModelContainer": """ Loads the large model. :param file_path: model file :param load_large_initializers: loads the large initializers, if not done, the model is incomplete but it can be used to look into the model without executing it and method :meth:`_load_large_initializers` can be used to load them later :return: self """ self.model_proto_ = load_model(file_path, load_external_data=False) if load_large_initializers: self._load_large_initializers(file_path) return self
def _save_external( self, file_path: str, all_tensors_to_one_file: bool, ) -> ModelProto: """Save the large model into a main onnx file and one file per tensor. Follows the same format as :func:`write_external_data_tensors <onnx.external_data_helper.write_external_data_tensors>`. The main model needs to be modified to update the file location, the function returns this modified copy. Arguments: file_path: model file all_tensors_to_one_file: all tensors in one file stats: saves time if not Nones Returns: modified main model proto """ def _clean_name(prefix: str, name: str, unique_names: dict[str, int]) -> str: if prefix: name = f"{prefix}-{name}" for c in ":/\\;,!#": name = name.replace(c, "") base_name = name if name in unique_names: i = unique_names[name] + 1 unique_names[name] = i return f"{base_name}_{i}" unique_names[name] = 1 return name unique_names: dict[str, int] = {} folder = os.path.dirname(file_path) if folder and not os.path.exists(folder): raise FileNotFoundError(f"Folder {folder!r} does not exist.") proto = self.model_proto.SerializeToString() copy = ModelProto() copy.ParseFromString(proto) prefix = os.path.splitext(os.path.split(file_path)[-1])[0] if all_tensors_to_one_file: file_weight = f"{os.path.split(file_path)[1]}.data" full_file_weight = f"{file_path}.data" offset = 0 with open(full_file_weight, "wb") as f: pass for tensor in _get_all_tensors(copy): if not uses_external_data(tensor): continue prop: Optional[StringStringEntryProto] = None for ext in tensor.external_data: # type: ignore[assignment] if ext.key == "location": # type: ignore[attr-defined] prop = ext # type: ignore[assignment] if prop is None: raise RuntimeError(f"No location found for tensor name {tensor.name!r}.") if prop.value not in self.large_initializers: raise RuntimeError( f"Unable to find large tensor named {tensor.name!r} " f"with location {prop.value!r} in " f"{sorted(self.large_initializers)}." ) np_tensor = self.large_initializers[prop.value] if sys.byteorder == "big": # Convert endian from little to big begin = time.perf_counter() tensor_bytes = np_tensor.byteswap().tobytes() self._stats["time_export_byteswap_tobytes"] += time.perf_counter() - begin elif isinstance(np_tensor, np.ndarray): begin = time.perf_counter() tensor_bytes = np_tensor.tobytes() self._stats["time_export_tobytes"] += time.perf_counter() - begin elif isinstance(np_tensor, TensorProto): tensor_bytes = np_tensor.raw_data assert len(tensor_bytes) > 0, f"One tensor is null, np_tensor={np_tensor}." else: import torch if isinstance(np_tensor, torch.nn.Parameter): pt = np_tensor.data elif isinstance(np_tensor, torch.Tensor): pt = np_tensor else: raise NotImplementedError( f"Handling of type {type(np_tensor)} as large initializer " f"is not implemented yet." ) begin = time.perf_counter() proto = proto_from_array(pt, name="dummy") self._stats["time_export_proto_from_array"] += time.perf_counter() - begin tensor_bytes = proto.raw_data assert ( pt.dtype != torch.float32 or len(tensor_bytes) == np.prod(pt.shape) * 4 ), ( f"Unexpected size mismatch, buffer size is {len(tensor_bytes)}, " f"but tensor size={np.prod(pt.shape) * 4}, " f"shape={pt.shape}, dtype={pt.dtype}" ) begin = time.perf_counter() if all_tensors_to_one_file: _set_external_data( tensor, location=file_weight, offset=offset, length=len(tensor_bytes), ) offset += len(tensor_bytes) with open(full_file_weight, "ab") as f: f.write(tensor_bytes) else: name = f"{_clean_name(prefix, prop.value, unique_names)}.weight" _set_external_data(tensor, location=name) full_name = os.path.join(folder, name) prop.value = name with open(full_name, "wb") as f: f.write(tensor_bytes) self._stats["time_export_write_tensor_bytes"] += time.perf_counter() - begin if self.inline: begin = time.perf_counter() copy = inline_local_functions(copy) self._stats["time_export_inline_model"] += time.perf_counter() - begin begin = time.perf_counter() with open(file_path, "wb") as f: f.write(copy.SerializeToString()) self._stats["time_export_write_model"] += time.perf_counter() - begin return copy def _deserialize_graph( self, proto: GraphProto, scoped_values: List[Dict[str, "onnxscript.ir.Value"]], # noqa: F821 ) -> "onnxscript.ir.Graph": # noqa: F821 """See :epkg:`onnxscript`.""" import onnxscript.ir as oir import onnxscript.ir.serde as oirs from ..reference import to_array_extended quantization_annotations = { annotation.tensor_name for annotation in proto.quantization_annotation } initializer_tensors = [] for tensor in proto.initializer: if uses_external_data(tensor): prop = None for ext in tensor.external_data: # type: ignore[assignment] if ext.key == "location": # type: ignore[attr-defined] prop = ext # type: ignore[assignment] assert prop is not None, f"No location found for tensor name {tensor.name!r}." assert prop.value in self.large_initializers, ( f"Unable to find large tensor named {tensor.name!r} " f"with location {prop.value!r} in " f"{sorted(self.large_initializers)}." ) np_tensor = self.large_initializers[prop.value] if isinstance(np_tensor, np.ndarray): t = oir.Tensor( np_tensor, name=tensor.name, doc_string=tensor.doc_string, metadata_props=oirs.deserialize_metadata_props(tensor.metadata_props), ) elif hasattr(np_tensor, "shape"): t = oir.Tensor( np_tensor.detach(), name=tensor.name, dtype=oir.DataType.from_numpy( tensor_dtype_to_np_dtype( torch_dtype_to_onnx_dtype(np_tensor.dtype) ) ), doc_string=tensor.doc_string, metadata_props=oirs.deserialize_metadata_props(tensor.metadata_props), ) else: t = oir.Tensor( to_array_extended(np_tensor), name=tensor.name, doc_string=tensor.doc_string, metadata_props=oirs.deserialize_metadata_props(tensor.metadata_props), ) else: t = oirs.deserialize_tensor(tensor) initializer_tensors.append(t) inputs = [oir.Input(info.name) for info in proto.input] for info, value in zip(proto.input, inputs): oirs.deserialize_value_info_proto(info, value) if value.name in quantization_annotations: oirs._deserialize_quantization_annotation( quantization_annotations[value.name], value ) values = {v.name: v for v in inputs} scoped_values.append(values) initializer_values = [] for i, tensor in enumerate(initializer_tensors): initializer_name = tensor.name assert initializer_name, f"Initializer {i} has no name, it should not be there." assert initializer_name not in values, f"Duplicated name {initializer_name!r}" initializer_value = oir.Value( None, index=None, name=initializer_name, type=oir.TensorType(tensor.dtype), shape=tensor.shape, const_value=tensor, ) if initializer_value.name in quantization_annotations: oirs._deserialize_quantization_annotation( quantization_annotations[initializer_value.name], initializer_value ) values[initializer_name] = initializer_value initializer_values.append(initializer_value) value_info = {info.name: info for info in proto.value_info} nodes = [ oirs._deserialize_node(node, scoped_values, value_info, quantization_annotations) for node in proto.node ] outputs = [] for info in proto.output: # Fill in values for graph outputs output_name = info.name assert output_name in values, f"Missing output_name={output_name!r} in {values}" value = values[output_name] oirs.deserialize_value_info_proto(info, value) outputs.append(value) # Exit the graph scope by popping the values for this scope from the stack scoped_values.pop() return oir.Graph( inputs, outputs, nodes=nodes, initializers=initializer_values, doc_string=self._get_field(proto, "doc_string"), name=self._get_field(proto, "name"), metadata_props=oirs.deserialize_metadata_props(proto.metadata_props), ) @classmethod def _get_field(cls, proto: Any, field: str) -> Any: if proto.HasField(field): return getattr(proto, field) return None
[docs] def to_ir(self) -> "onnxscript.ir.Model": # noqa: F821 """Conversion to :class:`onnxscript.ir.Model`.""" import onnxscript.ir as oir import onnxscript.ir.serde as oirs proto = self.model_proto graph = self._deserialize_graph(proto.graph, []) graph.opset_imports.update(oirs.deserialize_opset_import(proto.opset_import)) functions = [] for func in proto.functions: functions.append(oirs.deserialize_function(func)) model = oir.Model( graph, ir_version=proto.ir_version, producer_name=self._get_field(proto, "producer_name"), producer_version=self._get_field(proto, "producer_version"), domain=self._get_field(proto, "domain"), model_version=self._get_field(proto, "model_version"), doc_string=self._get_field(proto, "doc_string"), functions=functions, meta_data_props=oirs.deserialize_metadata_props(proto.metadata_props), ) return model