Mercurial > hg > CbC > CbC_llvm
diff include/llvm/Support/BranchProbability.h @ 100:7d135dc70f03 LLVM 3.9
LLVM 3.9
author | Miyagi Mitsuki <e135756@ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 26 Jan 2016 22:53:40 +0900 |
parents | afa8332a0e37 |
children | 1172e4bd9c6f |
line wrap: on
line diff
--- a/include/llvm/Support/BranchProbability.h Tue Oct 13 17:49:56 2015 +0900 +++ b/include/llvm/Support/BranchProbability.h Tue Jan 26 22:53:40 2016 +0900 @@ -15,7 +15,10 @@ #define LLVM_SUPPORT_BRANCHPROBABILITY_H #include "llvm/Support/DataTypes.h" +#include <algorithm> #include <cassert> +#include <climits> +#include <numeric> namespace llvm { @@ -31,27 +34,34 @@ // Denominator, which is a constant value. static const uint32_t D = 1u << 31; + static const uint32_t UnknownN = UINT32_MAX; // Construct a BranchProbability with only numerator assuming the denominator // is 1<<31. For internal use only. explicit BranchProbability(uint32_t n) : N(n) {} public: - BranchProbability() : N(0) {} + BranchProbability() : N(UnknownN) {} BranchProbability(uint32_t Numerator, uint32_t Denominator); bool isZero() const { return N == 0; } + bool isUnknown() const { return N == UnknownN; } static BranchProbability getZero() { return BranchProbability(0); } static BranchProbability getOne() { return BranchProbability(D); } + static BranchProbability getUnknown() { return BranchProbability(UnknownN); } // Create a BranchProbability object with the given numerator and 1<<31 // as denominator. static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } + // Create a BranchProbability object from 64-bit integers. + static BranchProbability getBranchProbability(uint64_t Numerator, + uint64_t Denominator); // Normalize given probabilties so that the sum of them becomes approximate // one. - template <class ProbabilityList> - static void normalizeProbabilities(ProbabilityList &Probs); + template <class ProbabilityIter> + static void normalizeProbabilities(ProbabilityIter Begin, + ProbabilityIter End); uint32_t getNumerator() const { return N; } static uint32_t getDenominator() { return D; } @@ -80,24 +90,36 @@ uint64_t scaleByInverse(uint64_t Num) const; BranchProbability &operator+=(BranchProbability RHS) { - assert(N <= D - RHS.N && - "The sum of branch probabilities should not exceed one!"); - N += RHS.N; + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in arithmetics."); + // Saturate the result in case of overflow. + N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N; return *this; } BranchProbability &operator-=(BranchProbability RHS) { - assert(N >= RHS.N && - "Can only subtract a smaller probability from a larger one!"); - N -= RHS.N; + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in arithmetics."); + // Saturate the result in case of underflow. + N = N < RHS.N ? 0 : N - RHS.N; return *this; } BranchProbability &operator*=(BranchProbability RHS) { + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in arithmetics."); N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; return *this; } + BranchProbability &operator/=(uint32_t RHS) { + assert(N != UnknownN && + "Unknown probability cannot participate in arithmetics."); + assert(RHS > 0 && "The divider cannot be zero."); + N /= RHS; + return *this; + } + BranchProbability operator+(BranchProbability RHS) const { BranchProbability Prob(*this); return Prob += RHS; @@ -113,26 +135,83 @@ return Prob *= RHS; } + BranchProbability operator/(uint32_t RHS) const { + BranchProbability Prob(*this); + return Prob /= RHS; + } + bool operator==(BranchProbability RHS) const { return N == RHS.N; } bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } - bool operator<(BranchProbability RHS) const { return N < RHS.N; } - bool operator>(BranchProbability RHS) const { return RHS < *this; } - bool operator<=(BranchProbability RHS) const { return !(RHS < *this); } - bool operator>=(BranchProbability RHS) const { return !(*this < RHS); } + + bool operator<(BranchProbability RHS) const { + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in comparisons."); + return N < RHS.N; + } + + bool operator>(BranchProbability RHS) const { + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in comparisons."); + return RHS < *this; + } + + bool operator<=(BranchProbability RHS) const { + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in comparisons."); + return !(RHS < *this); + } + + bool operator>=(BranchProbability RHS) const { + assert(N != UnknownN && RHS.N != UnknownN && + "Unknown probability cannot participate in comparisons."); + return !(*this < RHS); + } }; inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { return Prob.print(OS); } -template <class ProbabilityList> -void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) { - uint64_t Sum = 0; - for (auto Prob : Probs) - Sum += Prob.N; - assert(Sum > 0); - for (auto &Prob : Probs) - Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum; +template <class ProbabilityIter> +void BranchProbability::normalizeProbabilities(ProbabilityIter Begin, + ProbabilityIter End) { + if (Begin == End) + return; + + unsigned UnknownProbCount = 0; + uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), + [&](uint64_t S, const BranchProbability &BP) { + if (!BP.isUnknown()) + return S + BP.N; + UnknownProbCount++; + return S; + }); + + if (UnknownProbCount > 0) { + BranchProbability ProbForUnknown = BranchProbability::getZero(); + // If the sum of all known probabilities is less than one, evenly distribute + // the complement of sum to unknown probabilities. Otherwise, set unknown + // probabilities to zeros and continue to normalize known probabilities. + if (Sum < BranchProbability::getDenominator()) + ProbForUnknown = BranchProbability::getRaw( + (BranchProbability::getDenominator() - Sum) / UnknownProbCount); + + std::replace_if(Begin, End, + [](const BranchProbability &BP) { return BP.isUnknown(); }, + ProbForUnknown); + + if (Sum <= BranchProbability::getDenominator()) + return; + } + + if (Sum == 0) { + BranchProbability BP(1, std::distance(Begin, End)); + std::fill(Begin, End, BP); + return; + } + + for (auto I = Begin; I != End; ++I) + I->N = (I->N * uint64_t(D) + Sum / 2) / Sum; } }