"""
More complex helpers used in unit tests.
"""
import contextlib
import io
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from onnx import ModelProto, save
from .helpers import pretty_onnx
[docs]
def check_model_ort(
onx: ModelProto,
providers: Optional[Union[str, List[str]]] = None,
dump_file: Optional[str] = None,
) -> "onnxruntime.InferenceSession": # noqa: F821
"""
Loads a model with onnxruntime.
:param onx: ModelProto
:param providers: list of providers, None fur CPU, cpu for CPU, cuda for CUDA
:param dump_file: if not empty, dumps the model into this file if
an error happened
:return: InferenceSession
"""
from onnxruntime import InferenceSession
if providers is None or providers == "cpu":
providers = ["CPUExecutionProvider"]
elif not isinstance(providers, list) and providers.startswith("cuda"):
device_id = 0 if ":" not in providers else int(providers.split(":")[1])
providers = [
("CUDAExecutionProvider", {"device_id": device_id}),
("CPUExecutionProvider", {}),
]
if isinstance(onx, str):
try:
return InferenceSession(onx, providers=providers)
except Exception as e:
import onnx
if dump_file:
save(onx, dump_file)
raise AssertionError( # noqa: B904
f"onnxruntime cannot load the model "
f"due to {e}\n{pretty_onnx(onnx.load(onx))}"
)
return
try:
return InferenceSession(onx.SerializeToString(), providers=providers)
except Exception as e:
if dump_file:
save(onx, dump_file)
raise AssertionError( # noqa: B904
f"onnxruntime cannot load the modeldue to {e}\n{pretty_onnx(onx)}"
)
[docs]
def export_to_onnx(
model: Any,
*args: List[Any],
verbose: int = 0,
return_builder: bool = False,
torch_script: bool = True,
target_opset: int = 18,
prefix: Optional[str] = None,
rename_inputs: bool = False,
optimize: Union[str, bool] = True,
folder: Optional[str] = "dump_test",
export_options: Optional["ExportOptions"] = None, # noqa: F821
) -> Dict[str, Union[str, ModelProto, "GraphBuilder"]]: # noqa: F821
"""
Exports a model to ONNX.
:param model: model to export
:param args: arguments
:param verbose: verbosity
:param return_builder: returns the builder
:param torch_script: export with torch.script as well
:param target_opset: opset to export into
:param prefix: prefix to choose to export into
:param rename_inputs: rename the inputs into ``input_{i}``
:param optimize: enable, disable optimizations of pattern to test
:param folder: where to dump the model, creates it if it does not exist
:param export_options: see :class:`ExportOptions
<experimental_experiment.torch_interpreter.ExportOptions>`
:return: dictionary with ModelProto, builder, filenames
"""
from .xbuilder import OptimizationOptions
from .torch_interpreter import to_onnx
ret = {}
if torch_script and prefix is not None:
import torch
filename = f"{prefix}.onnx"
with contextlib.redirect_stdout(io.StringIO()), warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.onnx.export(model, args, filename, input_names=["input"])
ret["torch.script"] = filename
if isinstance(optimize, str):
options = OptimizationOptions(verbose=verbose, patterns=optimize)
else:
options = OptimizationOptions(verbose=verbose)
onx = to_onnx(
model,
tuple(args),
input_names=[f"input{i}" for i in range(len(args))] if rename_inputs else None,
options=options,
verbose=verbose,
return_builder=return_builder,
optimize=optimize,
export_options=export_options,
)
ret["proto"] = onx
if prefix is not None:
filename = f"{prefix}.custom.onnx"
if folder:
if not os.path.exists(folder):
os.makedirs(folder)
filename = os.path.join(folder, filename)
with open(filename, "wb") as f:
f.write((onx[0] if return_builder else onx).SerializeToString())
ret["custom"] = filename
return ret
[docs]
def dummy_llm(
cls_name: Optional[str] = None,
dynamic_shapes: bool = False,
) -> Tuple["torch.nn.Module", Tuple["torch.Tensor", ...]]: # noqa: F821
"""
Creates a dummy LLM for test purposes.
:param cls_name: None for whole model or a piece of it
:param dynamic_shapes: returns dynamic shapes as well
.. runpython::
:showcode:
from experimental_experiment.torch_test_helper import dummy_llm
print(dummy_llm())
"""
import torch
class Embedding(torch.nn.Module):
def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16):
super().__init__()
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
def forward(self, x):
word_emb = self.embedding(x)
word_pe = self.pe(x)
return word_emb + word_pe
class AttentionBlock(torch.nn.Module):
def __init__(self, embedding_dim: int = 16, context_size: int = 256):
super().__init__()
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
# torch.nn.Buffer are not fully handled by symbolic tracing
# Buffer(...)[:Prowy()] is not working
self.mask = torch.nn.Parameter(
torch.tril(
input=torch.ones(size=[context_size, context_size], dtype=torch.float)
)
)
def forward(self, x):
B, T, C = x.shape
query = self.query(x)
key = self.key(x)
value = self.value(x)
qk = query @ key.transpose(-2, -1) * C**-0.5
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
attention = torch.nn.functional.softmax(input=attention, dim=-1)
out = attention @ value
return out
class MultiAttentionBlock(torch.nn.Module):
def __init__(
self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256
):
super().__init__()
self.attention = torch.nn.ModuleList(
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
)
self.linear = torch.nn.Linear(
in_features=embedding_dim * num_heads, out_features=embedding_dim
)
def forward(self, x):
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
x = self.linear(out)
return x
class FeedForward(torch.nn.Module):
def __init__(self, embedding_dim: int = 16, ff_dim: int = 128):
super().__init__()
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
def forward(self, x):
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(
self,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
self.feed_forward = FeedForward(embedding_dim, ff_dim)
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
def forward(self, x):
x_norm = self.norm_1(x)
attention = self.attention(x_norm)
attention = attention + x
attention_norm = self.norm_2(attention)
ff = self.feed_forward(attention_norm)
ff = ff + attention
return ff
class LLM(torch.nn.Module):
def __init__(
self,
vocab_size: int = 1024,
embedding_dim: int = 16,
num_heads: int = 2,
context_size: int = 256,
ff_dim: int = 128,
):
super().__init__()
self.embedding = Embedding(vocab_size, embedding_dim)
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
def forward(self, input_ids):
x = self.embedding(input_ids)
y = self.decoder(x)
return y
if cls_name in (None, "LLM"):
dec = LLM()
x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
dec(x)
if dynamic_shapes:
dyn = {
"input_ids": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "DecoderLayer":
LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64))
dec = DecoderLayer()
x = Embedding()(
torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "MultiAttentionBlock":
dec = MultiAttentionBlock()
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
if cls_name == "AttentionBlock":
dec = AttentionBlock()
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
dec(x)
if dynamic_shapes:
dyn = {
"x": {
0: torch.export.Dim("batch", min=1, max=1024),
1: torch.export.Dim("length", min=1, max=255),
}
}
return dec, (x,), dyn
return dec, (x,)
raise NotImplementedError(f"cls_name={cls_name}")