import contextlib
from collections.abc import Iterable
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
from .helper import string_type
from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache
def _forward_(*args, _f=None, _context=None, **kwargs):
assert _f is not None, "_f cannot be None"
assert _context is not None, "_context cannot be None"
print(
f"---- stolen forward for class {_context['class_name']} "
f"-- iteration {_context['iteration']}"
)
kws = dict(
with_shape=_context.get("with_shape", False),
with_min_max=_context.get("with_min_max", False),
)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
# torch.compiler.is_exporting requires torch>=2.7
print(f" <- args={string_type(args, **kws)} --- kwargs={string_type(kwargs, **kws)}")
res = _f(*args, **kwargs)
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
print(" --")
print(f" -> {string_type(res, **kws)}")
print(".")
_context["iteration"] += 1
return res
[docs]
@contextlib.contextmanager
def steel_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False):
"""
The necessary modification to steem forward method and prints out inputs
and outputs. See example :ref:`l-plot-tiny-llm-export`.
"""
context = dict(
iteration=0,
class_name=model.__class__.__name__,
with_shape=with_shape,
with_min_max=with_min_max,
)
keep_model_forward = model.forward
model.forward = lambda *args, _f=keep_model_forward, _context=context, **kwargs: _forward_(
*args, _f=_f, _context=_context, **kwargs
)
try:
yield
finally:
model.forward = keep_model_forward
[docs]
def is_torchdynamo_exporting() -> bool:
"""Tells if torch is exporting a model."""
import torch
if not hasattr(torch.compiler, "is_exporting"):
# torch.compiler.is_exporting requires torch>=2.7
return False
try:
return torch.compiler.is_exporting()
except Exception:
try:
import torch._dynamo as dynamo
return dynamo.is_exporting() # type: ignore
except Exception:
return False
[docs]
def to_numpy(tensor: "torch.Tensor"): # noqa: F821
"""Converts a torch tensor to numy."""
try:
return tensor.numpy()
except TypeError:
# We try with ml_dtypes
pass
import ml_dtypes
conv = {torch.bfloat16: ml_dtypes.bfloat16}
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
[docs]
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
"""Replaces strings by ``torch.export.Dim.DYNAMIC``."""
import torch
if isinstance(dynamic_shapes, torch.export.dynamic_shapes._Dim):
return dynamic_shapes
if isinstance(dynamic_shapes, str):
return torch.export.Dim.DYNAMIC
if not dynamic_shapes:
return dynamic_shapes
if isinstance(dynamic_shapes, (tuple, list)):
return type(dynamic_shapes)(replace_string_by_dynamic(i) for i in dynamic_shapes)
if isinstance(dynamic_shapes, dict):
return {k: replace_string_by_dynamic(v) for k, v in dynamic_shapes.items()}
raise AssertionError(f"Unexpected type {type(dynamic_shapes)} for dynamic_shapes")
[docs]
def dummy_llm(
cls_name: Optional[str] = None,
dynamic_shapes: bool = False,
) -> Union[
Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]],
Tuple[torch.nn.Module, Tuple[torch.Tensor, ...], Any],
]:
"""
Creates a dummy LLM for test purposes.
:param cls_name: None for whole model or a piece of it
:param dynamic_shapes: returns dynamic shapes as well
.. runpython::
:showcode:
from onnx_diagnostic.helpers.torch_test_helper import dummy_llm
print(dummy_llm())
"""
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int = 16, context_size: int = 256):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
# torch.nn.Buffer are not fully handled by symbolic tracing
# Buffer(...)[:Prowy()] is not working
self.mask = torch.nn.Parameter(
torch.tril(
input=torch.ones(size=[context_size, context_size], dtype=torch.float)
)
)
def forward(self, x):
B, T, C = x.shape
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(
self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256
):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int = 16, ff_dim: int = 128):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(
self,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
if cls_name in (None, "LLM"):
dec: torch.nn.Module = LLM()
x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
dec(x)
if dynamic_shapes:
dyn = {
"input_ids": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "DecoderLayer":
LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64))
dec = DecoderLayer()
x = Embedding()(
torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "MultiAttentionBlock":
dec = MultiAttentionBlock()
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "AttentionBlock":
dec = AttentionBlock()
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
raise NotImplementedError(f"cls_name={cls_name}")
[docs]
def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any:
"""
Applies torch.to is applicables.
Goes recursively.
"""
if isinstance(value, (torch.nn.Module, torch.Tensor)):
return value.to(to_value)
if isinstance(value, list):
return [to_any(t, to_value) for t in value]
if isinstance(value, tuple):
return tuple(to_any(t, to_value) for t in value)
if isinstance(value, set):
return {to_any(t, to_value) for t in value}
if isinstance(value, dict):
return {k: to_any(t, to_value) for k, t in value.items()}
if hasattr(value, "to"):
return value.to(to_value)
if value.__class__.__name__ == "DynamicCache":
return make_dynamic_cache(
list(
zip(
[t.to(to_value) for t in value.key_cache],
[t.to(to_value) for t in value.value_cache],
)
)
)
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
return value
[docs]
def torch_deepcopy(value: Any) -> Any:
"""
Makes a deepcopy.
"""
if isinstance(value, (int, float, str)):
return value
if isinstance(value, tuple):
return tuple(torch_deepcopy(v) for v in value)
if isinstance(value, list):
return [torch_deepcopy(v) for v in value]
if isinstance(value, set):
return {torch_deepcopy(v) for v in value}
if isinstance(value, dict):
return {k: torch_deepcopy(v) for k, v in value.items()}
if isinstance(value, np.ndarray):
return value.copy()
if hasattr(value, "clone"):
return value.clone()
if value.__class__.__name__ == "DynamicCache":
return make_dynamic_cache(
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
)
if value.__class__.__name__ == "EncoderDecoderCache":
return make_encoder_decoder_cache(
torch_deepcopy(value.self_attention_cache),
torch_deepcopy(value.cross_attention_cache),
)
# We should have a code using serialization, deserialization assuming a model
# cannot be exported without them.
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")