Source code for yobx.helpers.torch_helper

import numpy as np
import onnx
import torch
from .onnx_helper import onnx_dtype_name

_TYPENAME = dict(
    FLOAT=onnx.TensorProto.FLOAT,
    INT64=onnx.TensorProto.INT64,
    INT32=onnx.TensorProto.INT32,
    FLOAT16=onnx.TensorProto.FLOAT16,
    BFLOAT16=onnx.TensorProto.BFLOAT16,
)


[docs] def onnx_dtype_to_torch_dtype(itype: int) -> torch.dtype: """ Converts an onnx type into a torch dtype. :param to: onnx dtype :return: torch dtype """ if itype == onnx.TensorProto.FLOAT: return torch.float32 if itype == onnx.TensorProto.FLOAT16: return torch.float16 if itype == onnx.TensorProto.BFLOAT16: return torch.bfloat16 if itype == onnx.TensorProto.DOUBLE: return torch.float64 if itype == onnx.TensorProto.INT32: return torch.int32 if itype == onnx.TensorProto.INT64: return torch.int64 if itype == onnx.TensorProto.UINT32: return torch.uint32 if itype == onnx.TensorProto.UINT64: return torch.uint64 if itype == onnx.TensorProto.BOOL: return torch.bool if itype == onnx.TensorProto.INT16: return torch.int16 if itype == onnx.TensorProto.UINT16: return torch.uint16 if itype == onnx.TensorProto.INT8: return torch.int8 if itype == onnx.TensorProto.UINT8: return torch.uint8 if itype == onnx.TensorProto.COMPLEX64: return torch.complex64 if itype == onnx.TensorProto.COMPLEX128: return torch.complex128 raise NotImplementedError( f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type." )
[docs] def torch_dtype_to_onnx_dtype(to: torch.dtype) -> int: """ Converts a torch dtype into a onnx element type. :param to: torch dtype :return: onnx type """ import torch if to == torch.float32: return onnx.TensorProto.FLOAT if to == torch.float16: return onnx.TensorProto.FLOAT16 if to == torch.bfloat16: return onnx.TensorProto.BFLOAT16 if to == torch.float64: return onnx.TensorProto.DOUBLE if to == torch.int64: return onnx.TensorProto.INT64 if to == torch.int32: return onnx.TensorProto.INT32 if to == torch.uint64: return onnx.TensorProto.UINT64 if to == torch.uint32: return onnx.TensorProto.UINT32 if to == torch.bool: return onnx.TensorProto.BOOL if to == torch.SymInt: return onnx.TensorProto.INT64 if to == torch.int16: return onnx.TensorProto.INT16 if to == torch.uint16: return onnx.TensorProto.UINT16 if to == torch.int8: return onnx.TensorProto.INT8 if to == torch.uint8: return onnx.TensorProto.UINT8 if to == torch.SymFloat: return onnx.TensorProto.FLOAT if to == torch.complex64: return onnx.TensorProto.COMPLEX64 if to == torch.complex128: return onnx.TensorProto.COMPLEX128 # SymbolicTensor sto = str(to) if sto in _TYPENAME: return _TYPENAME[sto] raise NotImplementedError(f"Unable to convert torch dtype {to!r} ({type(to)}) to onnx dtype.")
[docs] def to_numpy(tensor: torch.Tensor) -> np.ndarray: """Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`.""" try: return tensor.detach().cpu().numpy() except TypeError: # We try with ml_dtypes pass import ml_dtypes conv = {torch.bfloat16: ml_dtypes.bfloat16} assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}" return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])