Project
Loading...
Searching...
No Matches
renorm.h
Go to the documentation of this file.
1// Copyright 2019-2023 CERN and copyright holders of ALICE O2.
2// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3// All rights not expressly granted are reserved.
4//
5// This software is distributed under the terms of the GNU General Public
6// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7//
8// In applying this license CERN does not waive the privileges and immunities
9// granted to it by virtue of its status as an Intergovernmental Organization
10// or submit itself to any jurisdiction.
11
15
16#ifndef RANS_INTERNAL_TRANSFORM_RENORM_H_
17#define RANS_INTERNAL_TRANSFORM_RENORM_H_
18
19#include <fairlogger/Logger.h>
20
28
29namespace o2::rans
30{
31
32enum class RenormingPolicy { Auto, // make datadriven decision if a symbol will be marked as incompressible
33 ForceIncompressible }; // add a default incompressible symbol even if data does not require it
34
35namespace renormImpl
36{
37
38template <typename source_T>
40{
41 if constexpr (sizeof(source_T) <= 2) {
42 const size_t nUsedAlphabetSymbols = f.empty() ? 0 : f.size();
43 return nUsedAlphabetSymbols;
44 } else {
46 }
47}
48
49template <typename source_T>
54
55template <typename source_T>
57{
58 return f.size();
59}
60
61template <typename histogram_T>
62decltype(auto) renorm(histogram_T histogram, Metrics<typename histogram_T::source_type>& metrics, RenormingPolicy renormingPolicy, size_t lowProbabilityCutoffBits = 0)
63{
64 using namespace o2::rans;
65 using namespace o2::rans::internal;
66
67 if (histogram.empty()) {
68 LOG(warning) << "rescaling Frequency Table for empty message";
69 }
70 using histogram_type = histogram_T;
71 using source_type = typename histogram_type::source_type;
72 using count_type = typename histogram_type::value_type;
73 using difference_type = typename histogram_type::difference_type;
74 using container_type = typename histogram_type::container_type;
75 using iterator_type = typename container_type::iterator;
76
77 const double_t nSamples = histogram.getNumSamples();
78 const size_t renormingPrecisionBits = *metrics.getCoderProperties().renormingPrecisionBits;
79 const size_t nUsedAlphabetSymbols = metrics.getDatasetProperties().nUsedAlphabetSymbols;
80
81 const count_type nSamplesRescaled = utils::pow2(renormingPrecisionBits);
82 const double_t probabilityCutOffThreshold = 1.0 / static_cast<double_t>(utils::pow2(renormingPrecisionBits + lowProbabilityCutoffBits));
83
84 // scaling
85 double_t incompressibleSymbolProbability = 0;
86 count_type nIncompressibleSamples = 0;
87 count_type nIncompressibleSymbols = 0;
88 count_type nSamplesRescaledUncorrected = 0;
89 std::vector<std::pair<source_type, std::reference_wrapper<count_type>>> correctableIndices;
90 correctableIndices.reserve(nUsedAlphabetSymbols);
91
92 auto scaleFrequency = [nSamplesRescaled](double_t symbolProbability) -> double_t { return symbolProbability * nSamplesRescaled; };
93
94 container_type rescaledHistogram = std::move(histogram).release();
95
96 forEachIndexValue(rescaledHistogram, [&](const source_type& index, count_t& frequency) {
97 if (frequency > 0) {
98 const double_t symbolProbability = static_cast<double_t>(frequency) / nSamples;
99 if (symbolProbability < probabilityCutOffThreshold) {
100 nIncompressibleSamples += frequency;
101 ++nIncompressibleSymbols;
102 incompressibleSymbolProbability += symbolProbability;
103 frequency = 0;
104 } else {
105 const double_t scaledFrequencyD = scaleFrequency(symbolProbability);
106 count_type rescaledFrequency = internal::roundSymbolFrequency(scaledFrequencyD);
107 assert(rescaledFrequency > 0);
108 frequency = rescaledFrequency;
109 nSamplesRescaledUncorrected += rescaledFrequency;
110 if (rescaledFrequency > 1) {
111 correctableIndices.emplace_back(std::make_pair(index, std::ref(frequency)));
112 }
113 }
114 }
115 });
116
117 // treat incompressible symbol:
118 const count_type incompressibleSymbolFrequency = [&]() -> count_type {
119 // The Escape symbol for incompressible data is required
120 const bool requireIncompressible = incompressibleSymbolProbability > 0. // if the algorithm eliminates infrequent symbols
121 || nSamples == 0 // if the message we built the histogram from was empty
122 || (renormingPolicy == RenormingPolicy::ForceIncompressible); // or we want to reuse the symbol table later with different data
123
124 // if requireIncompressible == false it casts into 0, else it casts into 1 which is exactly our lower bound for each case, and we avoid branching.
125 return std::max(static_cast<count_type>(requireIncompressible), static_cast<count_type>(incompressibleSymbolProbability * nSamplesRescaled));
126 }();
127
128 nSamplesRescaledUncorrected += incompressibleSymbolFrequency;
129
130 // correction
131 const auto nSorted = [&]() {
132 const auto& datasetProperties = metrics.getDatasetProperties();
133 float_t cumulProbability{};
134 size_t nSymbols{};
135 for (size_t i = 0; i < datasetProperties.weightedSymbolLengthDistribution.size(); ++i) {
136 cumulProbability += datasetProperties.weightedSymbolLengthDistribution[i];
137 nSymbols += datasetProperties.symbolLengthDistribution[i];
138 if (cumulProbability > 0.99) {
139 break;
140 }
141 }
142 return nSymbols;
143 }();
144
145 if ((nSorted < correctableIndices.size()) && (renormingPolicy != RenormingPolicy::ForceIncompressible)) {
146 std::partial_sort(correctableIndices.begin(), correctableIndices.begin() + nSorted, correctableIndices.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
147 } else {
148 std::stable_sort(correctableIndices.begin(), correctableIndices.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
149 }
150
151 difference_type nCorrections = static_cast<difference_type>(nSamplesRescaled) - static_cast<difference_type>(nSamplesRescaledUncorrected);
152 const double_t rescalingFactor = static_cast<double_t>(nSamplesRescaled) / static_cast<double_t>(nSamplesRescaledUncorrected);
153
154 for (auto& [index, value] : correctableIndices) {
155 if (std::abs(nCorrections) > 0) {
156 const difference_type uncorrectedFrequency = value;
157 difference_type correction = uncorrectedFrequency - roundSymbolFrequency(uncorrectedFrequency * rescalingFactor);
158
159 if (nCorrections < 0) {
160 // overshoot - correct downwards by subtracting correction in [1,|nCorrections|]
161 correction = std::max(1l, std::min(correction, std::abs(nCorrections)));
162 } else {
163 // correct upwards by subtracting correction in [-1, -nCorrections]
164 correction = std::min(-1l, std::max(correction, -nCorrections));
165 }
166
167 // the corrected frequency must be at least 1 though
168 const count_type correctedFrequency = std::max(1l, uncorrectedFrequency - correction);
169 nCorrections += uncorrectedFrequency - correctedFrequency;
170 static_cast<count_type&>(value) = correctedFrequency;
171 } else {
172 break;
173 }
174 }
175
176 if (std::abs(nCorrections) > 0) {
177 throw HistogramError(fmt::format("rANS rescaling incomplete: {} corrections Remaining", nCorrections));
178 }
179
180 auto& coderProperties = metrics.getCoderProperties();
181 *coderProperties.renormingPrecisionBits = renormingPrecisionBits;
182 *coderProperties.nIncompressibleSymbols = nIncompressibleSymbols;
183 *coderProperties.nIncompressibleSamples = nIncompressibleSamples;
184
185 if constexpr (isDenseContainer_v<histogram_type>) {
186 RenormedDenseHistogram<source_type> ret{std::move(rescaledHistogram), renormingPrecisionBits, incompressibleSymbolFrequency};
187 std::tie(*coderProperties.min, *coderProperties.max) = getMinMax(ret);
188 return ret;
189 } else if constexpr (isAdaptiveContainer_v<histogram_type>) {
190 RenormedAdaptiveHistogram<source_type> ret{std::move(rescaledHistogram), renormingPrecisionBits, incompressibleSymbolFrequency};
191 std::tie(*coderProperties.min, *coderProperties.max) = getMinMax(ret);
192 return ret;
193 } else {
194 static_assert(isSetContainer_v<histogram_type>);
195 RenormedSparseHistogram<source_type> ret{std::move(rescaledHistogram), renormingPrecisionBits, incompressibleSymbolFrequency};
196 std::tie(*coderProperties.min, *coderProperties.max) = getMinMax(ret);
197 return ret;
198 }
199};
200} // namespace renormImpl
201
202template <typename histogram_T>
203decltype(auto) renorm(histogram_T histogram, size_t newPrecision, RenormingPolicy renormingPolicy = RenormingPolicy::Auto, size_t lowProbabilityCutoffBits = 0)
204{
205 using source_type = typename histogram_T::source_type;
206 const size_t nUsedAlphabetSymbols = renormImpl::getNUsedAlphabetSymbols(histogram);
209 metrics.getDatasetProperties().nUsedAlphabetSymbols = nUsedAlphabetSymbols;
210 return renormImpl::renorm(std::move(histogram), metrics, renormingPolicy, lowProbabilityCutoffBits);
211};
212
213template <typename histogram_T>
214decltype(auto) renorm(histogram_T histogram, Metrics<typename histogram_T::source_type>& metrics, RenormingPolicy renormingPolicy = RenormingPolicy::Auto, size_t lowProbabilityCutoffBits = 0)
215{
216 return renormImpl::renorm(std::move(histogram), metrics, renormingPolicy, lowProbabilityCutoffBits);
217};
218
219template <typename histogram_T>
220decltype(auto) renorm(histogram_T histogram, RenormingPolicy renormingPolicy = RenormingPolicy::Auto)
221{
222 using source_type = typename histogram_T::source_type;
223 Metrics<source_type> metrics{histogram};
224 return renorm(std::move(histogram), metrics, renormingPolicy);
225};
226
227} // namespace o2::rans
228
229#endif /* RANS_INTERNAL_TRANSFORM_RENORM_H_ */
Histogram for source symbols used to estimate symbol probabilities for entropy coding.
int32_t i
Computes and provides essential metrics on the dataset used for parameter and size estimates by other...
Histogram renormed to sum of frequencies being 2^P for use in fast rans coding.
Histogram to depict frequencies of source symbols for rANS compression, based on an ordered set.
common helper classes and functions
helper functionalities useful for packing operations
uint32_t source_type
const CoderProperties< source_type > & getCoderProperties() const noexcept
Definition Metrics.h:53
GLuint index
Definition glcorearb.h:781
GLsizei GLenum const void GLuint GLsizei GLfloat * metrics
Definition glcorearb.h:5500
GLdouble f
Definition glcorearb.h:310
GLboolean GLboolean GLboolean b
Definition glcorearb.h:1233
GLsizei const GLfloat * value
Definition glcorearb.h:819
GLboolean GLboolean GLboolean GLboolean a
Definition glcorearb.h:1233
count_t roundSymbolFrequency(double_t rescaledFrequency)
Definition utils.h:109
size_t getNUsedAlphabetSymbols(const DenseHistogram< source_T > &f)
Definition renorm.h:39
decltype(auto) renorm(histogram_T histogram, Metrics< typename histogram_T::source_type > &metrics, RenormingPolicy renormingPolicy, size_t lowProbabilityCutoffBits=0)
Definition renorm.h:62
constexpr size_t pow2(size_t n) noexcept
Definition utils.h:165
size_t countNUsedAlphabetSymbols(const AdaptiveHistogram< source_T > &histogram)
std::pair< source_T, source_T > getMinMax(const AdaptiveSymbolTable< source_T, symbol_T > &symbolTable)
decltype(auto) renorm(histogram_T histogram, size_t newPrecision, RenormingPolicy renormingPolicy=RenormingPolicy::Auto, size_t lowProbabilityCutoffBits=0)
Definition renorm.h:203
uint32_t count_t
Definition defaults.h:34
RenormingPolicy
Definition renorm.h:32
std::optional< size_t > renormingPrecisionBits
Definition properties.h:35
LOG(info)<< "Compressed in "<< sw.CpuTime()<< " s"