Mercurial > hg > CbC > CbC_llvm
diff include/llvm/Support/BranchProbability.h @ 95:afa8332a0e37 LLVM3.8
LLVM 3.8
author | Kaito Tokumori <e105711@ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 13 Oct 2015 17:48:58 +0900 |
parents | 54457678186b |
children | 7d135dc70f03 |
line wrap: on
line diff
--- a/include/llvm/Support/BranchProbability.h Wed Feb 18 14:56:07 2015 +0900 +++ b/include/llvm/Support/BranchProbability.h Tue Oct 13 17:48:58 2015 +0900 @@ -21,30 +21,43 @@ class raw_ostream; -// This class represents Branch Probability as a non-negative fraction. +// This class represents Branch Probability as a non-negative fraction that is +// no greater than 1. It uses a fixed-point-like implementation, in which the +// denominator is always a constant value (here we use 1<<31 for maximum +// precision). class BranchProbability { // Numerator uint32_t N; - // Denominator - uint32_t D; + // Denominator, which is a constant value. + static const uint32_t D = 1u << 31; + + // Construct a BranchProbability with only numerator assuming the denominator + // is 1<<31. For internal use only. + explicit BranchProbability(uint32_t n) : N(n) {} public: - BranchProbability(uint32_t n, uint32_t d) : N(n), D(d) { - assert(d > 0 && "Denomiator cannot be 0!"); - assert(n <= d && "Probability cannot be bigger than 1!"); - } + BranchProbability() : N(0) {} + BranchProbability(uint32_t Numerator, uint32_t Denominator); + + bool isZero() const { return N == 0; } - static BranchProbability getZero() { return BranchProbability(0, 1); } - static BranchProbability getOne() { return BranchProbability(1, 1); } + static BranchProbability getZero() { return BranchProbability(0); } + static BranchProbability getOne() { return BranchProbability(D); } + // Create a BranchProbability object with the given numerator and 1<<31 + // as denominator. + static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } + + // Normalize given probabilties so that the sum of them becomes approximate + // one. + template <class ProbabilityList> + static void normalizeProbabilities(ProbabilityList &Probs); uint32_t getNumerator() const { return N; } - uint32_t getDenominator() const { return D; } + static uint32_t getDenominator() { return D; } // Return (1 - Probability). - BranchProbability getCompl() const { - return BranchProbability(D - N, D); - } + BranchProbability getCompl() const { return BranchProbability(D - N); } raw_ostream &print(raw_ostream &OS) const; @@ -66,24 +79,62 @@ /// \return \c Num divided by \c this. uint64_t scaleByInverse(uint64_t Num) const; - bool operator==(BranchProbability RHS) const { - return (uint64_t)N * RHS.D == (uint64_t)D * RHS.N; + BranchProbability &operator+=(BranchProbability RHS) { + assert(N <= D - RHS.N && + "The sum of branch probabilities should not exceed one!"); + N += RHS.N; + return *this; + } + + BranchProbability &operator-=(BranchProbability RHS) { + assert(N >= RHS.N && + "Can only subtract a smaller probability from a larger one!"); + N -= RHS.N; + return *this; + } + + BranchProbability &operator*=(BranchProbability RHS) { + N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; + return *this; } - bool operator!=(BranchProbability RHS) const { - return !(*this == RHS); + + BranchProbability operator+(BranchProbability RHS) const { + BranchProbability Prob(*this); + return Prob += RHS; + } + + BranchProbability operator-(BranchProbability RHS) const { + BranchProbability Prob(*this); + return Prob -= RHS; } - bool operator<(BranchProbability RHS) const { - return (uint64_t)N * RHS.D < (uint64_t)D * RHS.N; + + BranchProbability operator*(BranchProbability RHS) const { + BranchProbability Prob(*this); + return Prob *= RHS; } + + bool operator==(BranchProbability RHS) const { return N == RHS.N; } + bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } + bool operator<(BranchProbability RHS) const { return N < RHS.N; } bool operator>(BranchProbability RHS) const { return RHS < *this; } bool operator<=(BranchProbability RHS) const { return !(RHS < *this); } bool operator>=(BranchProbability RHS) const { return !(*this < RHS); } }; -inline raw_ostream &operator<<(raw_ostream &OS, const BranchProbability &Prob) { +inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { return Prob.print(OS); } +template <class ProbabilityList> +void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) { + uint64_t Sum = 0; + for (auto Prob : Probs) + Sum += Prob.N; + assert(Sum > 0); + for (auto &Prob : Probs) + Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum; +} + } #endif