Source code for onnx_diagnostic.helpers.torch_helper

import contextlib
import ctypes
import inspect
import os
import sys
import warnings
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import onnx
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
import torch
from .helper import string_type, size_type
from .cache_helper import (
    make_dynamic_cache,
    make_encoder_decoder_cache,
    make_sliding_window_cache,
    make_mamba_cache,
)
from .mini_onnx_builder import create_onnx_model_from_input_tensors
from .onnx_helper import (
    to_array_extended,
    tensor_dtype_to_np_dtype,
    _STORAGE_TYPE,
    onnx_dtype_name,
)


[docs] def proto_from_tensor( arr: "torch.Tensor", # noqa: F821 name: Optional[str] = None, verbose: int = 0, ) -> onnx.TensorProto: """ Converts a torch Tensor into a TensorProto. :param arr: tensor :param verbose: display the type and shape :return: a TensorProto """ import torch if not isinstance(arr, torch.Tensor): raise TypeError(f"Unexpected type {type(arr)}.") if arr.is_sparse: raise NotImplementedError( f"Sparse tensor is not supported yet but initializer {name!r} is." ) # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this. if arr.is_contiguous(): arr_cpu = arr.cpu() else: arr_cpu = arr.contiguous().cpu() numel = torch.numel(arr_cpu) element_size = arr_cpu.element_size() if arr_cpu.dtype in {torch.bfloat16}: np_arr = arr_cpu elif arr_cpu.data_ptr() == arr.data_ptr(): copy = arr_cpu.clone().detach().requires_grad_(False) assert ( arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr() ), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}" np_arr = np.from_dlpack(copy) else: np_arr = np.from_dlpack(arr_cpu.detach()) tensor = onnx.TensorProto() tensor.dims.extend(arr_cpu.shape) if name: tensor.name = name itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype) assert not hasattr(onnx.TensorProto, "INT4") or itype not in { onnx.TensorProto.INT4, onnx.TensorProto.UINT4, }, f"Type {arr.dtype} is not supported yet for name={name!r}" tensor.data_type = itype if verbose > 1 and numel > 100: print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]") if isinstance(np_arr, torch.Tensor): byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr()) tensor.raw_data = bytes(byte_data) if sys.byteorder == "big": np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore else: tensor.raw_data = np_arr.tobytes() if sys.byteorder == "big": np_dtype = tensor_dtype_to_np_dtype(tensor.data_type) np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) return tensor
[docs] def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821 """ Converts an onnx type into a torch dtype. :param to: onnx dtype :return: torch dtype """ if itype == onnx.TensorProto.FLOAT: return torch.float32 if itype == onnx.TensorProto.FLOAT16: return torch.float16 if itype == onnx.TensorProto.BFLOAT16: return torch.bfloat16 if itype == onnx.TensorProto.DOUBLE: return torch.float64 if itype == onnx.TensorProto.INT32: return torch.int32 if itype == onnx.TensorProto.INT64: return torch.int64 if itype == onnx.TensorProto.UINT32: return torch.uint32 if itype == onnx.TensorProto.UINT64: return torch.uint64 if itype == onnx.TensorProto.BOOL: return torch.bool if itype == onnx.TensorProto.INT16: return torch.int16 if itype == onnx.TensorProto.UINT16: return torch.uint16 if itype == onnx.TensorProto.INT8: return torch.int8 if itype == onnx.TensorProto.UINT8: return torch.uint8 if itype == onnx.TensorProto.COMPLEX64: return torch.complex64 if itype == onnx.TensorProto.COMPLEX128: return torch.complex128 raise NotImplementedError( f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type." )
[docs] def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821 """ Converts a torch dtype into a onnx element type. :param to: torch dtype :return: onnx type """ import torch if to == torch.float32: return onnx.TensorProto.FLOAT if to == torch.float16: return onnx.TensorProto.FLOAT16 if to == torch.bfloat16: return onnx.TensorProto.BFLOAT16 if to == torch.float64: return onnx.TensorProto.DOUBLE if to == torch.int64: return onnx.TensorProto.INT64 if to == torch.int32: return onnx.TensorProto.INT32 if to == torch.uint64: return onnx.TensorProto.UINT64 if to == torch.uint32: return onnx.TensorProto.UINT32 if to == torch.bool: return onnx.TensorProto.BOOL if to == torch.SymInt: return onnx.TensorProto.INT64 if to == torch.int16: return onnx.TensorProto.INT16 if to == torch.uint16: return onnx.TensorProto.UINT16 if to == torch.int8: return onnx.TensorProto.INT8 if to == torch.uint8: return onnx.TensorProto.UINT8 if to == torch.SymFloat: return onnx.TensorProto.FLOAT if to == torch.complex64: return onnx.TensorProto.COMPLEX64 if to == torch.complex128: return onnx.TensorProto.COMPLEX128 raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
def _forward_( *args, _f=None, _fprint=string_type, _prefix="", _context=None, _storage=None, _storage_limit=2**27, _verbose=0, **kwargs, ): assert _f is not None, "_f cannot be None" assert _context is not None, "_context cannot be None" indent = " " * (len(_prefix) - len(_prefix.lstrip())) _prefix = _prefix.lstrip() print( f"{indent}+{_prefix} -- stolen forward for class {_context['class_name']} " f"-- iteration {_context['iteration']}" ) kws = dict( with_shape=_context.get("with_shape", False), with_min_max=_context.get("with_min_max", False), ) if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting(): # torch.compiler.is_exporting requires torch>=2.7 print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}") if _storage is not None: it = _context["iteration"] key = (_prefix, it) _storage[(*key, "I")] = (torch_deepcopy(args), torch_deepcopy(kwargs)) res = _f(*args, **kwargs) if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting(): print(f"{indent} -> {_fprint(res, **kws)}") print(f"{indent}-{_prefix}.") if _storage is not None: size = torch_tensor_size(res) if size < _storage_limit: if _verbose: print( f"-- stores key={key}, size {size // 2**10}Kb -- " f"{string_type(res, with_shape=True)}" ) _storage[(*key, "O")] = torch_deepcopy(res) else: if _verbose: print( f"-- skips key={key}, size {size // 2**10}Kb -- " f"{string_type(res, with_shape=True)}" ) _context["iteration"] += 1 return res _steal_forward_status = [False] _additional_stolen_objects = {}
[docs] def is_stealing() -> bool: """Returns true if :func:`steal_forward` was yielded.""" return _steal_forward_status[0]
[docs] def steal_append(name: str, obj: Any): """ When outside a forward method, it is still possible to add a python object which contains tensors and dump after the execution of the model. .. code-block:: python steal_append("quantize", [t1, t2]) The same code can executed multiple times, then the name can extended with a number. """ if is_stealing(): if name in _additional_stolen_objects: i = 1 n = f"{name}_{i}" while n in _additional_stolen_objects: i += 1 n = f"{name}_{i}" print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}") _additional_stolen_objects[n] = obj else: print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}") _additional_stolen_objects[name] = obj
[docs] @contextlib.contextmanager def steal_forward( model: Union[ Union[torch.nn.Module, Tuple[str, torch.nn.Module]], List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]], ], fprint: Callable = string_type, dump_file: Optional[str] = None, submodules: bool = False, verbose: int = 0, storage_limit: int = 2**27, **kwargs, ): """ The necessary modification to steem forward method and prints out inputs and outputs using :func:`onnx_diagnostic.helpers.string_type`. See example :ref:`l-plot-tiny-llm-export`. :param model: a model or a list of models to monitor, every model can also be a tuple(name, model), name is displayed well. :param fprint: function used to print out (or dump), by default, it is :func:`onnx_diagnostic.helpers.string_type` :param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type` or any other function defined by ``fprint`` :param dump_file: dumps stolen inputs and outputs in an onnx model, they can be restored with :func:`create_input_tensors_from_onnx_model <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>` :param submodules: if True and model is a module, the list extended with all the submodules the module contains :param verbose: verbosity :param storage_limit: do not stored object bigger than this The following examples shows how to steal and dump all the inputs / outputs for a module and its submodules, then restores them. .. runpython:: :showcode: import torch from onnx_diagnostic.helpers.torch_helper import steal_forward from onnx_diagnostic.helpers.mini_onnx_builder import ( create_input_tensors_from_onnx_model, ) class SubModel(torch.nn.Module): def forward(self, x): return x * x class Model(torch.nn.Module): def __init__(self): super().__init__() self.s1 = SubModel() self.s2 = SubModel() def forward(self, x, y): return self.s1(x) + self.s2(y) inputs = torch.rand(2, 1), torch.rand(2, 1) model = Model() dump_file = "dump_steal_forward_submodules.onnx" with steal_forward(model, submodules=True, dump_file=dump_file): model(*inputs) # Let's restore the stolen data. restored = create_input_tensors_from_onnx_model(dump_file) for k, v in sorted(restored.items()): if isinstance(v, tuple): args, kwargs = v print("input", k, args, kwargs) else: print("output", k, v) Function :func:`steal_append` can be used to dump more tensors. When inside the context, func:`is_stealing` returns True, False otherwise. """ assert not is_stealing(), "steal_forward was already called." # We clear the cache. _steal_forward_status[0] = True _additional_stolen_objects.clear() assert not submodules or isinstance( model, torch.nn.Module ), f"submodules can only be True if model is a module but is is {type(model)}." context = dict(iteration=0, **kwargs) if "with_shape" not in context and fprint == string_type: context["with_shape"] = True if not isinstance(model, list): assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model" if submodules: models = [] for idx, m in model.named_modules(): level = str(idx).split(".") ll = len(level) try: _, start_line = inspect.getsourcelines(m.forward) except OSError: # The code is not available. start_line = 0 name = f"{idx}-{m.__class__.__name__}-{start_line}" models.append((f"{' ' * ll}{name}", m)) model = models else: model = [model] keep_model_forward = {} storage: Optional[Dict[Any, Any]] = {} if dump_file else None for mt in model: name, m = mt if isinstance(mt, tuple) else ("", mt) keep_model_forward[id(m)] = (m, m.forward) c = context.copy() c["class_name"] = m.__class__.__name__ m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, _v=verbose, _sl=storage_limit, **kws: _forward_( # noqa: E501 *args, _f=_f, _fprint=_fp, _context=_c, _prefix=_p, _storage=_s, _verbose=_v, _storage_limit=_sl, **kws, ) try: yield finally: _steal_forward_status[0] = False for f in keep_model_forward.values(): f[0].forward = f[1] if dump_file: # Let's add the cached tensor assert storage is not None, "storage cannot be None but mypy is confused here." storage.update(_additional_stolen_objects) # We clear the cache. _additional_stolen_objects.clear() if verbose: size = torch_tensor_size(storage) print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb") proto = create_onnx_model_from_input_tensors(storage) if verbose: print("-- dumps stored objects") onnx.save( proto, dump_file, save_as_external_data=True, all_tensors_to_one_file=True, location=f"{os.path.split(dump_file)[-1]}.data", ) if verbose: print("-- done dump stored objects")
[docs] @contextlib.contextmanager def fake_torchdynamo_exporting(): """ Sets ``torch.compiler._is_exporting_flag`` to True to trigger pieces of code only enabled during export. """ memorize = torch.compiler._is_exporting_flag torch.compiler._is_exporting_flag = True try: yield finally: torch.compiler._is_exporting_flag = memorize
[docs] def is_torchdynamo_exporting() -> bool: """ Tells if :epkg:`torch` is exporting a model. Relies on ``torch.compiler.is_exporting()``. """ import torch if not hasattr(torch.compiler, "is_exporting"): # torch.compiler.is_exporting requires torch>=2.7 return False try: return torch.compiler.is_exporting() except Exception: try: import torch._dynamo as dynamo return dynamo.is_exporting() # type: ignore except Exception: return False
[docs] def to_numpy(tensor: "torch.Tensor"): # noqa: F821 """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`.""" try: return tensor.numpy() except TypeError: # We try with ml_dtypes pass import ml_dtypes conv = {torch.bfloat16: ml_dtypes.bfloat16} assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}" return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
[docs] def replace_string_by_dynamic(dynamic_shapes: Any) -> Any: """Replaces strings by ``torch.export.Dim.DYNAMIC``.""" import torch if isinstance(dynamic_shapes, torch.export.dynamic_shapes._Dim): return dynamic_shapes if isinstance(dynamic_shapes, str): return torch.export.Dim.DYNAMIC if not dynamic_shapes: return dynamic_shapes if isinstance(dynamic_shapes, (tuple, list)): return type(dynamic_shapes)(replace_string_by_dynamic(i) for i in dynamic_shapes) if isinstance(dynamic_shapes, dict): return {k: replace_string_by_dynamic(v) for k, v in dynamic_shapes.items()} raise AssertionError(f"Unexpected type {type(dynamic_shapes)} for dynamic_shapes")
[docs] def dummy_llm( cls_name: Optional[str] = None, dynamic_shapes: bool = False, ) -> Union[ Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]], Tuple[torch.nn.Module, Tuple[torch.Tensor, ...], Any], ]: """ Creates a dummy LLM for test purposes. :param cls_name: None for whole model or a piece of it :param dynamic_shapes: returns dynamic shapes as well .. runpython:: :showcode: from onnx_diagnostic.helpers.torch_helper import dummy_llm print(dummy_llm()) """ class Embedding(torch.nn.Module): def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16): super().__init__() self.embedding = torch.nn.Embedding(vocab_size, embedding_dim) self.pe = torch.nn.Embedding(vocab_size, embedding_dim) def forward(self, x): word_emb = self.embedding(x) word_pe = self.pe(x) return word_emb + word_pe class AttentionBlock(torch.nn.Module): def __init__(self, embedding_dim: int = 16, context_size: int = 256): super().__init__() self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) # torch.nn.Buffer are not fully handled by symbolic tracing # Buffer(...)[:Prowy()] is not working self.mask = torch.nn.Parameter( torch.tril( input=torch.ones(size=[context_size, context_size], dtype=torch.float) ) ) def forward(self, x): B, T, C = x.shape query = self.query(x) key = self.key(x) value = self.value(x) qk = query @ key.transpose(-2, -1) * C**-0.5 attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf")) attention = torch.nn.functional.softmax(input=attention, dim=-1) out = attention @ value return out class MultiAttentionBlock(torch.nn.Module): def __init__( self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256 ): super().__init__() self.attention = torch.nn.ModuleList( modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)] ) self.linear = torch.nn.Linear( in_features=embedding_dim * num_heads, out_features=embedding_dim ) def forward(self, x): out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1) x = self.linear(out) return x class FeedForward(torch.nn.Module): def __init__(self, embedding_dim: int = 16, ff_dim: int = 128): super().__init__() self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim) self.relu = torch.nn.ReLU() self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim) def forward(self, x): x = self.linear_1(x) x = self.relu(x) x = self.linear_2(x) return x class DecoderLayer(torch.nn.Module): def __init__( self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256, ff_dim: int = 128, ): super().__init__() self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size) self.feed_forward = FeedForward(embedding_dim, ff_dim) self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim) self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim) def forward(self, x): x_norm = self.norm_1(x) attention = self.attention(x_norm) attention = attention + x attention_norm = self.norm_2(attention) ff = self.feed_forward(attention_norm) ff = ff + attention return ff class LLM(torch.nn.Module): def __init__( self, vocab_size: int = 1024, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256, ff_dim: int = 128, ): super().__init__() self.embedding = Embedding(vocab_size, embedding_dim) self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim) def forward(self, input_ids): x = self.embedding(input_ids) y = self.decoder(x) return y if cls_name in (None, "LLM"): dec: torch.nn.Module = LLM() x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64) dec(x) if dynamic_shapes: dyn = { "input_ids": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) if cls_name == "DecoderLayer": LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)) dec = DecoderLayer() x = Embedding()( torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64) ) dec(x) if dynamic_shapes: dyn = { "x": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) if cls_name == "MultiAttentionBlock": dec = MultiAttentionBlock() x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32) dec(x) if dynamic_shapes: dyn = { "x": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) if cls_name == "AttentionBlock": dec = AttentionBlock() x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32) dec(x) if dynamic_shapes: dyn = { "x": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) raise NotImplementedError(f"cls_name={cls_name}")
[docs] def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any: """Applies torch.to if applicable. Goes recursively.""" if isinstance(value, (torch.nn.Module, torch.Tensor)): return value.to(to_value) if isinstance(value, list): return [to_any(t, to_value) for t in value] if isinstance(value, tuple): return tuple(to_any(t, to_value) for t in value) if isinstance(value, set): return {to_any(t, to_value) for t in value} if isinstance(value, dict): return {k: to_any(t, to_value) for k, t in value.items()} if hasattr(value, "to"): return value.to(to_value) if value.__class__.__name__ == "DynamicCache": return make_dynamic_cache( list( zip( [t.to(to_value) for t in value.key_cache], [t.to(to_value) for t in value.value_cache], ) ) ) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: args, spec = torch.utils._pytree.tree_flatten(value) new_args = to_any(args, to_value) return torch.utils._pytree.tree_unflatten(new_args, spec) assert not isinstance(value, Iterable), f"Unsupported type {type(value)}" return value
[docs] def torch_deepcopy(value: Any) -> Any: """Makes a deepcopy.""" if value is None: return None if isinstance(value, (int, float, str)): return value if isinstance(value, tuple): return tuple(torch_deepcopy(v) for v in value) if isinstance(value, list): return [torch_deepcopy(v) for v in value] if isinstance(value, set): return {torch_deepcopy(v) for v in value} if isinstance(value, dict): if type(value) is dict: return {k: torch_deepcopy(v) for k, v in value.items()} # for BaseModelOutput return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()}) if isinstance(value, np.ndarray): return value.copy() if hasattr(value, "clone"): return value.clone() if value.__class__.__name__ == "DynamicCache": return make_dynamic_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) ) if value.__class__.__name__ == "SlidingWindowCache": return make_sliding_window_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) ) if value.__class__.__name__ == "EncoderDecoderCache": return make_encoder_decoder_cache( torch_deepcopy(value.self_attention_cache), torch_deepcopy(value.cross_attention_cache), ) if value.__class__.__name__ == "MambaCache": return make_mamba_cache(list(zip(value.conv_states, value.ssm_states))) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: args, spec = torch.utils._pytree.tree_flatten(value) new_args = torch_deepcopy(args) return torch.utils._pytree.tree_unflatten(new_args, spec) # We should have a code using serialization, deserialization assuming a model # cannot be exported without them. raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
[docs] def torch_tensor_size(value: Any) -> Any: """Returns the number of bytes stored in tensors.""" if value is None: return 0 if isinstance(value, (int, float, str)): return 0 if isinstance(value, (tuple, list, set)): return sum(torch_tensor_size(v) for v in value) if isinstance(value, dict): return sum(torch_tensor_size(v) for v in value.values()) if isinstance(value, np.ndarray): return value.copy() if hasattr(value, "clone"): return value.numel() * size_type(value.dtype) if value.__class__.__name__ in {"DynamicCache", "SlidingWindowCache"}: return torch_tensor_size(value.key_cache) + torch_tensor_size(value.value_cache) if value.__class__.__name__ == "EncoderDecoderCache": return torch_tensor_size(value.self_attention_cache) + torch_tensor_size( value.cross_attention_cache ) if value.__class__.__name__ == "MambaCache": return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: args, spec = torch.utils._pytree.tree_flatten(value) return sum(torch_tensor_size(a) for a in args) # We should have a code using serialization, deserialization assuming a model # cannot be exported without them. raise NotImplementedError(f"torch_tensor_size not implemented for type {type(value)}")
[docs] def model_statistics(model: torch.nn.Module): """Returns statistics on a model in a dictionary.""" n_subs = len(list(model.modules())) sizes = {} param_size = 0 for param in model.parameters(): size = param.nelement() * param.element_size() param_size += size name = str(param.dtype).replace("torch.", "") if name not in sizes: sizes[name] = 0 sizes[name] += size buffer_size = 0 for buffer in model.buffers(): size = buffer.nelement() * buffer.element_size() buffer_size += size name = str(buffer.dtype).replace("torch.", "") if name not in sizes: sizes[name] = 0 sizes[name] += size res = dict( type=model.__class__.__name__, n_modules=n_subs, param_size=param_size, buffer_size=buffer_size, size_mb=(param_size + buffer_size) // 2**20, ) res.update(sizes) return res
[docs] def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor: """ Converts a TensorProto to a numpy array. :param tensor: a TensorProto object. :param base_dir: if external tensor exists, base_dir can help to find the path to it :return: the converted tensor """ assert not tensor.HasField("segment"), "Currently not supporting loading segments." assert ( tensor.data_type != onnx.TensorProto.UNDEFINED ), "The element type in the input tensor is not defined." assert tensor.data_type != onnx.TensorProto.STRING, "to_tensor not implemented for strings" tensor_dtype = tensor.data_type torch_dtype = onnx_dtype_to_torch_dtype(tensor_dtype) dims = tuple(tensor.dims) if uses_external_data(tensor): # Load raw data from external tensor if it exists load_external_data_for_tensor(tensor, base_dir) if tensor.HasField("raw_data"): raw_data = tensor.raw_data if sys.byteorder == "big": # Convert endian from little to big raw_data = torch.frombuffer(raw_data, dtype=torch_dtype).byteswap().tobytes() with warnings.catch_warnings(): warnings.simplefilter("ignore") return torch.frombuffer(raw_data, dtype=torch_dtype).reshape(dims) # Other cases, it should be small tensor. We use numpy. np_tensor = to_array_extended(tensor) return torch.from_numpy(np_tensor)