16#ifndef RANS_INTERNAL_ENCODE_SIMDKERNEL_H_
17#define RANS_INTERNAL_ENCODE_SIMDKERNEL_H_
40inline __m128i ransEncode(__m128i
state, __m128d frequency, __m128d cumulative, __m128d normalization)
noexcept
49 auto [div, mod] = divMod(uint64ToDouble(
state), frequency);
51 auto newState = _mm_fmadd_pd(normalization, div, cumulative);
53 auto newState = _mm_mul_pd(normalization, div);
54 newState = _mm_add_pd(newState, cumulative);
56 newState = _mm_add_pd(newState, mod);
58 return doubleToUint64(newState);
65inline __m256i ransEncode(__m256i
state, __m256d frequency, __m256d cumulative, __m256d normalization)
noexcept
74 auto [div, mod] = divMod(uint64ToDouble(
state), frequency);
75 auto newState = _mm256_fmadd_pd(normalization, div, cumulative);
76 newState = _mm256_add_pd(newState, mod);
78 return doubleToUint64(newState);
83inline void aosToSoa(gsl::span<const Symbol*, 2> in, __m128i* __restrict__ frequency, __m128i* __restrict__ cumulatedFrequency)
noexcept
85 __m128i in0Reg = _mm_loadu_si128(
reinterpret_cast<__m128i const*
>(in[0]->
data()));
86 __m128i in1Reg = _mm_loadu_si128(
reinterpret_cast<__m128i const*
>(in[1]->
data()));
88 *frequency = _mm_unpacklo_epi32(in0Reg, in1Reg);
89 *cumulatedFrequency = _mm_shuffle_epi32(*frequency, _MM_SHUFFLE(0, 0, 3, 2));
92inline void aosToSoa(gsl::span<const Symbol*, 4> in, __m128i* __restrict__ frequency, __m128i* __restrict__ cumulatedFrequency)
noexcept
94 __m128i in0Reg = _mm_loadu_si128(
reinterpret_cast<__m128i const*
>(in[0]->
data()));
95 __m128i in1Reg = _mm_loadu_si128(
reinterpret_cast<__m128i const*
>(in[1]->
data()));
96 __m128i in2Reg = _mm_loadu_si128(
reinterpret_cast<__m128i const*
>(in[2]->
data()));
97 __m128i in3Reg = _mm_loadu_si128(
reinterpret_cast<__m128i const*
>(in[3]->
data()));
99 __m128i merged0Reg = _mm_unpacklo_epi32(in0Reg, in1Reg);
100 __m128i merged1Reg = _mm_unpacklo_epi32(in2Reg, in3Reg);
101 *frequency = _mm_unpacklo_epi64(merged0Reg, merged1Reg);
102 *cumulatedFrequency = _mm_unpackhi_epi64(merged0Reg, merged1Reg);
105template <SIMDW
idth w
idth_V, u
int64_t lowerBound_V, u
int8_t streamBits_V>
106inline auto computeMaxState(__m128i frequencyVec, uint8_t symbolTablePrecisionBits)
noexcept
108 const uint64_t xmax = (lowerBound_V >> symbolTablePrecisionBits) << streamBits_V;
110 if constexpr (width_V == SIMDWidth::SSE) {
111 __m128i frequencyVecEpi64 = _mm_cvtepi32_epi64(frequencyVec);
112 return _mm_slli_epi64(frequencyVecEpi64, shift);
114 if constexpr (width_V == SIMDWidth::AVX) {
116 __m256i frequencyVecEpi64 = _mm256_cvtepi32_epi64(frequencyVec);
117 return _mm256_slli_epi64(frequencyVecEpi64, shift);
122template <u
int8_t streamBits_V>
123inline __m128i computeNewState(__m128i stateVec, __m128i cmpVec)
noexcept
126 __m128i newStateVec = _mm_srli_epi64(stateVec, streamBits_V);
127 newStateVec = _mm_blendv_epi8(stateVec, newStateVec, cmpVec);
132template <u
int8_t streamBits_V>
133inline __m256i computeNewState(__m256i stateVec, __m256i cmpVec)
noexcept
136 __m256i newStateVec = _mm256_srli_epi64(stateVec, streamBits_V);
137 newStateVec = _mm256_blendv_epi8(stateVec, newStateVec, cmpVec);
143inline constexpr std::array<epi8_t<SIMDWidth::SSE>, 16>
145 {0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
146 {0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
147 {0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
148 {0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
149 {0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
150 {0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
151 {0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
152 {0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
153 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
154 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
155 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
156 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
157 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
158 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
159 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8, 0xFF_u8},
160 {0x0C_u8, 0x0D_u8, 0x0E_u8, 0x0F_u8, 0x04_u8, 0x05_u8, 0x06_u8, 0x07_u8, 0x08_u8, 0x09_u8, 0x0A_u8, 0x0B_u8, 0x00_u8, 0x01_u8, 0x02_u8, 0x03_u8}
163inline constexpr std::array<uint32_t, 256> AVXStreamOutLUT{
423struct StreamOutResult;
426struct StreamOutResult<SIMDWidth::
SSE> {
428 __m128i streamOutVec;
431inline StreamOutResult<SIMDWidth::SSE> streamOut(
const __m128i* __restrict__ stateVec,
const __m128i* __restrict__ cmpVec)
noexcept
433 auto shifted1 = _mm_slli_epi64(stateVec[1], 32);
435 __m128i statesFused = _mm_blend_epi16(stateVec[0], shifted1, 0b11001100);
436 __m128i cmpFused = _mm_blend_epi16(cmpVec[0], cmpVec[1], 0b11001100);
437 const uint32_t
id = _mm_movemask_ps(_mm_castsi128_ps(cmpFused));
439 __m128i permutationMask = load(SSEStreamOutLUT[
id]);
440 __m128i streamOutVec = _mm_shuffle_epi8(statesFused, permutationMask);
442 return {
static_cast<uint32_t
>(_mm_popcnt_u32(
id)), streamOutVec};
447struct StreamOutResult<SIMDWidth::AVX> {
449 __m256i streamOutVec;
452inline StreamOutResult<SIMDWidth::AVX> streamOut(
const __m256i* __restrict__ stateVec,
const __m256i* __restrict__ cmpVec)
noexcept
454 auto shifted1 = _mm256_slli_epi64(stateVec[1], 32);
456 __m256i statesFused = _mm256_blend_epi32(stateVec[0], shifted1, 0b10101010);
457 __m256i cmpFused = _mm256_blend_epi32(cmpVec[0], cmpVec[1], 0b10101010);
458 statesFused = _mm256_and_si256(statesFused, cmpFused);
459 const uint32_t
id = _mm256_movemask_ps(_mm256_castsi256_ps(cmpFused));
461 __m256i permutationMask = _mm256_set1_epi32(AVXStreamOutLUT[
id]);
462 constexpr epi32_t<SIMDWidth::AVX>
mask{0xF0000000u, 0x0F000000u, 0x00F00000u, 0x000F0000u, 0x0000F000u, 0x00000F00u, 0x000000F0u, 0x0000000Fu};
463 permutationMask = _mm256_and_si256(permutationMask, load(
mask));
464 constexpr epi32_t<SIMDWidth::AVX> shift{28u, 24u, 20u, 16u, 12u, 8u, 4u, 0u};
465 permutationMask = _mm256_srlv_epi32(permutationMask, load(shift));
466 __m256i streamOutVec = _mm256_permutevar8x32_epi32(statesFused, permutationMask);
468 return {
static_cast<uint32_t
>(_mm_popcnt_u32(
id)), streamOutVec};
473template <SIMDW
idth,
typename output_IT>
476template <
typename output_IT>
477struct RenormResult<SIMDWidth::
SSE, output_IT> {
478 output_IT outputIter;
483template <
typename output_IT>
484struct RenormResult<SIMDWidth::AVX, output_IT> {
485 output_IT outputIter;
490template <
typename output_IT, u
int64_t lowerBound_V, u
int8_t streamBits_V>
491inline output_IT ransRenorm(
const __m128i* __restrict__
state,
const __m128i* __restrict__ frequency, uint8_t symbolTablePrecisionBits, output_IT outputIter, __m128i* __restrict__ newState)
noexcept
497 maxState[0] = computeMaxState<SIMDWidth::SSE, lowerBound_V, streamBits_V>(frequency[0], symbolTablePrecisionBits);
498 maxState[1] = computeMaxState<SIMDWidth::SSE, lowerBound_V, streamBits_V>(frequency[1], symbolTablePrecisionBits);
500 cmp[0] = cmpgeq_epi64(
state[0], maxState[0]);
501 cmp[1] = cmpgeq_epi64(
state[1], maxState[1]);
503 newState[0] = computeNewState<streamBits_V>(
state[0],
cmp[0]);
504 newState[1] = computeNewState<streamBits_V>(
state[1],
cmp[1]);
506 auto [nStreamOutWords, streamOutResult] = streamOut(
state,
cmp);
507 if constexpr (std::is_pointer_v<output_IT>) {
508 _mm_storeu_si128(
reinterpret_cast<__m128i*
>(outputIter), streamOutResult);
509 outputIter += nStreamOutWords;
511 auto result = store<uint32_t>(streamOutResult);
512 for (
size_t i = 0;
i < nStreamOutWords; ++
i) {
522template <
typename output_IT, u
int64_t lowerBound_V, u
int8_t streamBits_V>
523inline output_IT ransRenorm(
const __m256i*
state,
const __m128i* __restrict__ frequency, uint8_t symbolTablePrecisionBits, output_IT outputIter, __m256i* __restrict__ newState)
noexcept
529 maxState[0] = computeMaxState<SIMDWidth::AVX, lowerBound_V, streamBits_V>(frequency[0], symbolTablePrecisionBits);
530 maxState[1] = computeMaxState<SIMDWidth::AVX, lowerBound_V, streamBits_V>(frequency[1], symbolTablePrecisionBits);
532 cmp[0] = cmpgeq_epi64(
state[0], maxState[0]);
533 cmp[1] = cmpgeq_epi64(
state[1], maxState[1]);
535 newState[0] = computeNewState<streamBits_V>(
state[0],
cmp[0]);
536 newState[1] = computeNewState<streamBits_V>(
state[1],
cmp[1]);
538 auto [nStreamOutWords, streamOutResult] = streamOut(
state,
cmp);
539 if constexpr (std::is_pointer_v<output_IT>) {
540 _mm256_storeu_si256(
reinterpret_cast<__m256i*
>(outputIter), streamOutResult);
541 outputIter += nStreamOutWords;
543 auto result = store<uint32_t>(streamOutResult);
544 for (
size_t i = 0;
i < nStreamOutWords; ++
i) {
554struct UnrolledSymbols {
555 __m128i frequencies[2];
556 __m128i cumulativeFrequencies[2];
Memory aligned array used for SIMD operations.
Contains statistical information for one source symbol, required for encoding/decoding.
common helper classes and functions
preprocessor defines to enable features based on CPU architecture
auto make_span(const o2::rans::internal::simd::AlignedArray< T, width_V, size_V > &array)
uint8_t itsSharedClusterMap uint8_t
constexpr T log2UIntNZ(T x) noexcept
constexpr size_t pow2(size_t n) noexcept
wrapper around basic SIMD operations
basic SIMD datatypes and traits
std::vector< o2::ctf::BufferType > vec
char const *restrict const cmp