35#elif defined(__GNUC__) || defined(__clang__)
36 little = __ORDER_LITTLE_ENDIAN__,
37 big = __ORDER_BIG_ENDIAN__,
38 native = __BYTE_ORDER__,
40#error OrtDataType::detail::endian is not implemented in this environment.
45 endian::native == endian::little || endian::native == endian::big,
46 "Only little-endian or big-endian native byte orders are supported.");
53template <
class Derived>
75 return static_cast<uint16_t
>(
val & ~kSignMask);
111 return static_cast<int16_t
>(
val) < 0;
165 auto abs = AbsImpl();
175 auto abs = AbsImpl();
187 auto abs = AbsImpl();
203 GPUd() Derived Negate() const
noexcept {
return Derived::FromBits(NegateImpl()); }
215 return static_cast<uint16_t
>((lhs.val | rhs.val) & ~
kSignMask) == 0;
220 if (IsNaN() || rhs.IsNaN()) {
224 return val == rhs.val;
227 GPUd() bool operator!=(const Float16Impl& rhs) const
noexcept {
return !(*
this == rhs); }
229 GPUd() bool operator<(const Float16Impl& rhs) const
noexcept
236 const bool left_is_negative = IsNegative();
237 if (left_is_negative !=
rhs.IsNegative()) {
241 return left_is_negative && !AreZero(*
this, rhs);
243 return (
val !=
rhs.val) && ((
val <
rhs.val) ^ left_is_negative);
277template <
class Derived>
283 constexpr detail::float32_bits f32infty = {255 << 23};
284 constexpr detail::float32_bits f16max = {(127 + 16) << 23};
285 constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23};
286 constexpr unsigned int sign_mask = 0x80000000u;
287 uint16_t
val =
static_cast<uint16_t
>(0x0u);
289 unsigned int sign =
f.u & sign_mask;
297 if (
f.u >= f16max.u) {
298 val = (
f.u > f32infty.u) ? 0x7e00 : 0x7c00;
300 if (
f.u < (113 << 23)) {
304 f.f += denorm_magic.f;
307 val =
static_cast<uint16_t
>(
f.u - denorm_magic.u);
309 unsigned int mant_odd = (
f.u >> 13) & 1;
318 val =
static_cast<uint16_t
>(
f.u >> 13);
322 val |=
static_cast<uint16_t
>(sign >> 16);
326template <
class Derived>
327GPUdi() float Float16Impl<Derived>::ToFloatImpl() const noexcept
329 constexpr detail::float32_bits magic = {113 << 23};
330 constexpr unsigned int shifted_exp = 0x7c00 << 13;
331 detail::float32_bits
o{};
333 o.u = (
val & 0x7fff) << 13;
334 unsigned int exp = shifted_exp &
o.u;
335 o.u += (127 - 15) << 23;
338 if (
exp == shifted_exp) {
339 o.u += (128 - 16) << 23;
340 }
else if (
exp == 0) {
347#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC)
353 o.u |= (
val & 0x8000U) << 16U;
359template <
class Derived>
367 GPUd() static uint16_t ToUint16Impl(
float v) noexcept;
373 GPUd()
float ToFloatImpl() const noexcept;
379 GPUd() uint16_t AbsImpl() const noexcept
381 return static_cast<uint16_t
>(
val & ~kSignMask);
388 GPUd() uint16_t NegateImpl() const noexcept
390 return IsNaN() ?
val :
static_cast<uint16_t
>(
val ^ kSignMask);
395 static constexpr uint16_t kSignMask = 0x8000U;
396 static constexpr uint16_t kBiasedExponentMask = 0x7F80U;
397 static constexpr uint16_t kPositiveInfinityBits = 0x7F80U;
398 static constexpr uint16_t kNegativeInfinityBits = 0xFF80U;
399 static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U;
400 static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U;
401 static constexpr uint16_t kSignaling_NaNBits = 0x7F80U;
402 static constexpr uint16_t kEpsilonBits = 0x0080U;
403 static constexpr uint16_t kMinValueBits = 0xFF7FU;
404 static constexpr uint16_t kMaxValueBits = 0x7F7FU;
405 static constexpr uint16_t kRoundToNearest = 0x7FFFU;
406 static constexpr uint16_t kOneBits = 0x3F80U;
407 static constexpr uint16_t kMinusOneBits = 0xBF80U;
417 GPUd()
bool IsNegative() const noexcept
419 return static_cast<int16_t
>(
val) < 0;
426 GPUd() bool IsNaN() const noexcept
428 return AbsImpl() > kPositiveInfinityBits;
435 GPUd() bool IsFinite() const noexcept
437 return AbsImpl() < kPositiveInfinityBits;
444 GPUd() bool IsPositiveInfinity() const noexcept
446 return val == kPositiveInfinityBits;
453 GPUd() bool IsNegativeInfinity() const noexcept
455 return val == kNegativeInfinityBits;
462 GPUd() bool IsInfinity() const noexcept
464 return AbsImpl() == kPositiveInfinityBits;
471 GPUd() bool IsNaNOrZero() const noexcept
473 auto abs = AbsImpl();
474 return (abs == 0 || abs > kPositiveInfinityBits);
481 GPUd() bool IsNormal() const noexcept
483 auto abs = AbsImpl();
484 return (abs < kPositiveInfinityBits)
486 && ((abs & kBiasedExponentMask) != 0);
493 GPUd() bool IsSubnormal() const noexcept
495 auto abs = AbsImpl();
496 return (abs < kPositiveInfinityBits)
498 && ((abs & kBiasedExponentMask) == 0);
505 GPUd() Derived Abs() const noexcept {
return Derived::FromBits(AbsImpl()); }
511 GPUd() Derived Negate() const noexcept {
return Derived::FromBits(NegateImpl()); }
521 GPUd() static
bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept
526 return static_cast<uint16_t
>((
lhs.val |
rhs.val) & ~kSignMask) == 0;
530template <
class Derived>
531GPUdi() uint16_t BFloat16Impl<Derived>::ToUint16Impl(
float v) noexcept
534 if (o2::gpu::CAMath::IsNaN(
v)) {
535 result = kPositiveQNaNBits;
537 auto get_msb_half = [](
float fl) {
540 o2::gpu::CAMath::memcpy(&
result,
reinterpret_cast<char*
>(&fl) +
sizeof(uint16_t),
sizeof(uint16_t));
542#ifdef __cpp_if_constexpr
543 if constexpr (detail::endian::native == detail::endian::little)
545 if (detail::endian::native == detail::endian::little)
548 std::memcpy(&
result,
reinterpret_cast<char*
>(&fl) +
sizeof(uint16_t),
sizeof(uint16_t));
550 std::memcpy(&
result, &fl,
sizeof(uint16_t));
556 uint16_t upper_bits = get_msb_half(
v);
562 U32 += (upper_bits & 1) + kRoundToNearest;
563 result = get_msb_half(F32);
568template <
class Derived>
569GPUdi() float BFloat16Impl<Derived>::ToFloatImpl() const noexcept
572 return o2::gpu::CAMath::QuietNaN();
575 char*
const first =
reinterpret_cast<char*
>(&
result);
576 char*
const second =
first +
sizeof(uint16_t);
579 o2::gpu::CAMath::memcpy(second, &
val,
sizeof(uint16_t));
581#ifdef __cpp_if_constexpr
582 if constexpr (detail::endian::native == detail::endian::little)
584 if (detail::endian::native == detail::endian::little)
587 std::memset(
first, 0,
sizeof(uint16_t));
588 std::memcpy(second, &
val,
sizeof(uint16_t));
590 std::memcpy(
first, &
val,
sizeof(uint16_t));
591 std::memset(second, 0,
sizeof(uint16_t));
615struct Float16_t : OrtDataType::Float16Impl<Float16_t> {
622 constexpr explicit Float16_t(uint16_t
v)
noexcept {
val =
v; }
625 using Base = OrtDataType::Float16Impl<Float16_t>;
637 GPUd() constexpr static Float16_t FromBits(uint16_t
v) noexcept {
return Float16_t(
v); }
643 GPUd() explicit Float16_t(
float v) noexcept {
val = Base::ToUint16Impl(
v); }
649 GPUd() float ToFloat() const noexcept {
return Base::ToFloatImpl(); }
655 using Base::IsNegative;
667 using Base::IsFinite;
673 using Base::IsPositiveInfinity;
679 using Base::IsNegativeInfinity;
685 using Base::IsInfinity;
691 using Base::IsNaNOrZero;
697 using Base::IsNormal;
703 using Base::IsSubnormal;
730 GPUdi() explicit operator
float() const noexcept {
return ToFloat(); }
732 using Base::operator==;
733 using Base::operator!=;
734 using Base::operator<;
737static_assert(
sizeof(Float16_t) ==
sizeof(uint16_t),
"Sizes must match");
769 using Base = OrtDataType::BFloat16Impl<BFloat16_t>;
790 GPUd() float ToFloat() const noexcept {
return Base::ToFloatImpl(); }
796 using Base::IsNegative;
808 using Base::IsFinite;
814 using Base::IsPositiveInfinity;
820 using Base::IsNegativeInfinity;
826 using Base::IsInfinity;
832 using Base::IsNaNOrZero;
838 using Base::IsNormal;
844 using Base::IsSubnormal;
871 GPUdi() explicit operator
float() const noexcept {
return ToFloat(); }
880static_assert(
sizeof(
BFloat16_t) ==
sizeof(uint16_t),
"Sizes must match");