Source code for onnx_diagnostic.helpers.torch_test_helper

import contextlib
from collections.abc import Iterable
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
from .helper import string_type
from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache


def _forward_(*args, _f=None, _context=None, **kwargs):
    assert _f is not None, "_f cannot be None"
    assert _context is not None, "_context cannot be None"
    print(
        f"---- stolen forward for class {_context['class_name']} "
        f"-- iteration {_context['iteration']}"
    )
    kws = dict(
        with_shape=_context.get("with_shape", False),
        with_min_max=_context.get("with_min_max", False),
    )
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        # torch.compiler.is_exporting requires torch>=2.7
        print(f"  <- args={string_type(args, **kws)} --- kwargs={string_type(kwargs, **kws)}")
    res = _f(*args, **kwargs)
    if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
        print("  --")
        print(f"  -> {string_type(res, **kws)}")
        print(".")
    _context["iteration"] += 1
    return res


[docs] @contextlib.contextmanager def steel_forward(model: torch.nn.Module, with_shape: bool = True, with_min_max: bool = False): """ The necessary modification to steem forward method and prints out inputs and outputs. See example :ref:`l-plot-tiny-llm-export`. """ context = dict( iteration=0, class_name=model.__class__.__name__, with_shape=with_shape, with_min_max=with_min_max, ) keep_model_forward = model.forward model.forward = lambda *args, _f=keep_model_forward, _context=context, **kwargs: _forward_( *args, _f=_f, _context=_context, **kwargs ) try: yield finally: model.forward = keep_model_forward
[docs] def is_torchdynamo_exporting() -> bool: """Tells if torch is exporting a model.""" import torch if not hasattr(torch.compiler, "is_exporting"): # torch.compiler.is_exporting requires torch>=2.7 return False try: return torch.compiler.is_exporting() except Exception: try: import torch._dynamo as dynamo return dynamo.is_exporting() # type: ignore except Exception: return False
[docs] def to_numpy(tensor: "torch.Tensor"): # noqa: F821 """Converts a torch tensor to numy.""" try: return tensor.numpy() except TypeError: # We try with ml_dtypes pass import ml_dtypes conv = {torch.bfloat16: ml_dtypes.bfloat16} assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}" return tensor.to(torch.float32).numpy().astype(conv[tensor.dtype])
[docs] def replace_string_by_dynamic(dynamic_shapes: Any) -> Any: """Replaces strings by ``torch.export.Dim.DYNAMIC``.""" import torch if isinstance(dynamic_shapes, torch.export.dynamic_shapes._Dim): return dynamic_shapes if isinstance(dynamic_shapes, str): return torch.export.Dim.DYNAMIC if not dynamic_shapes: return dynamic_shapes if isinstance(dynamic_shapes, (tuple, list)): return type(dynamic_shapes)(replace_string_by_dynamic(i) for i in dynamic_shapes) if isinstance(dynamic_shapes, dict): return {k: replace_string_by_dynamic(v) for k, v in dynamic_shapes.items()} raise AssertionError(f"Unexpected type {type(dynamic_shapes)} for dynamic_shapes")
[docs] def dummy_llm( cls_name: Optional[str] = None, dynamic_shapes: bool = False, ) -> Union[ Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]], Tuple[torch.nn.Module, Tuple[torch.Tensor, ...], Any], ]: """ Creates a dummy LLM for test purposes. :param cls_name: None for whole model or a piece of it :param dynamic_shapes: returns dynamic shapes as well .. runpython:: :showcode: from onnx_diagnostic.helpers.torch_test_helper import dummy_llm print(dummy_llm()) """ class Embedding(torch.nn.Module): def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16): super().__init__() self.embedding = torch.nn.Embedding(vocab_size, embedding_dim) self.pe = torch.nn.Embedding(vocab_size, embedding_dim) def forward(self, x): word_emb = self.embedding(x) word_pe = self.pe(x) return word_emb + word_pe class AttentionBlock(torch.nn.Module): def __init__(self, embedding_dim: int = 16, context_size: int = 256): super().__init__() self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) # torch.nn.Buffer are not fully handled by symbolic tracing # Buffer(...)[:Prowy()] is not working self.mask = torch.nn.Parameter( torch.tril( input=torch.ones(size=[context_size, context_size], dtype=torch.float) ) ) def forward(self, x): B, T, C = x.shape query = self.query(x) key = self.key(x) value = self.value(x) qk = query @ key.transpose(-2, -1) * C**-0.5 attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf")) attention = torch.nn.functional.softmax(input=attention, dim=-1) out = attention @ value return out class MultiAttentionBlock(torch.nn.Module): def __init__( self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256 ): super().__init__() self.attention = torch.nn.ModuleList( modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)] ) self.linear = torch.nn.Linear( in_features=embedding_dim * num_heads, out_features=embedding_dim ) def forward(self, x): out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1) x = self.linear(out) return x class FeedForward(torch.nn.Module): def __init__(self, embedding_dim: int = 16, ff_dim: int = 128): super().__init__() self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim) self.relu = torch.nn.ReLU() self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim) def forward(self, x): x = self.linear_1(x) x = self.relu(x) x = self.linear_2(x) return x class DecoderLayer(torch.nn.Module): def __init__( self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256, ff_dim: int = 128, ): super().__init__() self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size) self.feed_forward = FeedForward(embedding_dim, ff_dim) self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim) self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim) def forward(self, x): x_norm = self.norm_1(x) attention = self.attention(x_norm) attention = attention + x attention_norm = self.norm_2(attention) ff = self.feed_forward(attention_norm) ff = ff + attention return ff class LLM(torch.nn.Module): def __init__( self, vocab_size: int = 1024, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256, ff_dim: int = 128, ): super().__init__() self.embedding = Embedding(vocab_size, embedding_dim) self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim) def forward(self, input_ids): x = self.embedding(input_ids) y = self.decoder(x) return y if cls_name in (None, "LLM"): dec: torch.nn.Module = LLM() x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64) dec(x) if dynamic_shapes: dyn = { "input_ids": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) if cls_name == "DecoderLayer": LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)) dec = DecoderLayer() x = Embedding()( torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64) ) dec(x) if dynamic_shapes: dyn = { "x": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) if cls_name == "MultiAttentionBlock": dec = MultiAttentionBlock() x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32) dec(x) if dynamic_shapes: dyn = { "x": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) if cls_name == "AttentionBlock": dec = AttentionBlock() x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32) dec(x) if dynamic_shapes: dyn = { "x": { 0: torch.export.Dim("batch", min=1, max=1024), 1: torch.export.Dim("length", min=1, max=255), } } return dec, (x,), dyn return dec, (x,) raise NotImplementedError(f"cls_name={cls_name}")
[docs] def to_any(value: Any, to_value: Union[torch.dtype, torch.device]) -> Any: """ Applies torch.to is applicables. Goes recursively. """ if isinstance(value, (torch.nn.Module, torch.Tensor)): return value.to(to_value) if isinstance(value, list): return [to_any(t, to_value) for t in value] if isinstance(value, tuple): return tuple(to_any(t, to_value) for t in value) if isinstance(value, set): return {to_any(t, to_value) for t in value} if isinstance(value, dict): return {k: to_any(t, to_value) for k, t in value.items()} if hasattr(value, "to"): return value.to(to_value) if value.__class__.__name__ == "DynamicCache": return make_dynamic_cache( list( zip( [t.to(to_value) for t in value.key_cache], [t.to(to_value) for t in value.value_cache], ) ) ) assert not isinstance(value, Iterable), f"Unsupported type {type(value)}" return value
[docs] def torch_deepcopy(value: Any) -> Any: """ Makes a deepcopy. """ if isinstance(value, (int, float, str)): return value if isinstance(value, tuple): return tuple(torch_deepcopy(v) for v in value) if isinstance(value, list): return [torch_deepcopy(v) for v in value] if isinstance(value, set): return {torch_deepcopy(v) for v in value} if isinstance(value, dict): return {k: torch_deepcopy(v) for k, v in value.items()} if isinstance(value, np.ndarray): return value.copy() if hasattr(value, "clone"): return value.clone() if value.__class__.__name__ == "DynamicCache": return make_dynamic_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) ) if value.__class__.__name__ == "EncoderDecoderCache": return make_encoder_decoder_cache( torch_deepcopy(value.self_attention_cache), torch_deepcopy(value.cross_attention_cache), ) # We should have a code using serialization, deserialization assuming a model # cannot be exported without them. raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")