Source code for onnx_array_api.validation.f8

import struct
import numpy

# display functions


[docs]class UndefinedCastError(FloatingPointError): """ Unable to case a number. """ pass
[docs]def display_int(ival, sign=1, exponent=8, mantissa=23): """ Displays an integer as bits. :param ival: value to display (float32) :param sign: number of bits for the sign :param exponent: number of bits for the exponent :param mantissa: number of bits for the mantissa :return: string """ t = sign + exponent + mantissa s = bin(ival)[2:] s = "0" * (t - len(s)) + s s1 = s[:sign] s2 = s[sign : sign + exponent] s3 = s[sign + exponent :] return ".".join([s1, s2, s3])
[docs]def display_float32(value, sign=1, exponent=8, mantissa=23): """ Displays a float32 into b. :param value: value to display (float32) :param sign: number of bits for the sign :param exponent: number of bits for the exponent :param mantissa: number of bits for the mantissa :return: string """ return display_int( int.from_bytes(struct.pack("<f", numpy.float32(value)), "little"), sign=sign, exponent=exponent, mantissa=mantissa, )
[docs]def display_float16(value, sign=1, exponent=5, mantissa=10): """ Displays a float32 into b. :param value: value to display (float16) :param sign: number of bits for the sign :param exponent: number of bits for the exponent :param mantissa: number of bits for the mantissa :return: string """ return display_int( numpy.float16(value).view("H"), sign=sign, exponent=exponent, mantissa=mantissa )
[docs]def display_fexmx(value, sign, exponent, mantissa): """ Displays any float encoded with 1 bit for the sign, *exponent* bit for the exponent and *mantissa* bit for the mantissa. :param value: value to display (int) :param sign: number of bits for the sign :param exponent: number of bits for the exponent :param mantissa: number of bits for the mantissa :return: string """ return display_int(value, sign=sign, exponent=exponent, mantissa=mantissa)
[docs]def display_fe4m3(value, sign=1, exponent=4, mantissa=3): """ Displays a float 8 E4M3 into b. :param value: value to display (int) :param sign: number of bits for the sign :param exponent: number of bits for the exponent :param mantissa: number of bits for the mantissa :return: string """ return display_fexmx(value, sign=1, exponent=4, mantissa=3)
[docs]def display_fe5m2(value, sign=1, exponent=4, mantissa=3): """ Displays a float 8 E5M2 into binary format. :param value: value to display (int) :param sign: number of bits for the sign :param exponent: number of bits for the exponent :param mantissa: number of bits for the mantissa :return: string """ return display_fexmx(value, sign=1, exponent=5, mantissa=2)
# cast from float 8 to float 32
[docs]def fe4m3_to_float32_float(ival: int, fn: bool = True, uz: bool = False) -> float: """ Casts a float 8 encoded as an integer into a float. :param ival: byte :param fn: no infinite values :param uz: no negative zero :return: float (float 32) """ if not fn: raise NotImplementedError("fn=False is not implemented.") if ival < 0 or ival > 255: raise ValueError(f"{ival} is not a float8.") if ival == 0: return numpy.float32(0) if uz: if ival == 0x80: return numpy.float32(numpy.nan) sign = ival & 0x80 ival &= 0x7F expo = ival >> 3 mant = ival & 0x07 powe = expo & 0x0F if expo == 0: powe -= 7 fraction = 0 else: powe -= 8 fraction = 1 fval = float(mant / 8 + fraction) * 2.0**powe if sign: fval = -fval return numpy.float32(fval) else: if ival == 255: return numpy.float32(-numpy.nan) if ival == 127: return numpy.float32(numpy.nan) sign = ival & 0x80 if ival == 0 and sign > 0: return -numpy.float32(0) ival &= 0x7F expo = ival >> 3 mant = ival & 0x07 powe = expo & 0x0F if expo == 0: powe -= 6 fraction = 0 else: powe -= 7 fraction = 1 fval = float(mant / 8 + fraction) * 2.0**powe if sign: fval = -fval return numpy.float32(fval)
[docs]def fe5m2_to_float32_float(ival: int, fn: bool = False, uz: bool = False) -> float: """ Casts a float 8 encoded as an integer into a float. :param ival: byte :param fn: no infinite values :param uz: no negative zero :return: float (float 32) """ if ival < 0 or ival > 255: raise ValueError(f"{ival} is not a float8.") if fn and uz: if ival == 0x80: return numpy.float32(numpy.nan) exponent_bias = 16 elif not fn and not uz: if ival in (255, 254, 253): return numpy.float32(-numpy.nan) if ival in (127, 126, 125): return numpy.float32(numpy.nan) if ival == 252: return -numpy.float32(numpy.inf) if ival == 124: return numpy.float32(numpy.inf) if (ival & 0x7F) == 0: return numpy.float32(0) exponent_bias = 15 else: raise NotImplementedError("fn and uz must be both True or False.") sign = ival & 0x80 ival &= 0x7F expo = ival >> 2 mant = ival & 0x03 powe = expo & 0x1F if expo == 0: powe -= exponent_bias - 1 fraction = 0 else: powe -= exponent_bias fraction = 1 fval = float(mant / 4 + fraction) * 2.0**powe if sign: fval = -fval return numpy.float32(fval)
[docs]def fe4m3_to_float32(ival: int, fn: bool = True, uz: bool = False) -> float: """ Casts a float E4M3 encoded as an integer into a float. :param ival: byte :param fn: no inifinite values :param uz: no negative zero :return: float (float 32) """ if not fn: raise NotImplementedError("fn=False is not implemented.") if ival < 0 or ival > 255: raise ValueError(f"{ival} is not a float8.") if uz: exponent_bias = 8 if ival == 0x80: return numpy.nan else: exponent_bias = 7 if ival == 255: return numpy.float32(-numpy.nan) if ival == 127: return numpy.float32(numpy.nan) expo = (ival & 0x78) >> 3 mant = ival & 0x07 sign = ival & 0x80 res = sign << 24 if expo == 0: if mant > 0: expo = 0x7F - exponent_bias if mant & 0x4 == 0: mant &= 0x3 mant <<= 1 expo -= 1 if mant & 0x4 == 0: mant &= 0x3 mant <<= 1 expo -= 1 res |= (mant & 0x3) << 21 res |= expo << 23 else: res |= mant << 20 expo += 0x7F - exponent_bias res |= expo << 23 f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121 return f
[docs]def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float: """ Casts a float E5M2 encoded as an integer into a float. :param ival: byte :param fn: no inifinite values :param uz: no negative values :return: float (float 32) """ if ival < 0 or ival > 255: raise ValueError(f"{ival} is not a float8.") if fn and uz: if ival == 0x80: return numpy.float32(numpy.nan) exponent_bias = 16 elif not fn and not uz: if ival in {253, 254, 255}: return numpy.float32(-numpy.nan) if ival in {125, 126, 127}: return numpy.float32(numpy.nan) if ival == 252: return numpy.float32(-numpy.inf) if ival == 124: return numpy.float32(numpy.inf) exponent_bias = 15 else: raise NotImplementedError("fn and uz must be both False or True.") expo = (ival & 0x7C) >> 2 mant = ival & 0x03 sign = ival & 0x80 res = sign << 24 if expo == 0: if mant > 0: expo = 0x7F - exponent_bias if mant & 0x2 == 0: mant &= 0x1 mant <<= 1 expo -= 1 res |= (mant & 0x1) << 22 res |= expo << 23 else: res |= mant << 21 expo += 0x7F - exponent_bias res |= expo << 23 f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121 return f
# cast from float32 to float 8 class CastFloat8Sets: values_e4m3fn = list( sorted( (fe4m3_to_float32_float(i), i) for i in range(0, 256) if i not in (255, 127) ) ) values_e4m3fnuz = list( sorted( (fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256) if i != 0x80 ) ) values_e5m2 = list( sorted( (fe5m2_to_float32_float(i), i) for i in range(0, 256) if i not in {253, 254, 255, 125, 126, 127} ) ) values_e5m2fnuz = list( sorted( (fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(0, 256) if i != 0x80 ) )
[docs]class CastFloat8(CastFloat8Sets): """ Helpers to cast float8 into float32 or the other way around. """ values_e4m3fn_max_value = max( v for v in CastFloat8Sets.values_e4m3fn if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) ) values_e4m3fnuz_max_value = max( v for v in CastFloat8Sets.values_e4m3fnuz if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) ) values_e5m2_max_value = max( v for v in CastFloat8Sets.values_e5m2 if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) ) values_e5m2fnuz_max_value = max( v for v in CastFloat8Sets.values_e5m2fnuz if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) )
[docs] @staticmethod def find_closest_value(value, sorted_values): """ Search a value into a sorted array of values. :param value: float32 value to search :param sorted_values: list of tuple `[(float 32, byte)]` :return: byte The function searches into the first column the closest value and return the value on the second columns. """ a = 0 b = len(sorted_values) while a < b: m = (a + b) // 2 th = sorted_values[m][0] if value == th: return sorted_values[m][1] if value < th: b = m elif a == m: break else: a = m # finds the closest one if b < len(sorted_values): d1 = value - sorted_values[a][0] d2 = sorted_values[b][0] - value if d1 < d2: return sorted_values[a][1] if d1 == d2: # Applies rule tie to even ca, cb = sorted_values[a][1], sorted_values[b][1] return cb if ca & 1 == 1 else ca return sorted_values[b][1] return sorted_values[a][1]
[docs]def search_float32_into_fe4m3( value: float, fn: bool = True, uz: bool = False, saturate: bool = True ) -> int: """ Casts a float 32 into a float E4M3. :param value: float :param fn: no infinite values :param uz: no negative zero :param saturate: to convert out of range and infinities to max value if True :return: byte """ if not fn: raise NotImplementedError("fn=False is not implemented.") b = int.from_bytes(struct.pack("<f", numpy.float32(value)), "little") ret = (b & 0x80000000) >> 24 # sign if uz: if numpy.isnan(value): return 0x80 if numpy.isinf(value) and not saturate: return 0x80 set_values = CastFloat8.values_e4m3fnuz max_value = CastFloat8.values_e4m3fnuz_max_value if value > max_value[0]: return max_value[1] if saturate else 0x80 if value < -max_value[0]: return (max_value[1] | ret) if saturate else 0x80 else: if numpy.isnan(value) or numpy.isinf(value): return 0x7F | ret set_values = CastFloat8.values_e4m3fn max_value = CastFloat8.values_e4m3fn_max_value if value > max_value[0]: return max_value[1] if saturate else 0x7F | ret if value < -max_value[0]: return (max_value[1] | ret) if saturate else 0x7F | ret f = numpy.float32(value) i = CastFloat8.find_closest_value(f, set_values) return (i & 0x7F) | ret
[docs]def search_float32_into_fe5m2( value: float, fn: bool = False, uz: bool = False, saturate: bool = True ) -> int: """ Casts a float 32 into a float E5M2. :param value: float :param fn: no infinite values :param uz: no negative zero :param saturate: to convert out of range and infinities to max value if True :return: byte """ b = int.from_bytes(struct.pack("<f", numpy.float32(value)), "little") ret = (b & 0x80000000) >> 24 # sign if fn and uz: if numpy.isnan(value): return 0x80 if numpy.isinf(value) and not saturate: return 0x80 set_values = CastFloat8.values_e5m2fnuz max_value = CastFloat8.values_e5m2fnuz_max_value if value > max_value[0]: return max_value[1] if saturate else 0x80 if value < -max_value[0]: return (max_value[1] | ret) if saturate else 0x80 elif not fn and not uz: if numpy.isnan(value): return 0x7F | ret set_values = CastFloat8.values_e5m2 max_value = CastFloat8.values_e5m2_max_value if value > max_value[0]: return max_value[1] if saturate else (0x7C | ret) if value < -max_value[0]: return (max_value[1] | ret) if saturate else (0x7C | ret) else: raise NotImplementedError("fn and uz must both True or False.") f = numpy.float32(value) i = CastFloat8.find_closest_value(f, set_values) return (i & 0x7F) | ret
[docs]def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True): """ Converts a float32 into a float E4M3. :param x: numpy.float32 :param fn: no infinite values :param uz: no negative zero :param saturate: to convert out of range and infinities to max value if True :return: byte """ if not fn: raise NotImplementedError("fn=False is not implemented.") if not isinstance(x, numpy.float32): x = numpy.float32(x) b = int.from_bytes(struct.pack("<f", x), "little") ret = (b & 0x80000000) >> 24 # sign if uz: if (b & 0x7FFFFFFF) == 0x7F800000: # infinity if saturate: return ret | 127 return 0x80 if (b & 0x7F800000) == 0x7F800000: return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e != 0: if e < 116: pass elif e < 120: # denormalized number ex = e - 119 if ex >= -2: ret |= 1 << (2 + ex) ret |= m >> (21 - ex) elif m > 0: ret |= 1 mask = 1 << (20 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 135: # normalized number ex = e - 119 # 127 - 8 if ex == 0: ret |= 0x4 ret |= m >> 21 else: ret |= ex << 3 ret |= m >> 20 if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)): if (ret & 0x7F) < 0x7F: # rounding ret += 1 elif not saturate: return 0x80 elif saturate: ret |= 0x7F # 01111110 else: ret = 0x80 elif m == 0: # -0 ret = 0 return int(ret) else: if (b & 0x7FFFFFFF) == 0x7F800000: # infinity if saturate: return ret | 126 return 0x7F | ret if (b & 0x7F800000) == 0x7F800000: # non return 0x7F | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e != 0: if e < 117: pass elif e < 121: # denormalized number ex = e - 120 if ex >= -2: ret |= 1 << (2 + ex) ret |= m >> (21 - ex) elif m > 0: ret |= 1 mask = 1 << (20 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 136: # normalized number ex = e - 120 if ex == 0: ret |= 0x4 ret |= m >> 21 else: ret |= ex << 3 ret |= m >> 20 if (ret & 0x7F) == 0x7F: ret &= 0xFE if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)): if (ret & 0x7F) < 0x7E: # rounding ret += 1 elif not saturate: ret |= 0x7F elif saturate: ret |= 126 # 01111110 else: ret |= 0x7F return int(ret)
[docs]def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = True): """ Converts a float32 into a float E5M2. :param x: numpy.float32 :param fn: no infinite values :param uz: no negative zero :param saturate: to convert out of range and infinities to max value if True :return: byte """ b = int.from_bytes(struct.pack("<f", numpy.float32(x)), "little") ret = (b & 0x80000000) >> 24 # sign if fn and uz: if (b & 0x7FFFFFFF) == 0x7F800000: # inf if saturate: return ret | 0x7F return 0x80 if (b & 0x7F800000) == 0x7F800000: # nan return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e != 0: if e < 109: pass elif e < 112: # denormalized number ex = e - 111 if ex >= -1: ret |= 1 << (1 + ex) ret |= m >> (22 - ex) elif m > 0: ret |= 1 mask = 1 << (21 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 143: # normalized number ex = e - 111 ret |= ex << 2 ret |= m >> 21 if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): if (ret & 0x7F) < 0x7F: # rounding ret += 1 elif not saturate: ret = 0x80 elif e == 255 and m == 0: # inf ret = 0x80 elif saturate: ret |= 0x7F # last possible number else: ret = 0x80 elif m == 0: # -0 ret = 0 return int(ret) elif not fn and not uz: if (b & 0x7FFFFFFF) == 0x7F800000: # inf if saturate: return 0x7B | ret return 0x7C | ret if (b & 0x7F800000) == 0x7F800000: # nan return 0x7F | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa if e != 0: if e < 110: pass elif e < 113: # denormalized number ex = e - 112 if ex >= -1: ret |= 1 << (1 + ex) ret |= m >> (22 - ex) elif m > 0: ret |= 1 mask = 1 << (21 - ex) if m & mask and ( ret & 1 or m & (mask - 1) > 0 or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) ): # rounding ret += 1 elif e < 143: # normalized number ex = e - 112 ret |= ex << 2 ret |= m >> 21 if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): if (ret & 0x7F) < 0x7B: # rounding ret += 1 elif saturate: ret |= 0x7B else: ret |= 0x7C elif saturate: ret |= 0x7B else: ret |= 0x7C return int(ret) else: raise NotImplementedError("fn and uz must be both False or True.")