Project
Loading...
Searching...
No Matches
GPUORTFloat16.h
Go to the documentation of this file.
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// This code was created from:
5// - https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_float16.h
6// - https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_cxx_api.h
7
8#include <stdint.h>
9#include <cmath>
10#include <cstring>
11#include <limits>
12
13namespace o2
14{
15
16namespace OrtDataType
17{
18
19namespace detail
20{
21
22enum class endian {
23#if defined(_WIN32)
24 little = 0,
25 big = 1,
26 native = little,
27#elif defined(__GNUC__) || defined(__clang__)
28 little = __ORDER_LITTLE_ENDIAN__,
29 big = __ORDER_BIG_ENDIAN__,
30 native = __BYTE_ORDER__,
31#else
32#error OrtDataType::detail::endian is not implemented in this environment.
33#endif
34};
35
36static_assert(
37 endian::native == endian::little || endian::native == endian::big,
38 "Only little-endian or big-endian native byte orders are supported.");
39
40} // namespace detail
41
45template <class Derived>
47 protected:
53 constexpr static uint16_t ToUint16Impl(float v) noexcept;
54
59 float ToFloatImpl() const noexcept;
60
65 uint16_t AbsImpl() const noexcept
66 {
67 return static_cast<uint16_t>(val & ~kSignMask);
68 }
69
74 uint16_t NegateImpl() const noexcept
75 {
76 return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
77 }
78
79 public:
80 // uint16_t special values
81 static constexpr uint16_t kSignMask = 0x8000U;
82 static constexpr uint16_t kBiasedExponentMask = 0x7C00U;
83 static constexpr uint16_t kPositiveInfinityBits = 0x7C00U;
84 static constexpr uint16_t kNegativeInfinityBits = 0xFC00U;
85 static constexpr uint16_t kPositiveQNaNBits = 0x7E00U;
86 static constexpr uint16_t kNegativeQNaNBits = 0xFE00U;
87 static constexpr uint16_t kEpsilonBits = 0x4170U;
88 static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number
89 static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number
90 static constexpr uint16_t kOneBits = 0x3C00U;
91 static constexpr uint16_t kMinusOneBits = 0xBC00U;
92
93 uint16_t val{0};
94
95 Float16Impl() = default;
96
101 bool IsNegative() const noexcept
102 {
103 return static_cast<int16_t>(val) < 0;
104 }
105
110 bool IsNaN() const noexcept
111 {
113 }
114
119 bool IsFinite() const noexcept
120 {
122 }
123
128 bool IsPositiveInfinity() const noexcept
129 {
130 return val == kPositiveInfinityBits;
131 }
132
137 bool IsNegativeInfinity() const noexcept
138 {
139 return val == kNegativeInfinityBits;
140 }
141
146 bool IsInfinity() const noexcept
147 {
148 return AbsImpl() == kPositiveInfinityBits;
149 }
150
155 bool IsNaNOrZero() const noexcept
156 {
157 auto abs = AbsImpl();
158 return (abs == 0 || abs > kPositiveInfinityBits);
159 }
160
165 bool IsNormal() const noexcept
166 {
167 auto abs = AbsImpl();
168 return (abs < kPositiveInfinityBits) // is finite
169 && (abs != 0) // is not zero
170 && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
171 }
172
177 bool IsSubnormal() const noexcept
178 {
179 auto abs = AbsImpl();
180 return (abs < kPositiveInfinityBits) // is finite
181 && (abs != 0) // is not zero
182 && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
183 }
184
189 Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
190
195 Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
196
205 static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept
206 {
207 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
208 }
209
210 bool operator==(const Float16Impl& rhs) const noexcept
211 {
212 if (IsNaN() || rhs.IsNaN()) {
213 // IEEE defines that NaN is not equal to anything, including itself.
214 return false;
215 }
216 return val == rhs.val;
217 }
218
219 bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); }
220
221 bool operator<(const Float16Impl& rhs) const noexcept
222 {
223 if (IsNaN() || rhs.IsNaN()) {
224 // IEEE defines that NaN is unordered with respect to everything, including itself.
225 return false;
226 }
227
228 const bool left_is_negative = IsNegative();
229 if (left_is_negative != rhs.IsNegative()) {
230 // When the signs of left and right differ, we know that left is less than right if it is
231 // the negative value. The exception to this is if both values are zero, in which case IEEE
232 // says they should be equal, even if the signs differ.
233 return left_is_negative && !AreZero(*this, rhs);
234 }
235 return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
236 }
237};
238
239// The following Float16_t conversions are based on the code from
240// Eigen library.
241
242// The conversion routines are Copyright (c) Fabian Giesen, 2016.
243// The original license follows:
244//
245// Copyright (c) Fabian Giesen, 2016
246// All rights reserved.
247// Redistribution and use in source and binary forms, with or without
248// modification, are permitted.
249// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
250// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
251// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
252// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
253// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
254// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
255// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
256// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
257// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
258// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
259// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
260
261namespace detail
262{
264 unsigned int u;
265 float f;
266};
267}; // namespace detail
268
269template <class Derived>
270inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept
271{
273 f.f = v;
274
275 constexpr detail::float32_bits f32infty = {255 << 23};
276 constexpr detail::float32_bits f16max = {(127 + 16) << 23};
277 constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
278 constexpr unsigned int sign_mask = 0x80000000u;
279 uint16_t val = static_cast<uint16_t>(0x0u);
280
281 unsigned int sign = f.u & sign_mask;
282 f.u ^= sign;
283
284 // NOTE all the integer compares in this function can be safely
285 // compiled into signed compares since all operands are below
286 // 0x80000000. Important if you want fast straight SSE2 code
287 // (since there's no unsigned PCMPGTD).
288
289 if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
290 val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
291 } else { // (De)normalized number or zero
292 if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
293 // use a magic value to align our 10 mantissa bits at the bottom of
294 // the float. as long as FP addition is round-to-nearest-even this
295 // just works.
296 f.f += denorm_magic.f;
297
298 // and one integer subtract of the bias later, we have our final float!
299 val = static_cast<uint16_t>(f.u - denorm_magic.u);
300 } else {
301 unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
302
303 // update exponent, rounding bias part 1
304 // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
305 // without arithmetic overflow.
306 f.u += 0xc8000fffU;
307 // rounding bias part 2
308 f.u += mant_odd;
309 // take the bits!
310 val = static_cast<uint16_t>(f.u >> 13);
311 }
312 }
313
314 val |= static_cast<uint16_t>(sign >> 16);
315 return val;
316}
317
318template <class Derived>
319inline float Float16Impl<Derived>::ToFloatImpl() const noexcept
320{
321 constexpr detail::float32_bits magic = {113 << 23};
322 constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
324
325 o.u = (val & 0x7fff) << 13; // exponent/mantissa bits
326 unsigned int exp = shifted_exp & o.u; // just the exponent
327 o.u += (127 - 15) << 23; // exponent adjust
328
329 // handle exponent special cases
330 if (exp == shifted_exp) { // Inf/NaN?
331 o.u += (128 - 16) << 23; // extra exp adjust
332 } else if (exp == 0) { // Zero/Denormal?
333 o.u += 1 << 23; // extra exp adjust
334 o.f -= magic.f; // re-normalize
335 }
336
337 // Attempt to workaround the Internal Compiler Error on ARM64
338 // for bitwise | operator, including std::bitset
339#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
340 if (IsNegative()) {
341 return -o.f;
342 }
343#else
344 // original code:
345 o.u |= (val & 0x8000U) << 16U; // sign bit
346#endif
347 return o.f;
348}
349
351template <class Derived>
353 protected:
359 static uint16_t ToUint16Impl(float v) noexcept;
360
365 float ToFloatImpl() const noexcept;
366
371 uint16_t AbsImpl() const noexcept
372 {
373 return static_cast<uint16_t>(val & ~kSignMask);
374 }
375
380 uint16_t NegateImpl() const noexcept
381 {
382 return IsNaN() ? val : static_cast<uint16_t>(val ^ kSignMask);
383 }
384
385 public:
386 // uint16_t special values
387 static constexpr uint16_t kSignMask = 0x8000U;
388 static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
389 static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
390 static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
391 static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
392 static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
393 static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
394 static constexpr uint16_t kEpsilonBits = 0x0080U;
395 static constexpr uint16_t kMinValueBits = 0xFF7FU;
396 static constexpr uint16_t kMaxValueBits = 0x7F7FU;
397 static constexpr uint16_t kRoundToNearest = 0x7FFFU;
398 static constexpr uint16_t kOneBits = 0x3F80U;
399 static constexpr uint16_t kMinusOneBits = 0xBF80U;
400
401 uint16_t val{0};
402
403 BFloat16Impl() = default;
404
409 bool IsNegative() const noexcept
410 {
411 return static_cast<int16_t>(val) < 0;
412 }
413
418 bool IsNaN() const noexcept
419 {
421 }
422
427 bool IsFinite() const noexcept
428 {
430 }
431
436 bool IsPositiveInfinity() const noexcept
437 {
438 return val == kPositiveInfinityBits;
439 }
440
445 bool IsNegativeInfinity() const noexcept
446 {
447 return val == kNegativeInfinityBits;
448 }
449
454 bool IsInfinity() const noexcept
455 {
456 return AbsImpl() == kPositiveInfinityBits;
457 }
458
463 bool IsNaNOrZero() const noexcept
464 {
465 auto abs = AbsImpl();
466 return (abs == 0 || abs > kPositiveInfinityBits);
467 }
468
473 bool IsNormal() const noexcept
474 {
475 auto abs = AbsImpl();
476 return (abs < kPositiveInfinityBits) // is finite
477 && (abs != 0) // is not zero
478 && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent)
479 }
480
485 bool IsSubnormal() const noexcept
486 {
487 auto abs = AbsImpl();
488 return (abs < kPositiveInfinityBits) // is finite
489 && (abs != 0) // is not zero
490 && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent)
491 }
492
497 Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); }
498
503 Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); }
504
513 static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept
514 {
515 // IEEE defines that positive and negative zero are equal, this gives us a quick equality check
516 // for two values by or'ing the private bits together and stripping the sign. They are both zero,
517 // and therefore equivalent, if the resulting value is still zero.
518 return static_cast<uint16_t>((lhs.val | rhs.val) & ~kSignMask) == 0;
519 }
520};
521
522template <class Derived>
523inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept
524{
525 uint16_t result;
526 if (std::isnan(v)) {
527 result = kPositiveQNaNBits;
528 } else {
529 auto get_msb_half = [](float fl) {
530 uint16_t result;
531#ifdef __cpp_if_constexpr
532 if constexpr (detail::endian::native == detail::endian::little)
533#else
534 if (detail::endian::native == detail::endian::little)
535#endif
536 {
537 std::memcpy(&result, reinterpret_cast<char*>(&fl) + sizeof(uint16_t), sizeof(uint16_t));
538 } else {
539 std::memcpy(&result, &fl, sizeof(uint16_t));
540 }
541 return result;
542 };
543
544 uint16_t upper_bits = get_msb_half(v);
545 union {
546 uint32_t U32;
547 float F32;
548 };
549 F32 = v;
550 U32 += (upper_bits & 1) + kRoundToNearest;
551 result = get_msb_half(F32);
552 }
553 return result;
554}
555
556template <class Derived>
557inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept
558{
559 if (IsNaN()) {
560 return std::numeric_limits<float>::quiet_NaN();
561 }
562 float result;
563 char* const first = reinterpret_cast<char*>(&result);
564 char* const second = first + sizeof(uint16_t);
565#ifdef __cpp_if_constexpr
566 if constexpr (detail::endian::native == detail::endian::little)
567#else
568 if (detail::endian::native == detail::endian::little)
569#endif
570 {
571 std::memset(first, 0, sizeof(uint16_t));
572 std::memcpy(second, &val, sizeof(uint16_t));
573 } else {
574 std::memcpy(first, &val, sizeof(uint16_t));
575 std::memset(second, 0, sizeof(uint16_t));
576 }
577 return result;
578}
579
599 private:
605 constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
606
607 public:
609
613 Float16_t() = default;
614
620 constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
621
626 explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
627
632 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
633
638 using Base::IsNegative;
639
644 using Base::IsNaN;
645
650 using Base::IsFinite;
651
657
663
668 using Base::IsInfinity;
669
674 using Base::IsNaNOrZero;
675
680 using Base::IsNormal;
681
686 using Base::IsSubnormal;
687
692 using Base::Abs;
693
698 using Base::Negate;
699
708 using Base::AreZero;
709
713 explicit operator float() const noexcept { return ToFloat(); }
714
715 using Base::operator==;
716 using Base::operator!=;
717 using Base::operator<;
718};
719
720static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
721
741 private:
749 constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
750
751 public:
753
754 BFloat16_t() = default;
755
761 static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
762
767 explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
768
773 float ToFloat() const noexcept { return Base::ToFloatImpl(); }
774
779 using Base::IsNegative;
780
785 using Base::IsNaN;
786
791 using Base::IsFinite;
792
798
804
809 using Base::IsInfinity;
810
815 using Base::IsNaNOrZero;
816
821 using Base::IsNormal;
822
827 using Base::IsSubnormal;
828
833 using Base::Abs;
834
839 using Base::Negate;
840
849 using Base::AreZero;
850
854 explicit operator float() const noexcept { return ToFloat(); }
855
856 // We do not have an inherited impl for the below operators
857 // as the internal class implements them a little differently
858 bool operator==(const BFloat16_t& rhs) const noexcept;
859 bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
860 bool operator<(const BFloat16_t& rhs) const noexcept;
861};
862
863static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
864
865} // namespace OrtDataType
866
867} // namespace o2
uint64_t exp(uint64_t base, uint8_t exp) noexcept
bool o
GLuint64EXT * result
Definition glcorearb.h:5662
const GLdouble * v
Definition glcorearb.h:832
GLdouble f
Definition glcorearb.h:310
GLuint GLfloat * val
Definition glcorearb.h:1582
a couple of static helper functions to create timestamp values for CCDB queries or override obsolete ...
Shared implementation between public and internal classes. CRTP pattern.
static constexpr uint16_t kMaxValueBits
static constexpr uint16_t kNegativeQNaNBits
static constexpr uint16_t kRoundToNearest
uint16_t NegateImpl() const noexcept
Creates a new instance with the sign flipped.
static constexpr uint16_t kSignaling_NaNBits
bool IsFinite() const noexcept
Tests if the value is finite.
Derived Negate() const noexcept
Creates a new instance with the sign flipped.
Derived Abs() const noexcept
Creates an instance that represents absolute value.
static constexpr uint16_t kSignMask
uint16_t AbsImpl() const noexcept
Creates an instance that represents absolute value.
static constexpr uint16_t kPositiveInfinityBits
static constexpr uint16_t kBiasedExponentMask
static constexpr uint16_t kNegativeInfinityBits
static constexpr uint16_t kMinusOneBits
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
static constexpr uint16_t kPositiveQNaNBits
bool IsNaN() const noexcept
Tests if the value is NaN.
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
static uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation.
static constexpr uint16_t kEpsilonBits
static bool AreZero(const BFloat16Impl &lhs, const BFloat16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity.
float ToFloatImpl() const noexcept
Converts bfloat16 to float.
static constexpr uint16_t kOneBits
bool IsNegative() const noexcept
Checks if the value is negative.
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
static constexpr uint16_t kMinValueBits
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
bfloat16 (Brain Floating Point) data type
static constexpr BFloat16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of bfloat16.
BFloat16_t(float v) noexcept
__ctor from float. Float is converted into bfloat16 16-bit representation.
float ToFloat() const noexcept
Converts bfloat16 to float.
bool operator!=(const BFloat16_t &rhs) const noexcept
bool operator==(const BFloat16_t &rhs) const noexcept
bool operator<(const BFloat16_t &rhs) const noexcept
Shared implementation between public and internal classes. CRTP pattern.
uint16_t NegateImpl() const noexcept
Creates a new instance with the sign flipped.
bool IsFinite() const noexcept
Tests if the value is finite.
static constexpr uint16_t kSignMask
static constexpr uint16_t kNegativeInfinityBits
float ToFloatImpl() const noexcept
Converts float16 to float.
bool IsPositiveInfinity() const noexcept
Tests if the value represents positive infinity.
static constexpr uint16_t kNegativeQNaNBits
static constexpr uint16_t kPositiveInfinityBits
bool IsNormal() const noexcept
Tests if the value is normal (not zero, subnormal, infinite, or NaN).
static constexpr uint16_t kMinusOneBits
bool operator==(const Float16Impl &rhs) const noexcept
bool IsNegative() const noexcept
Checks if the value is negative.
static constexpr uint16_t kBiasedExponentMask
static constexpr uint16_t kOneBits
static constexpr uint16_t kEpsilonBits
bool IsInfinity() const noexcept
Tests if the value is either positive or negative infinity.
bool operator!=(const Float16Impl &rhs) const noexcept
static constexpr uint16_t kMaxValueBits
static constexpr uint16_t ToUint16Impl(float v) noexcept
Converts from float to uint16_t float16 representation.
static constexpr uint16_t kMinValueBits
bool IsNaN() const noexcept
Tests if the value is NaN.
bool operator<(const Float16Impl &rhs) const noexcept
bool IsNegativeInfinity() const noexcept
Tests if the value represents negative infinity.
static bool AreZero(const Float16Impl &lhs, const Float16Impl &rhs) noexcept
IEEE defines that positive and negative zero are equal, this gives us a quick equality check for two ...
Derived Negate() const noexcept
Creates a new instance with the sign flipped.
bool IsSubnormal() const noexcept
Tests if the value is subnormal (denormal).
Derived Abs() const noexcept
Creates an instance that represents absolute value.
uint16_t AbsImpl() const noexcept
Creates an instance that represents absolute value.
static constexpr uint16_t kPositiveQNaNBits
bool IsNaNOrZero() const noexcept
Tests if the value is NaN or zero. Useful for comparisons.
IEEE 754 half-precision floating point data type.
float ToFloat() const noexcept
Converts float16 to float.
Float16_t()=default
Default constructor.
static constexpr Float16_t FromBits(uint16_t v) noexcept
Explicit conversion to uint16_t representation of float16.
Float16_t(float v) noexcept
__ctor from float. Float is converted into float16 16-bit representation.