67 if (histogram.empty()) {
68 LOG(warning) <<
"rescaling Frequency Table for empty message";
70 using histogram_type = histogram_T;
71 using source_type =
typename histogram_type::source_type;
72 using count_type =
typename histogram_type::value_type;
73 using difference_type =
typename histogram_type::difference_type;
74 using container_type =
typename histogram_type::container_type;
75 using iterator_type =
typename container_type::iterator;
77 const double_t nSamples = histogram.getNumSamples();
78 const size_t renormingPrecisionBits = *
metrics.getCoderProperties().renormingPrecisionBits;
79 const size_t nUsedAlphabetSymbols =
metrics.getDatasetProperties().nUsedAlphabetSymbols;
81 const count_type nSamplesRescaled =
utils::pow2(renormingPrecisionBits);
82 const double_t probabilityCutOffThreshold = 1.0 /
static_cast<double_t
>(
utils::pow2(renormingPrecisionBits + lowProbabilityCutoffBits));
85 double_t incompressibleSymbolProbability = 0;
86 count_type nIncompressibleSamples = 0;
87 count_type nIncompressibleSymbols = 0;
88 count_type nSamplesRescaledUncorrected = 0;
89 std::vector<std::pair<source_type, std::reference_wrapper<count_type>>> correctableIndices;
90 correctableIndices.reserve(nUsedAlphabetSymbols);
92 auto scaleFrequency = [nSamplesRescaled](double_t symbolProbability) -> double_t {
return symbolProbability * nSamplesRescaled; };
94 container_type rescaledHistogram = std::move(histogram).release();
98 const double_t symbolProbability =
static_cast<double_t
>(frequency) / nSamples;
99 if (symbolProbability < probabilityCutOffThreshold) {
100 nIncompressibleSamples += frequency;
101 ++nIncompressibleSymbols;
102 incompressibleSymbolProbability += symbolProbability;
105 const double_t scaledFrequencyD = scaleFrequency(symbolProbability);
107 assert(rescaledFrequency > 0);
108 frequency = rescaledFrequency;
109 nSamplesRescaledUncorrected += rescaledFrequency;
110 if (rescaledFrequency > 1) {
111 correctableIndices.emplace_back(std::make_pair(
index, std::ref(frequency)));
118 const count_type incompressibleSymbolFrequency = [&]() -> count_type {
120 const bool requireIncompressible = incompressibleSymbolProbability > 0.
125 return std::max(
static_cast<count_type
>(requireIncompressible),
static_cast<count_type
>(incompressibleSymbolProbability * nSamplesRescaled));
128 nSamplesRescaledUncorrected += incompressibleSymbolFrequency;
131 const auto nSorted = [&]() {
132 const auto& datasetProperties =
metrics.getDatasetProperties();
135 for (
size_t i = 0;
i < datasetProperties.weightedSymbolLengthDistribution.size(); ++
i) {
136 cumulProbability += datasetProperties.weightedSymbolLengthDistribution[
i];
137 nSymbols += datasetProperties.symbolLengthDistribution[
i];
138 if (cumulProbability > 0.99) {
146 std::partial_sort(correctableIndices.begin(), correctableIndices.begin() + nSorted, correctableIndices.end(), [](
const auto&
a,
const auto&
b) { return a.second < b.second; });
148 std::stable_sort(correctableIndices.begin(), correctableIndices.end(), [](
const auto&
a,
const auto&
b) { return a.second < b.second; });
151 difference_type nCorrections =
static_cast<difference_type
>(nSamplesRescaled) -
static_cast<difference_type
>(nSamplesRescaledUncorrected);
152 const double_t rescalingFactor =
static_cast<double_t
>(nSamplesRescaled) /
static_cast<double_t
>(nSamplesRescaledUncorrected);
154 for (
auto& [
index,
value] : correctableIndices) {
155 if (std::abs(nCorrections) > 0) {
156 const difference_type uncorrectedFrequency =
value;
157 difference_type correction = uncorrectedFrequency - roundSymbolFrequency(uncorrectedFrequency * rescalingFactor);
159 if (nCorrections < 0) {
161 correction = std::max(1l, std::min(correction, std::abs(nCorrections)));
164 correction = std::min(-1l, std::max(correction, -nCorrections));
168 const count_type correctedFrequency = std::max(1l, uncorrectedFrequency - correction);
169 nCorrections += uncorrectedFrequency - correctedFrequency;
170 static_cast<count_type&
>(
value) = correctedFrequency;
176 if (std::abs(nCorrections) > 0) {
177 throw HistogramError(fmt::format(
"rANS rescaling incomplete: {} corrections Remaining", nCorrections));
180 auto& coderProperties =
metrics.getCoderProperties();
181 *coderProperties.renormingPrecisionBits = renormingPrecisionBits;
182 *coderProperties.nIncompressibleSymbols = nIncompressibleSymbols;
183 *coderProperties.nIncompressibleSamples = nIncompressibleSamples;
185 if constexpr (isDenseContainer_v<histogram_type>) {
187 std::tie(*coderProperties.min, *coderProperties.max) =
getMinMax(ret);
189 }
else if constexpr (isAdaptiveContainer_v<histogram_type>) {
191 std::tie(*coderProperties.min, *coderProperties.max) =
getMinMax(ret);
194 static_assert(isSetContainer_v<histogram_type>);
196 std::tie(*coderProperties.min, *coderProperties.max) =
getMinMax(ret);