1//===--- Float16bits.cpp - supports 2-byte floats ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements f16 and bf16 to support the compilation and execution
10// of programs using these types.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/ExecutionEngine/Float16bits.h"
15
16#ifdef MLIR_FLOAT16_DEFINE_FUNCTIONS // We are building this library
17
18#include <cmath>
19#include <cstring>
20
21namespace {
22
23// Union used to make the int/float aliasing explicit so we can access the raw
24// bits.
25union Float32Bits {
26 uint32_t u;
27 float f;
28};
29
30const uint32_t kF32MantiBits = 23;
31const uint32_t kF32HalfMantiBitDiff = 13;
32const uint32_t kF32HalfBitDiff = 16;
33const Float32Bits kF32Magic = {.u: 113 << kF32MantiBits};
34const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
35
36// Constructs the 16 bit representation for a half precision value from a float
37// value. This implementation is adapted from Eigen.
38uint16_t float2half(float floatValue) {
39 const Float32Bits inf = {.u: 255 << kF32MantiBits};
40 const Float32Bits f16max = {.u: (127 + 16) << kF32MantiBits};
41 const Float32Bits denormMagic = {.u: ((127 - 15) + (kF32MantiBits - 10) + 1)
42 << kF32MantiBits};
43 uint32_t signMask = 0x80000000u;
44 uint16_t halfValue = static_cast<uint16_t>(0x0u);
45 Float32Bits f;
46 f.f = floatValue;
47 uint32_t sign = f.u & signMask;
48 f.u ^= sign;
49
50 if (f.u >= f16max.u) {
51 const uint32_t halfQnan = 0x7e00;
52 const uint32_t halfInf = 0x7c00;
53 // Inf or NaN (all exponent bits set).
54 halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
55 } else {
56 // (De)normalized number or zero.
57 if (f.u < kF32Magic.u) {
58 // The resulting FP16 is subnormal or zero.
59 //
60 // Use a magic value to align our 10 mantissa bits at the bottom of the
61 // float. As long as FP addition is round-to-nearest-even this works.
62 f.f += denormMagic.f;
63
64 halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
65 } else {
66 uint32_t mantOdd =
67 (f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd.
68
69 // Update exponent, rounding bias part 1. The following expressions are
70 // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
71 // 0xfff`, but without arithmetic overflow.
72 f.u += 0xc8000fffU;
73 // Rounding bias part 2.
74 f.u += mantOdd;
75 halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
76 }
77 }
78
79 halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
80 return halfValue;
81}
82
83// Converts the 16 bit representation of a half precision value to a float
84// value. This implementation is adapted from Eigen.
85float half2float(uint16_t halfValue) {
86 const uint32_t shiftedExp =
87 0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift.
88
89 // Initialize the float representation with the exponent/mantissa bits.
90 Float32Bits f = {
91 .u: static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
92 const uint32_t exp = shiftedExp & f.u;
93 f.u += kF32HalfExpAdjust; // Adjust the exponent
94
95 // Handle exponent special cases.
96 if (exp == shiftedExp) {
97 // Inf/NaN
98 f.u += kF32HalfExpAdjust;
99 } else if (exp == 0) {
100 // Zero/Denormal?
101 f.u += 1 << kF32MantiBits;
102 f.f -= kF32Magic.f;
103 }
104
105 f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit.
106 return f.f;
107}
108
109const uint32_t kF32BfMantiBitDiff = 16;
110
111// Constructs the 16 bit representation for a bfloat value from a float value.
112// This implementation is adapted from Eigen.
113uint16_t float2bfloat(float floatValue) {
114 if (std::isnan(x: floatValue))
115 return std::signbit(x: floatValue) ? 0xFFC0 : 0x7FC0;
116
117 Float32Bits floatBits;
118 floatBits.f = floatValue;
119 uint16_t bfloatBits;
120
121 // Least significant bit of resulting bfloat.
122 uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
123 uint32_t roundingBias = 0x7fff + lsb;
124 floatBits.u += roundingBias;
125 bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
126 return bfloatBits;
127}
128
129// Converts the 16 bit representation of a bfloat value to a float value. This
130// implementation is adapted from Eigen.
131float bfloat2float(uint16_t bfloatBits) {
132 Float32Bits floatBits;
133 floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
134 return floatBits.f;
135}
136
137} // namespace
138
139f16::f16(float f) : bits(float2half(floatValue: f)) {}
140
141bf16::bf16(float f) : bits(float2bfloat(floatValue: f)) {}
142
143std::ostream &operator<<(std::ostream &os, const f16 &f) {
144 os << half2float(halfValue: f.bits);
145 return os;
146}
147
148std::ostream &operator<<(std::ostream &os, const bf16 &d) {
149 os << bfloat2float(bfloatBits: d.bits);
150 return os;
151}
152
153bool operator==(const f16 &f1, const f16 &f2) { return f1.bits == f2.bits; }
154
155bool operator==(const bf16 &f1, const bf16 &f2) { return f1.bits == f2.bits; }
156
157// Mark these symbols as weak so they don't conflict when compiler-rt also
158// defines them.
159#define ATTR_WEAK
160#ifdef __has_attribute
161#if __has_attribute(weak) && !defined(__MINGW32__) && !defined(__CYGWIN__) && \
162 !defined(_WIN32)
163#undef ATTR_WEAK
164#define ATTR_WEAK __attribute__((__weak__))
165#endif
166#endif
167
168#if defined(__x86_64__) || defined(_M_X64)
169// On x86 bfloat16 is passed in SSE registers. Since both float and __bf16
170// are passed in the same register we can use the wider type and careful casting
171// to conform to x86_64 psABI. This only works with the assumption that we're
172// dealing with little-endian values passed in wider registers.
173// Ideally this would directly use __bf16, but that type isn't supported by all
174// compilers.
175using BF16ABIType = float;
176#else
177// Default to uint16_t if we have nothing else.
178using BF16ABIType = uint16_t;
179#endif
180
181// Provide a float->bfloat conversion routine in case the runtime doesn't have
182// one.
183extern "C" BF16ABIType ATTR_WEAK __truncsfbf2(float f) {
184 uint16_t bf = float2bfloat(floatValue: f);
185 // The output can be a float type, bitcast it from uint16_t.
186 BF16ABIType ret = 0;
187 std::memcpy(dest: &ret, src: &bf, n: sizeof(bf));
188 return ret;
189}
190
191// Provide a double->bfloat conversion routine in case the runtime doesn't have
192// one.
193extern "C" BF16ABIType ATTR_WEAK __truncdfbf2(double d) {
194 // This does a double rounding step, but it's precise enough for our use
195 // cases.
196 return __truncsfbf2(f: static_cast<float>(d));
197}
198
199// Provide these to the CRunner with the local float16 knowledge.
200extern "C" void printF16(uint16_t bits) {
201 f16 f;
202 std::memcpy(dest: &f, src: &bits, n: sizeof(f16));
203 std::cout << f;
204}
205extern "C" void printBF16(uint16_t bits) {
206 bf16 f;
207 std::memcpy(dest: &f, src: &bits, n: sizeof(bf16));
208 std::cout << f;
209}
210
211#endif // MLIR_FLOAT16_DEFINE_FUNCTIONS
212

source code of mlir/lib/ExecutionEngine/Float16bits.cpp