diff options
Diffstat (limited to 'contrib/lua-torch/torch7/lib/TH/THHalf.c')
-rw-r--r-- | contrib/lua-torch/torch7/lib/TH/THHalf.c | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/contrib/lua-torch/torch7/lib/TH/THHalf.c b/contrib/lua-torch/torch7/lib/TH/THHalf.c new file mode 100644 index 000000000..d7468ac3d --- /dev/null +++ b/contrib/lua-torch/torch7/lib/TH/THHalf.c @@ -0,0 +1,100 @@ +#include "THHalf.h" + +/* Copyright 1993-2014 NVIDIA Corporation. All rights reserved. */ + +THHalf TH_float2half(float f) +{ + THHalf h; + TH_float2halfbits(&f, &h.x); + return h; +} + +TH_API float TH_half2float(THHalf h) +{ + float f; + TH_halfbits2float(&h.x, &f); + return f; +} + +// Host functions for converting between FP32 and FP16 formats + +void TH_halfbits2float(unsigned short* src, float* res) +{ + unsigned h = *src; + unsigned sign = ((h >> 15) & 1); + unsigned exponent = ((h >> 10) & 0x1f); + unsigned mantissa = ((h & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + *(unsigned*)res = ((sign << 31) | (exponent << 23) | mantissa); +} + +void TH_float2halfbits(float* src, unsigned short* dest) +{ + unsigned x = *(unsigned*)src; + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + *dest = 0x7fffU; + return ; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + *dest = sign | 0x7c00U; + return; + } + if (u < 0x33000001) { + *dest = (sign | 0x0000); + return; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + *dest = (sign | (exponent << 10) | mantissa); +} |