import contextlib
import ctypes
import inspect
import os
import sys
import warnings
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import onnx
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
import torch
from .helper import string_type, size_type
from .cache_helper import (
make_dynamic_cache,
make_encoder_decoder_cache,
make_sliding_window_cache,
make_mamba_cache,
)
from .mini_onnx_builder import create_onnx_model_from_input_tensors
from .onnx_helper import (
to_array_extended,
tensor_dtype_to_np_dtype,
_STORAGE_TYPE,
onnx_dtype_name,
)
[docs]
def proto_from_tensor(
arr: "torch.Tensor", # noqa: F821
name: Optional[str] = None,
verbose: int = 0,
) -> onnx.TensorProto:
"""
Converts a torch Tensor into a TensorProto.
:param arr: tensor
:param verbose: display the type and shape
:return: a TensorProto
"""
import torch
if not isinstance(arr, torch.Tensor):
raise TypeError(f"Unexpected type {type(arr)}.")
if arr.is_sparse:
raise NotImplementedError(
f"Sparse tensor is not supported yet but initializer {name!r} is."
)
# arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
if arr.is_contiguous():
arr_cpu = arr.cpu()
else:
arr_cpu = arr.contiguous().cpu()
numel = torch.numel(arr_cpu)
element_size = arr_cpu.element_size()
if arr_cpu.dtype in {torch.bfloat16}:
np_arr = arr_cpu
elif arr_cpu.data_ptr() == arr.data_ptr():
copy = arr_cpu.clone().detach().requires_grad_(False)
assert (
arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr()
), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}"
np_arr = np.from_dlpack(copy)
else:
np_arr = np.from_dlpack(arr_cpu.detach())
tensor = onnx.TensorProto()
tensor.dims.extend(arr_cpu.shape)
if name:
tensor.name = name
itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype)
assert not hasattr(onnx.TensorProto, "INT4") or itype not in {
onnx.TensorProto.INT4,
onnx.TensorProto.UINT4,
}, f"Type {arr.dtype} is not supported yet for name={name!r}"
tensor.data_type = itype
if verbose > 1 and numel > 100:
print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
if isinstance(np_arr, torch.Tensor):
byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
tensor.raw_data = bytes(byte_data)
if sys.byteorder == "big":
np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore
np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore
else:
tensor.raw_data = np_arr.tobytes()
if sys.byteorder == "big":
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
return tensor
[docs]
def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
"""
Converts an onnx type into a torch dtype.
:param to: onnx dtype
:return: torch dtype
"""
if itype == onnx.TensorProto.FLOAT:
return torch.float32
if itype == onnx.TensorProto.FLOAT16:
return torch.float16
if itype == onnx.TensorProto.BFLOAT16:
return torch.bfloat16
if itype == onnx.TensorProto.DOUBLE:
return torch.float64
if itype == onnx.TensorProto.INT32:
return torch.int32
if itype == onnx.TensorProto.INT64:
return torch.int64
if itype == onnx.TensorProto.UINT32:
return torch.uint32
if itype == onnx.TensorProto.UINT64:
return torch.uint64
if itype == onnx.TensorProto.BOOL:
return torch.bool
if itype == onnx.TensorProto.INT16:
return torch.int16
if itype == onnx.TensorProto.UINT16:
return torch.uint16
if itype == onnx.TensorProto.INT8:
return torch.int8
if itype == onnx.TensorProto.UINT8:
return torch.uint8
if itype == onnx.TensorProto.COMPLEX64:
return torch.complex64
if itype == onnx.TensorProto.COMPLEX128:
return torch.complex128
raise NotImplementedError(
f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
)
[docs]
def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
"""
Converts a torch dtype into a onnx element type.
:param to: torch dtype
:return: onnx type
"""
import torch
if to == torch.float32:
return onnx.TensorProto.FLOAT
if to == torch.float16:
return onnx.TensorProto.FLOAT16
if to == torch.bfloat16:
return onnx.TensorProto.BFLOAT16
if to == torch.float64:
return onnx.TensorProto.DOUBLE
if to == torch.int64:
return onnx.TensorProto.INT64
if to == torch.int32:
return onnx.TensorProto.INT32
if to == torch.uint64:
return onnx.TensorProto.UINT64
if to == torch.uint32:
return onnx.TensorProto.UINT32
if to == torch.bool:
return onnx.TensorProto.BOOL
if to == torch.SymInt:
return onnx.TensorProto.INT64
if to == torch.int16:
return onnx.TensorProto.INT16
if to == torch.uint16:
return onnx.TensorProto.UINT16
if to == torch.int8:
return onnx.TensorProto.INT8
if to == torch.uint8:
return onnx.TensorProto.UINT8
if to == torch.SymFloat:
return onnx.TensorProto.FLOAT
if to == torch.complex64:
return onnx.TensorProto.COMPLEX64
if to == torch.complex128:
return onnx.TensorProto.COMPLEX128
raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
def _forward_(
*args,
_f=None,
_fprint=string_type,
_prefix="",
_context=None,
_storage=None,
_storage_limit=2**27,
_verbose=0,
**kwargs,
):
assert _f is not None, "_f cannot be None"
assert _context is not None, "_context cannot be None"
indent = " " * (len(_prefix) - len(_prefix.lstrip()))
_prefix = _prefix.lstrip()
print(
f"{indent}+{_prefix} -- stolen forward for class {_context['class_name']} "
f"-- iteration {_context['iteration']}"
)
kws = dict(
with_shape=_context.get("with_shape", False),
with_min_max=_context.get("with_min_max", False),
)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}")
if _storage is not None:
it = _context["iteration"]
key = (_prefix, it)
_storage[(*key, "I")] = (torch_deepcopy(args), torch_deepcopy(kwargs))
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print(f"{indent} -> {_fprint(res, **kws)}")
print(f"{indent}-{_prefix}.")
if _storage is not None:
size = torch_tensor_size(res)
if size < _storage_limit:
if _verbose:
print(
f"-- stores key={key}, size {size // 2**10}Kb -- "
f"{string_type(res, with_shape=True)}"
)
_storage[(*key, "O")] = torch_deepcopy(res)
else:
if _verbose:
print(
f"-- skips key={key}, size {size // 2**10}Kb -- "
f"{string_type(res, with_shape=True)}"
)
_context["iteration"] += 1
return res
_steal_forward_status = [False]
_additional_stolen_objects = {}
[docs]
def is_stealing() -> bool:
"""Returns true if :func:`steal_forward` was yielded."""
return _steal_forward_status[0]
[docs]
def steal_append(name: str, obj: Any):
"""
When outside a forward method, it is still possible to add
a python object which contains tensors and dump after the execution
of the model.
.. code-block:: python
steal_append("quantize", [t1, t2])
The same code can executed multiple times, then
the name can extended with a number.
"""
if is_stealing():
if name in _additional_stolen_objects:
i = 1
n = f"{name}_{i}"
while n in _additional_stolen_objects:
i += 1
n = f"{name}_{i}"
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
_additional_stolen_objects[n] = obj
else:
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
_additional_stolen_objects[name] = obj
[docs]
@contextlib.contextmanager
def steal_forward(
model: Union[
Union[torch.nn.Module, Tuple[str, torch.nn.Module]],
List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]],
],
fprint: Callable = string_type,
dump_file: Optional[str] = None,
submodules: bool = False,
verbose: int = 0,
storage_limit: int = 2**27,
**kwargs,
):
"""
The necessary modification to steem forward method and prints out inputs
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
See example :ref:`l-plot-tiny-llm-export`.
:param model: a model or a list of models to monitor,
every model can also be a tuple(name, model), name is displayed well.
:param fprint: function used to print out (or dump), by default, it is
:func:`onnx_diagnostic.helpers.string_type`
:param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type`
or any other function defined by ``fprint``
:param dump_file: dumps stolen inputs and outputs in an onnx model,
they can be restored with :func:`create_input_tensors_from_onnx_model
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
:param submodules: if True and model is a module, the list extended with all the submodules
the module contains
:param verbose: verbosity
:param storage_limit: do not stored object bigger than this
The following examples shows how to steal and dump all the inputs / outputs
for a module and its submodules, then restores them.
.. runpython::
:showcode:
import torch
from onnx_diagnostic.helpers.torch_helper import steal_forward
from onnx_diagnostic.helpers.mini_onnx_builder import (
create_input_tensors_from_onnx_model,
)
class SubModel(torch.nn.Module):
def forward(self, x):
return x * x
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.s1 = SubModel()
self.s2 = SubModel()
def forward(self, x, y):
return self.s1(x) + self.s2(y)
inputs = torch.rand(2, 1), torch.rand(2, 1)
model = Model()
dump_file = "dump_steal_forward_submodules.onnx"
with steal_forward(model, submodules=True, dump_file=dump_file):
model(*inputs)
# Let's restore the stolen data.
restored = create_input_tensors_from_onnx_model(dump_file)
for k, v in sorted(restored.items()):
if isinstance(v, tuple):
args, kwargs = v
print("input", k, args, kwargs)
else:
print("output", k, v)
Function :func:`steal_append` can be used to dump more tensors.
When inside the context, func:`is_stealing` returns True, False otherwise.
"""
assert not is_stealing(), "steal_forward was already called."
# We clear the cache.
_steal_forward_status[0] = True
_additional_stolen_objects.clear()
assert not submodules or isinstance(
model, torch.nn.Module
), f"submodules can only be True if model is a module but is is {type(model)}."
context = dict(iteration=0, **kwargs)
if "with_shape" not in context and fprint == string_type:
context["with_shape"] = True
if not isinstance(model, list):
assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model"
if submodules:
models = []
for idx, m in model.named_modules():
level = str(idx).split(".")
ll = len(level)
try:
_, start_line = inspect.getsourcelines(m.forward)
except OSError:
# The code is not available.
start_line = 0
name = f"{idx}-{m.__class__.__name__}-{start_line}"
models.append((f"{' ' * ll}{name}", m))
model = models
else:
model = [model]
keep_model_forward = {}
storage: Optional[Dict[Any, Any]] = {} if dump_file else None
for mt in model:
name, m = mt if isinstance(mt, tuple) else ("", mt)
keep_model_forward[id(m)] = (m, m.forward)
c = context.copy()
c["class_name"] = m.__class__.__name__
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, _v=verbose, _sl=storage_limit, **kws: _forward_( # noqa: E501
*args,
_f=_f,
_fprint=_fp,
_context=_c,
_prefix=_p,
_storage=_s,
_verbose=_v,
_storage_limit=_sl,
**kws,
)
try:
yield
finally:
_steal_forward_status[0] = False
for f in keep_model_forward.values():
f[0].forward = f[1]
if dump_file:
# Let's add the cached tensor
assert storage is not None, "storage cannot be None but mypy is confused here."
storage.update(_additional_stolen_objects)
# We clear the cache.
_additional_stolen_objects.clear()
if verbose:
size = torch_tensor_size(storage)
print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
proto = create_onnx_model_from_input_tensors(storage)
if verbose:
print("-- dumps stored objects")
onnx.save(
proto,
dump_file,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{os.path.split(dump_file)[-1]}.data",
)
if verbose:
print("-- done dump stored objects")
[docs]
@contextlib.contextmanager
def fake_torchdynamo_exporting():
"""
Sets ``torch.compiler._is_exporting_flag`` to True to trigger
pieces of code only enabled during export.
"""
memorize = torch.compiler._is_exporting_flag
torch.compiler._is_exporting_flag = True
try:
yield
finally:
torch.compiler._is_exporting_flag = memorize
[docs]
def is_torchdynamo_exporting() -> bool:
"""
Tells if :epkg:`torch` is exporting a model.
Relies on ``torch.compiler.is_exporting()``.
"""
import torch
if not hasattr(torch.compiler, "is_exporting"):
# torch.compiler.is_exporting requires torch>=2.7
return False
try:
return torch.compiler.is_exporting()
except Exception:
try:
import torch._dynamo as dynamo
return dynamo.is_exporting() # type: ignore
except Exception:
return False
[docs]
def to_numpy(tensor: "torch.Tensor"): # noqa: F821
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
try:
return tensor.numpy()
except TypeError:
# We try with ml_dtypes
pass
import ml_dtypes
conv = {torch.bfloat16: ml_dtypes.bfloat16}
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
[docs]
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
"""Replaces strings by ``torch.export.Dim.DYNAMIC``."""
import torch
if isinstance(dynamic_shapes, torch.export.dynamic_shapes._Dim):
return dynamic_shapes
if isinstance(dynamic_shapes, str):
return torch.export.Dim.DYNAMIC
if not dynamic_shapes:
return dynamic_shapes
if isinstance(dynamic_shapes, (tuple, list)):
return type(dynamic_shapes)(replace_string_by_dynamic(i) for i in dynamic_shapes)
if isinstance(dynamic_shapes, dict):
return {k: replace_string_by_dynamic(v) for k, v in dynamic_shapes.items()}
raise AssertionError(f"Unexpected type {type(dynamic_shapes)} for dynamic_shapes")
[docs]
def dummy_llm(
cls_name: Optional[str] = None,
dynamic_shapes: bool = False,
) -> Union[
Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]],
Tuple[torch.nn.Module, Tuple[torch.Tensor, ...], Any],
]:
"""
Creates a dummy LLM for test purposes.
:param cls_name: None for whole model or a piece of it
:param dynamic_shapes: returns dynamic shapes as well
.. runpython::
:showcode:
from onnx_diagnostic.helpers.torch_helper import dummy_llm
print(dummy_llm())
"""
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int = 16, context_size: int = 256):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
# torch.nn.Buffer are not fully handled by symbolic tracing
# Buffer(...)[:Prowy()] is not working
self.mask = torch.nn.Parameter(
torch.tril(
input=torch.ones(size=[context_size, context_size], dtype=torch.float)
)
)
def forward(self, x):
B, T, C = x.shape
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(
self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256
):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int = 16, ff_dim: int = 128):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(
self,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
if cls_name in (None, "LLM"):
dec: torch.nn.Module = LLM()
x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
dec(x)
if dynamic_shapes:
dyn = {
"input_ids": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "DecoderLayer":
LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64))
dec = DecoderLayer()
x = Embedding()(
torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "MultiAttentionBlock":
dec = MultiAttentionBlock()
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "AttentionBlock":
dec = AttentionBlock()
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
raise NotImplementedError(f"cls_name={cls_name}")
[docs]
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
"""Applies torch.to if applicable. Goes recursively."""
if isinstance(value, (torch.nn.Module, torch.Tensor)):
return value.to(to_value)
if isinstance(value, list):
return [to_any(t, to_value) for t in value]
if isinstance(value, tuple):
return tuple(to_any(t, to_value) for t in value)
if isinstance(value, set):
return {to_any(t, to_value) for t in value}
if isinstance(value, dict):
return {k: to_any(t, to_value) for k, t in value.items()}
if hasattr(value, "to"):
return value.to(to_value)
if value.__class__.__name__ == "DynamicCache":
return make_dynamic_cache(
list(
zip(
[t.to(to_value) for t in value.key_cache],
[t.to(to_value) for t in value.value_cache],
)
)
)
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
args, spec = torch.utils._pytree.tree_flatten(value)
new_args = to_any(args, to_value)
return torch.utils._pytree.tree_unflatten(new_args, spec)
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
return value
[docs]
def torch_deepcopy(value: Any) -> Any:
"""Makes a deepcopy."""
if value is None:
return None
if isinstance(value, (int, float, str)):
return value
if isinstance(value, tuple):
return tuple(torch_deepcopy(v) for v in value)
if isinstance(value, list):
return [torch_deepcopy(v) for v in value]
if isinstance(value, set):
return {torch_deepcopy(v) for v in value}
if isinstance(value, dict):
if type(value) is dict:
return {k: torch_deepcopy(v) for k, v in value.items()}
# for BaseModelOutput
return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()})
if isinstance(value, np.ndarray):
return value.copy()
if hasattr(value, "clone"):
return value.clone()
if value.__class__.__name__ == "DynamicCache":
return make_dynamic_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
)
if value.__class__.__name__ == "SlidingWindowCache":
return make_sliding_window_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
)
if value.__class__.__name__ == "EncoderDecoderCache":
return make_encoder_decoder_cache(
torch_deepcopy(value.self_attention_cache),
torch_deepcopy(value.cross_attention_cache),
)
if value.__class__.__name__ == "MambaCache":
return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
args, spec = torch.utils._pytree.tree_flatten(value)
new_args = torch_deepcopy(args)
return torch.utils._pytree.tree_unflatten(new_args, spec)
# We should have a code using serialization, deserialization assuming a model
# cannot be exported without them.
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
[docs]
def torch_tensor_size(value: Any) -> Any:
"""Returns the number of bytes stored in tensors."""
if value is None:
return 0
if isinstance(value, (int, float, str)):
return 0
if isinstance(value, (tuple, list, set)):
return sum(torch_tensor_size(v) for v in value)
if isinstance(value, dict):
return sum(torch_tensor_size(v) for v in value.values())
if isinstance(value, np.ndarray):
return value.copy()
if hasattr(value, "clone"):
return value.numel() * size_type(value.dtype)
if value.__class__.__name__ in {"DynamicCache", "SlidingWindowCache"}:
return torch_tensor_size(value.key_cache) + torch_tensor_size(value.value_cache)
if value.__class__.__name__ == "EncoderDecoderCache":
return torch_tensor_size(value.self_attention_cache) + torch_tensor_size(
value.cross_attention_cache
)
if value.__class__.__name__ == "MambaCache":
return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
args, spec = torch.utils._pytree.tree_flatten(value)
return sum(torch_tensor_size(a) for a in args)
# We should have a code using serialization, deserialization assuming a model
# cannot be exported without them.
raise NotImplementedError(f"torch_tensor_size not implemented for type {type(value)}")
[docs]
def model_statistics(model: torch.nn.Module):
"""Returns statistics on a model in a dictionary."""
n_subs = len(list(model.modules()))
sizes = {}
param_size = 0
for param in model.parameters():
size = param.nelement() * param.element_size()
param_size += size
name = str(param.dtype).replace("torch.", "")
if name not in sizes:
sizes[name] = 0
sizes[name] += size
buffer_size = 0
for buffer in model.buffers():
size = buffer.nelement() * buffer.element_size()
buffer_size += size
name = str(buffer.dtype).replace("torch.", "")
if name not in sizes:
sizes[name] = 0
sizes[name] += size
res = dict(
type=model.__class__.__name__,
n_modules=n_subs,
param_size=param_size,
buffer_size=buffer_size,
size_mb=(param_size + buffer_size) // 2**20,
)
res.update(sizes)
return res
[docs]
def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor:
"""
Converts a TensorProto to a numpy array.
:param tensor: a TensorProto object.
:param base_dir: if external tensor exists, base_dir can help to find the path to it
:return: the converted tensor
"""
assert not tensor.HasField("segment"), "Currently not supporting loading segments."
assert (
tensor.data_type != onnx.TensorProto.UNDEFINED
), "The element type in the input tensor is not defined."
assert tensor.data_type != onnx.TensorProto.STRING, "to_tensor not implemented for strings"
tensor_dtype = tensor.data_type
torch_dtype = onnx_dtype_to_torch_dtype(tensor_dtype)
dims = tuple(tensor.dims)
if uses_external_data(tensor):
# Load raw data from external tensor if it exists
load_external_data_for_tensor(tensor, base_dir)
if tensor.HasField("raw_data"):
raw_data = tensor.raw_data
if sys.byteorder == "big":
# Convert endian from little to big
raw_data = torch.frombuffer(raw_data, dtype=torch_dtype).byteswap().tobytes()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.frombuffer(raw_data, dtype=torch_dtype).reshape(dims)
# Other cases, it should be small tensor. We use numpy.
np_tensor = to_array_extended(tensor)
return torch.from_numpy(np_tensor)