Project
Loading...
Searching...
No Matches
SIMDEncoderImpl.h
Go to the documentation of this file.
1// Copyright 2019-2023 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#ifndef RANS_INTERNAL_ENCODE_SIMDENCODERIMPL_H_
17#define RANS_INTERNAL_ENCODE_SIMDENCODERIMPL_H_
18
20
21#ifdef RANS_SIMD
22
23#include <cassert>
24#include <cstdint>
25#include <tuple>
26
31
32namespace o2::rans::internal
33{
34
35template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
36class SIMDEncoderImpl : public EncoderImpl<simd::UnrolledSymbols,
37 SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>>
38{
39 using base_type = EncoderImpl<simd::UnrolledSymbols, SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>>;
40
41 public:
42 using stream_type = typename base_type::stream_type;
43 using state_type = typename base_type::state_type;
44 using symbol_type = typename base_type::symbol_type;
45 using size_type = typename base_type::size_type;
46 using difference_type = typename base_type::difference_type;
47
48 static_assert(streamingLowerBound_V <= 20, "SIMD coders are limited to 20 BIT precision because of their used of FP arithmeric");
49
50 [[nodiscard]] inline static constexpr size_type getNstreams() noexcept { return 2 * simd::getElementCount<state_type>(simdWidth_V); };
51
52 SIMDEncoderImpl(size_t symbolTablePrecision);
53 SIMDEncoderImpl() : SIMDEncoderImpl{0} {};
54
55 // Flushes the rANS encoder.
56 template <typename Stream_IT>
57 Stream_IT flush(Stream_IT outputIter);
58
59 template <typename Stream_IT>
60 Stream_IT putSymbols(Stream_IT outputIter, const symbol_type& encodeSymbols);
61
62 template <typename Stream_IT>
63 Stream_IT putSymbols(Stream_IT outputIter, const symbol_type& encodeSymbols, size_t nActiveStreams);
64
65 [[nodiscard]] inline static constexpr state_type getStreamingLowerBound() noexcept { return static_cast<state_type>(utils::pow2(streamingLowerBound_V)); };
66
67 private:
68 size_t mSymbolTablePrecision{};
69 simd::simdI_t<simdWidth_V> mStates[2]{};
70 simd::simdD_t<simdWidth_V> mNSamples{};
71
72 template <typename Stream_IT>
73 Stream_IT putSymbol(Stream_IT outputIter, const Symbol& symbol, state_type& state);
74
75 template <typename Stream_IT>
76 Stream_IT flushState(state_type& state, Stream_IT outputIter);
77
78 // Renormalize the encoder.
79 template <typename Stream_IT>
80 std::tuple<state_type, Stream_IT> renorm(state_type state, Stream_IT outputIter, uint32_t frequency);
81
82 inline static constexpr state_type LowerBound = utils::pow2(streamingLowerBound_V); // lower bound of our normalization interval
83
84 inline static constexpr state_type StreamBits = utils::toBits<stream_type>(); // lower bound of our normalization interval
85};
86
87template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
88SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::SIMDEncoderImpl(size_t symbolTablePrecision) : mSymbolTablePrecision{symbolTablePrecision}, mStates{}, mNSamples{}
89{
90 if (mSymbolTablePrecision > LowerBound) {
91 throw HistogramError(fmt::format("SymbolTable Precision of {} Bits is larger than allowed by the rANS Encoder (max {} Bits)", mSymbolTablePrecision, LowerBound));
92 }
93
94 mStates[0] = simd::setAll<simdWidth_V>(LowerBound);
95 mStates[1] = simd::setAll<simdWidth_V>(LowerBound);
96
97 mNSamples = simd::setAll<simdWidth_V>(static_cast<double>(utils::pow2(mSymbolTablePrecision)));
98};
99
100template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
101template <typename Stream_IT>
102Stream_IT SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::flush(Stream_IT iter)
103{
104 using namespace simd;
105 epi64_t<simdWidth_V, 2> states;
106 store(mStates[0], states[0]);
107 store(mStates[1], states[1]);
108
109 Stream_IT streamPos = iter;
110 for (size_t stateIdx = states.nElements(); stateIdx-- > 0;) {
111 streamPos = flushState(*(states.data() + stateIdx), streamPos);
112 }
113
114 mStates[0] = load(states[0]);
115 mStates[1] = load(states[1]);
116
117 return streamPos;
118};
119
120template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
121template <typename Stream_IT>
122inline Stream_IT SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::putSymbols(Stream_IT outputIter, const symbol_type& symbols)
123{
124 using namespace simd;
125
126#if !defined(NDEBUG)
127 // for (const auto& symbol : symbols) {
128// // assert(symbol->getFrequency() != 0);
129// }
130#endif
131 simd::simdI_t<simdWidth_V> renormedStates[2];
132 auto streamPosition = ransRenorm<Stream_IT, LowerBound, StreamBits>(mStates,
133 symbols.frequencies,
134 static_cast<uint8_t>(mSymbolTablePrecision),
135 outputIter,
136 renormedStates);
137 mStates[0] = ransEncode(renormedStates[0], int32ToDouble<simdWidth_V>(symbols.frequencies[0]), int32ToDouble<simdWidth_V>(symbols.cumulativeFrequencies[0]), mNSamples);
138 mStates[1] = ransEncode(renormedStates[1], int32ToDouble<simdWidth_V>(symbols.frequencies[1]), int32ToDouble<simdWidth_V>(symbols.cumulativeFrequencies[1]), mNSamples);
139
140 return streamPosition;
141}
142
143template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
144template <typename Stream_IT>
145Stream_IT SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::putSymbols(Stream_IT outputIter, const symbol_type& symbols, size_t nActiveStreams)
146{
147 using namespace simd;
148
149 Stream_IT streamPos = outputIter;
150
151 epi64_t<simdWidth_V, 2> states;
152 store(mStates[0], states[0]);
153 store(mStates[1], states[1]);
154
155 epi32_t<SIMDWidth::SSE, 2> frequencies;
156 epi32_t<SIMDWidth::SSE, 2> cumulativeFrequencies;
157
158 store<uint32_t>(symbols.frequencies[0], frequencies[0]);
159 store<uint32_t>(symbols.frequencies[1], frequencies[1]);
160 store<uint32_t>(symbols.cumulativeFrequencies[0], cumulativeFrequencies[0]);
161 store<uint32_t>(symbols.cumulativeFrequencies[1], cumulativeFrequencies[1]);
162
163 for (size_t i = nActiveStreams; i-- > 0;) {
164 Symbol encodeSymbol{frequencies(i), cumulativeFrequencies(i)};
165 streamPos = putSymbol(streamPos, encodeSymbol, states(i));
166 }
167
168 mStates[0] = load(states[0]);
169 mStates[1] = load(states[1]);
170
171 return streamPos;
172};
173
174template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
175template <typename Stream_IT>
176Stream_IT SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::putSymbol(Stream_IT outputIter, const Symbol& symbol, state_type& state)
177{
178 assert(symbol.getFrequency() != 0); // can't encode symbol with freq=0
179 // renormalize
180 const auto [x, streamPos] = renorm(state, outputIter, symbol.getFrequency());
181
182 // x = C(s,x)
183 state = ((x / symbol.getFrequency()) << mSymbolTablePrecision) + (x % symbol.getFrequency()) + symbol.getCumulative();
184 return streamPos;
185}
186
187template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
188template <typename Stream_IT>
189Stream_IT SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::flushState(state_type& state, Stream_IT streamPosition)
190{
191 *streamPosition = static_cast<stream_type>(state >> 32);
192 ++streamPosition;
193 *streamPosition = static_cast<stream_type>(state >> 0);
194 ++streamPosition;
195
196 state = 0;
197 return streamPosition;
198}
199
200template <size_t streamingLowerBound_V, simd::SIMDWidth simdWidth_V>
201template <typename Stream_IT>
202inline auto SIMDEncoderImpl<streamingLowerBound_V, simdWidth_V>::renorm(state_type state, Stream_IT outputIter, uint32_t frequency) -> std::tuple<state_type, Stream_IT>
203{
204 state_type maxState = ((LowerBound >> mSymbolTablePrecision) << StreamBits) * frequency; // this turns into a shift.
205 if (state >= maxState) {
206 *outputIter = static_cast<stream_type>(state);
207 ++outputIter;
208 state >>= StreamBits;
209 assert(state < maxState);
210 }
211 return std::make_tuple(state, outputIter);
212};
213
214template <size_t streamingLowerBound_V>
215using SSEEncoderImpl = SIMDEncoderImpl<streamingLowerBound_V, simd::SIMDWidth::SSE>;
216template <size_t streamingLowerBound_V>
217using AVXEncoderImpl = SIMDEncoderImpl<streamingLowerBound_V, simd::SIMDWidth::AVX>;
218
219} // namespace o2::rans::internal
220
221#endif /* RANS_SIMD */
222
223#endif /* RANS_INTERNAL_ENCODE_SIMDENCODERIMPL_H_ */
benchmark::State & state
Defines the common operations for encoding data onto an rANS stream.
int32_t i
Contains statistical information for one source symbol, required for encoding/decoding.
common helper classes and functions
constexpr size_t StreamBits
constexpr size_t LowerBound
std::tuple< ransState_t, stream_IT > renorm(ransState_t state, stream_IT outputIter, count_t frequency, size_t symbolTablePrecision)
uint32_t stream_type
preprocessor defines to enable features based on CPU architecture
GLint GLenum GLint x
Definition glcorearb.h:403
GLuint * states
Definition glcorearb.h:4932
uint8_t itsSharedClusterMap uint8_t