Source code for onnx_extended.ortops.tutorial.cuda
import os
import textwrap
from typing import List
from ... import _get_ort_ext_libs
[docs]def get_ort_ext_libs() -> List[str]:
"""
Returns the list of libraries implementing new simple
:epkg:`onnxruntime` kernels implemented for the
:epkg:`CUDAExecutionProvider`.
"""
libs = _get_ort_ext_libs(os.path.dirname(__file__))
return [lib for lib in libs if "cuda_cuda" not in lib]
def documentation() -> List[str]:
"""
Returns a list of rst string documenting every implemented kernels
in this subfolder.
"""
return list(
map(
textwrap.dedent,
[
"""
onnx_extented.ortops.tutorial.cuda.CustomGemm
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
It calls CUDA library for Gemm :math:`\\alpha A B + \\beta C`.
**Provider**
CUDAExecutionProvider
**Inputs**
* A (T): tensor of type T
* B (T): tensor of type T
* C (T): tensor of type T
* D (T): tensor of type T
* E (T): tensor of type T
**Outputs**
* Z (T): :math:`\\alpha A B + \\beta C`
**Constraints**
* T: float, float16, bfloat16
"""
],
)
)