import os
from typing import Dict, List, Optional, Tuple
import onnx
import onnx.helper as oh
import torch
from ..reference.torch_ops import OpRunKernel, OpRunTensor
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
from .ort_session import InferenceSessionForTorch
_SAVED: List[str] = []
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
def _get_model_name(op_name: str, provider: str) -> Optional[str]:
if _SAVE_OPTIMIZED_MODEL_:
name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
_SAVED.append(name)
return name
return None
[docs]
class LayerNormalizationOrt(OpRunKernel):
"LayerNormalization with onnxruntime"
[docs]
@classmethod
def device_dependent(cls) -> bool:
"Needs device."
return True
def __init__(
self,
node: onnx.NodeProto,
version=None,
device: Optional[torch.device] = None,
verbose: int = 0,
):
super().__init__(node, version, verbose=verbose)
self.axis = self.get_attribute_int(node, "axis", -1)
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
self.device = device
self.stash_type = onnx_dtype_to_torch_dtype(
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
)
self.compute_std = len(node.output) > 1
assert not self.compute_std, (
f"This kernel implementation only work when only one output "
f"is required but {node.output} were."
)
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
self.is_cpu = torch.device("cpu") == self.device
def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
shape = [*["d{i}" for i in range(rank - 1)], "last"]
layer_model = oh.make_model(
oh.make_graph(
[
oh.make_node(
"LayerNormalization",
["X", "W", "B"] if has_bias else ["X", "W"],
["Z"],
axis=self.axis,
epsilon=self.epsilon,
)
],
"dummy",
(
[
oh.make_tensor_value_info("X", itype, shape),
oh.make_tensor_value_info("W", itype, ["last"]),
oh.make_tensor_value_info("B", itype, ["last"]),
]
if has_bias
else [
oh.make_tensor_value_info("X", itype, shape),
oh.make_tensor_value_info("W", itype, ["last"]),
]
),
[oh.make_tensor_value_info("Z", itype, shape)],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
self._provider = provider
return InferenceSessionForTorch(
layer_model,
optimized_model_filepath=_get_model_name("layer_norm", provider),
providers=[provider],
)
[docs]
def run(self, x, scale, bias=None):
itype = torch_dtype_to_onnx_dtype(x.dtype)
rank = len(x.shape)
key = itype, rank
if key not in self._cache:
self._cache[key] = self._make_model(itype, rank, bias is not None)
sess = self._cache[key]
if self.verbose:
print(f"[LayerNormalizationOrt] running on {self._provider!r}")
feeds = dict(X=x.tensor, W=scale.tensor)
if bias is not None:
feeds["B"] = bias.tensor
got = sess.run(None, feeds)[0]
return OpRunTensor(got)
[docs]
class MatMulOrt(OpRunKernel):
"MatMul with onnxruntime"
[docs]
@classmethod
def device_dependent(cls) -> bool:
"Needs device."
return True
def __init__(
self,
node: onnx.NodeProto,
version=None,
device: Optional[torch.device] = None,
verbose: int = 0,
):
super().__init__(node, version, verbose=verbose)
self.device = device
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
self.is_cpu = torch.device("cpu") == self.device
def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
shapea = ["a{i}" for i in range(ranka)]
shapeb = ["b{i}" for i in range(rankb)]
shapec = ["c{i}" for i in range(max(ranka, rankb))]
model = oh.make_model(
oh.make_graph(
[oh.make_node("MatMul", ["A", "B"], ["C"])],
"dummy",
[
oh.make_tensor_value_info("A", itype, shapea),
oh.make_tensor_value_info("B", itype, shapeb),
],
[oh.make_tensor_value_info("C", itype, shapec)],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
self._provider = provider
return InferenceSessionForTorch(
model,
optimized_model_filepath=_get_model_name("matmul", provider),
providers=[provider],
)
[docs]
def run(self, a, b):
itype = torch_dtype_to_onnx_dtype(a.dtype)
ranka, rankb = len(a.shape), len(b.shape)
key = itype, ranka, rankb
if key not in self._cache:
self._cache[key] = self._make_model(itype, ranka, rankb)
sess = self._cache[key]
if self.verbose:
print(f"[MatMulOrt] running on {self._provider!r}")
feeds = dict(A=a.tensor, B=b.tensor)
got = sess.run(None, feeds)[0]
return OpRunTensor(got)