Mercurial > hg > CbC > CbC_llvm
diff llvm/lib/IR/ConstantRange.cpp @ 236:c4bab56944e8 llvm-original
LLVM 16
author | kono |
---|---|
date | Wed, 09 Nov 2022 17:45:10 +0900 |
parents | 5f17cb93ff66 |
children | 1f2b6ac9f198 |
line wrap: on
line diff
--- a/llvm/lib/IR/ConstantRange.cpp Wed Jul 21 10:27:27 2021 +0900 +++ b/llvm/lib/IR/ConstantRange.cpp Wed Nov 09 17:45:10 2022 +0900 @@ -75,6 +75,24 @@ return ConstantRange(Lower, Upper + 1); } +KnownBits ConstantRange::toKnownBits() const { + // TODO: We could return conflicting known bits here, but consumers are + // likely not prepared for that. + if (isEmptySet()) + return KnownBits(getBitWidth()); + + // We can only retain the top bits that are the same between min and max. + APInt Min = getUnsignedMin(); + APInt Max = getUnsignedMax(); + KnownBits Known = KnownBits::makeConstant(Min); + if (Optional<unsigned> DifferentBit = + APIntOps::GetMostSignificantDifferentBit(Min, Max)) { + Known.Zero.clearLowBits(*DifferentBit + 1); + Known.One.clearLowBits(*DifferentBit + 1); + } + return Known; +} + ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, const ConstantRange &CR) { if (CR.isEmptySet()) @@ -110,7 +128,7 @@ APInt UMin(CR.getUnsignedMin()); if (UMin.isMaxValue()) return getEmpty(W); - return ConstantRange(std::move(UMin) + 1, APInt::getNullValue(W)); + return ConstantRange(std::move(UMin) + 1, APInt::getZero(W)); } case CmpInst::ICMP_SGT: { APInt SMin(CR.getSignedMin()); @@ -119,7 +137,7 @@ return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W)); } case CmpInst::ICMP_UGE: - return getNonEmpty(CR.getUnsignedMin(), APInt::getNullValue(W)); + return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W)); case CmpInst::ICMP_SGE: return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W)); } @@ -147,38 +165,77 @@ return makeAllowedICmpRegion(Pred, C); } -bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred, - APInt &RHS) const { - bool Success = false; +bool ConstantRange::areInsensitiveToSignednessOfICmpPredicate( + const ConstantRange &CR1, const ConstantRange &CR2) { + if (CR1.isEmptySet() || CR2.isEmptySet()) + return true; + + return (CR1.isAllNonNegative() && CR2.isAllNonNegative()) || + (CR1.isAllNegative() && CR2.isAllNegative()); +} + +bool ConstantRange::areInsensitiveToSignednessOfInvertedICmpPredicate( + const ConstantRange &CR1, const ConstantRange &CR2) { + if (CR1.isEmptySet() || CR2.isEmptySet()) + return true; + + return (CR1.isAllNonNegative() && CR2.isAllNegative()) || + (CR1.isAllNegative() && CR2.isAllNonNegative()); +} +CmpInst::Predicate ConstantRange::getEquivalentPredWithFlippedSignedness( + CmpInst::Predicate Pred, const ConstantRange &CR1, + const ConstantRange &CR2) { + assert(CmpInst::isIntPredicate(Pred) && CmpInst::isRelational(Pred) && + "Only for relational integer predicates!"); + + CmpInst::Predicate FlippedSignednessPred = + CmpInst::getFlippedSignednessPredicate(Pred); + + if (areInsensitiveToSignednessOfICmpPredicate(CR1, CR2)) + return FlippedSignednessPred; + + if (areInsensitiveToSignednessOfInvertedICmpPredicate(CR1, CR2)) + return CmpInst::getInversePredicate(FlippedSignednessPred); + + return CmpInst::Predicate::BAD_ICMP_PREDICATE; +} + +void ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred, + APInt &RHS, APInt &Offset) const { + Offset = APInt(getBitWidth(), 0); if (isFullSet() || isEmptySet()) { Pred = isEmptySet() ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; RHS = APInt(getBitWidth(), 0); - Success = true; } else if (auto *OnlyElt = getSingleElement()) { Pred = CmpInst::ICMP_EQ; RHS = *OnlyElt; - Success = true; } else if (auto *OnlyMissingElt = getSingleMissingElement()) { Pred = CmpInst::ICMP_NE; RHS = *OnlyMissingElt; - Success = true; } else if (getLower().isMinSignedValue() || getLower().isMinValue()) { Pred = getLower().isMinSignedValue() ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; RHS = getUpper(); - Success = true; } else if (getUpper().isMinSignedValue() || getUpper().isMinValue()) { Pred = getUpper().isMinSignedValue() ? CmpInst::ICMP_SGE : CmpInst::ICMP_UGE; RHS = getLower(); - Success = true; + } else { + Pred = CmpInst::ICMP_ULT; + RHS = getUpper() - getLower(); + Offset = -getLower(); } - assert((!Success || ConstantRange::makeExactICmpRegion(Pred, RHS) == *this) && + assert(ConstantRange::makeExactICmpRegion(Pred, RHS) == add(Offset) && "Bad result!"); +} - return Success; +bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred, + APInt &RHS) const { + APInt Offset; + getEquivalentICmp(Pred, RHS, Offset); + return Offset.isZero(); } bool ConstantRange::icmp(CmpInst::Predicate Pred, @@ -204,13 +261,13 @@ // Handle special case for 0, -1 and 1. See the last for reason why we // specialize -1 and 1. unsigned BitWidth = V.getBitWidth(); - if (V == 0 || V.isOneValue()) + if (V == 0 || V.isOne()) return ConstantRange::getFull(BitWidth); APInt MinValue = APInt::getSignedMinValue(BitWidth); APInt MaxValue = APInt::getSignedMaxValue(BitWidth); // e.g. Returning [-127, 127], represented as [-127, -128). - if (V.isAllOnesValue()) + if (V.isAllOnes()) return ConstantRange(-MaxValue, MinValue); APInt Lower, Upper; @@ -248,8 +305,7 @@ case Instruction::Add: { if (Unsigned) - return getNonEmpty(APInt::getNullValue(BitWidth), - -Other.getUnsignedMax()); + return getNonEmpty(APInt::getZero(BitWidth), -Other.getUnsignedMax()); APInt SignedMinVal = APInt::getSignedMinValue(BitWidth); APInt SMin = Other.getSignedMin(), SMax = Other.getSignedMax(); @@ -291,7 +347,7 @@ // to be at most bitwidth-1, which results in most conservative range. APInt ShAmtUMax = ShAmt.getUnsignedMax(); if (Unsigned) - return getNonEmpty(APInt::getNullValue(BitWidth), + return getNonEmpty(APInt::getZero(BitWidth), APInt::getMaxValue(BitWidth).lshr(ShAmtUMax) + 1); return getNonEmpty(APInt::getSignedMinValue(BitWidth).ashr(ShAmtUMax), APInt::getSignedMaxValue(BitWidth).ashr(ShAmtUMax) + 1); @@ -316,7 +372,7 @@ } bool ConstantRange::isWrappedSet() const { - return Lower.ugt(Upper) && !Upper.isNullValue(); + return Lower.ugt(Upper) && !Upper.isZero(); } bool ConstantRange::isUpperWrapped() const { @@ -343,11 +399,10 @@ bool ConstantRange::isSizeLargerThan(uint64_t MaxSize) const { - assert(MaxSize && "MaxSize can't be 0."); // If this a full set, we need special handling to avoid needing an extra bit // to represent the size. if (isFullSet()) - return APInt::getMaxValue(getBitWidth()).ugt(MaxSize - 1); + return MaxSize == 0 || APInt::getMaxValue(getBitWidth()).ugt(MaxSize - 1); return (Upper - Lower).ugt(MaxSize); } @@ -595,7 +650,7 @@ APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower; APInt U = (CR.Upper - 1).ugt(Upper - 1) ? CR.Upper : Upper; - if (L.isNullValue() && U.isNullValue()) + if (L.isZero() && U.isZero()) return getFull(); return ConstantRange(std::move(L), std::move(U)); @@ -644,6 +699,24 @@ return ConstantRange(std::move(L), std::move(U)); } +Optional<ConstantRange> +ConstantRange::exactIntersectWith(const ConstantRange &CR) const { + // TODO: This can be implemented more efficiently. + ConstantRange Result = intersectWith(CR); + if (Result == inverse().unionWith(CR.inverse()).inverse()) + return Result; + return None; +} + +Optional<ConstantRange> +ConstantRange::exactUnionWith(const ConstantRange &CR) const { + // TODO: This can be implemented more efficiently. + ConstantRange Result = unionWith(CR); + if (Result == inverse().intersectWith(CR.inverse()).inverse()) + return Result; + return None; +} + ConstantRange ConstantRange::castOp(Instruction::CastOps CastOp, uint32_t ResultBitWidth) const { switch (CastOp) { @@ -666,15 +739,23 @@ case Instruction::UIToFP: { // TODO: use input range if available auto BW = getBitWidth(); - APInt Min = APInt::getMinValue(BW).zextOrSelf(ResultBitWidth); - APInt Max = APInt::getMaxValue(BW).zextOrSelf(ResultBitWidth); + APInt Min = APInt::getMinValue(BW); + APInt Max = APInt::getMaxValue(BW); + if (ResultBitWidth > BW) { + Min = Min.zext(ResultBitWidth); + Max = Max.zext(ResultBitWidth); + } return ConstantRange(std::move(Min), std::move(Max)); } case Instruction::SIToFP: { // TODO: use input range if available auto BW = getBitWidth(); - APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(ResultBitWidth); - APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(ResultBitWidth); + APInt SMin = APInt::getSignedMinValue(BW); + APInt SMax = APInt::getSignedMaxValue(BW); + if (ResultBitWidth > BW) { + SMin = SMin.sext(ResultBitWidth); + SMax = SMax.sext(ResultBitWidth); + } return ConstantRange(std::move(SMin), std::move(SMax)); } case Instruction::FPTrunc: @@ -1055,6 +1136,25 @@ return UR.isSizeStrictlySmallerThan(SR) ? UR : SR; } +ConstantRange ConstantRange::smul_fast(const ConstantRange &Other) const { + if (isEmptySet() || Other.isEmptySet()) + return getEmpty(); + + APInt Min = getSignedMin(); + APInt Max = getSignedMax(); + APInt OtherMin = Other.getSignedMin(); + APInt OtherMax = Other.getSignedMax(); + + bool O1, O2, O3, O4; + auto Muls = {Min.smul_ov(OtherMin, O1), Min.smul_ov(OtherMax, O2), + Max.smul_ov(OtherMin, O3), Max.smul_ov(OtherMax, O4)}; + if (O1 || O2 || O3 || O4) + return getFull(); + + auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); }; + return getNonEmpty(std::min(Muls, Compare), std::max(Muls, Compare) + 1); +} + ConstantRange ConstantRange::smax(const ConstantRange &Other) const { // X smax Y is: range(smax(X_smin, Y_smin), @@ -1113,13 +1213,13 @@ ConstantRange ConstantRange::udiv(const ConstantRange &RHS) const { - if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isNullValue()) + if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isZero()) return getEmpty(); APInt Lower = getUnsignedMin().udiv(RHS.getUnsignedMax()); APInt RHS_umin = RHS.getUnsignedMin(); - if (RHS_umin.isNullValue()) { + if (RHS_umin.isZero()) { // We want the lowest value in RHS excluding zero. Usually that would be 1 // except for a range in the form of [X, 1) in which case it would be X. if (RHS.getUpper() == 1) @@ -1136,9 +1236,12 @@ // We split up the LHS and RHS into positive and negative components // and then also compute the positive and negative components of the result // separately by combining division results with the appropriate signs. - APInt Zero = APInt::getNullValue(getBitWidth()); + APInt Zero = APInt::getZero(getBitWidth()); APInt SignedMin = APInt::getSignedMinValue(getBitWidth()); - ConstantRange PosFilter(APInt(getBitWidth(), 1), SignedMin); + // There are no positive 1-bit values. The 1 would get interpreted as -1. + ConstantRange PosFilter = + getBitWidth() == 1 ? getEmpty() + : ConstantRange(APInt(getBitWidth(), 1), SignedMin); ConstantRange NegFilter(SignedMin, Zero); ConstantRange PosL = intersectWith(PosFilter); ConstantRange NegL = intersectWith(NegFilter); @@ -1159,12 +1262,12 @@ // (For APInts the operation is well-defined and yields SignedMin.) We // handle this by dropping either SignedMin from the LHS or -1 from the RHS. APInt Lo = (NegL.Upper - 1).sdiv(NegR.Lower); - if (NegL.Lower.isMinSignedValue() && NegR.Upper.isNullValue()) { + if (NegL.Lower.isMinSignedValue() && NegR.Upper.isZero()) { // Remove -1 from the LHS. Skip if it's the only element, as this would // leave us with an empty set. - if (!NegR.Lower.isAllOnesValue()) { + if (!NegR.Lower.isAllOnes()) { APInt AdjNegRUpper; - if (RHS.Lower.isAllOnesValue()) + if (RHS.Lower.isAllOnes()) // Negative part of [-1, X] without -1 is [SignedMin, X]. AdjNegRUpper = RHS.Upper; else @@ -1218,12 +1321,12 @@ } ConstantRange ConstantRange::urem(const ConstantRange &RHS) const { - if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isNullValue()) + if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isZero()) return getEmpty(); if (const APInt *RHSInt = RHS.getSingleElement()) { // UREM by null is UB. - if (RHSInt->isNullValue()) + if (RHSInt->isZero()) return getEmpty(); // Use APInt's implementation of UREM for single element ranges. if (const APInt *LHSInt = getSingleElement()) @@ -1236,7 +1339,7 @@ // L % R is <= L and < R. APInt Upper = APIntOps::umin(getUnsignedMax(), RHS.getUnsignedMax() - 1) + 1; - return getNonEmpty(APInt::getNullValue(getBitWidth()), std::move(Upper)); + return getNonEmpty(APInt::getZero(getBitWidth()), std::move(Upper)); } ConstantRange ConstantRange::srem(const ConstantRange &RHS) const { @@ -1245,7 +1348,7 @@ if (const APInt *RHSInt = RHS.getSingleElement()) { // SREM by null is UB. - if (RHSInt->isNullValue()) + if (RHSInt->isZero()) return getEmpty(); // Use APInt's implementation of SREM for single element ranges. if (const APInt *LHSInt = getSingleElement()) @@ -1257,10 +1360,10 @@ APInt MaxAbsRHS = AbsRHS.getUnsignedMax(); // Modulus by zero is UB. - if (MaxAbsRHS.isNullValue()) + if (MaxAbsRHS.isZero()) return getEmpty(); - if (MinAbsRHS.isNullValue()) + if (MinAbsRHS.isZero()) ++MinAbsRHS; APInt MinLHS = getSignedMin(), MaxLHS = getSignedMax(); @@ -1272,7 +1375,7 @@ // L % R is <= L and < R. APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1; - return ConstantRange(APInt::getNullValue(getBitWidth()), std::move(Upper)); + return ConstantRange(APInt::getZero(getBitWidth()), std::move(Upper)); } // Same basic logic as above, but the result is negative. @@ -1291,37 +1394,32 @@ } ConstantRange ConstantRange::binaryNot() const { - return ConstantRange(APInt::getAllOnesValue(getBitWidth())).sub(*this); + return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this); } -ConstantRange -ConstantRange::binaryAnd(const ConstantRange &Other) const { +ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) return getEmpty(); - // Use APInt's implementation of AND for single element ranges. - if (isSingleElement() && Other.isSingleElement()) - return {*getSingleElement() & *Other.getSingleElement()}; - - // TODO: replace this with something less conservative - - APInt umin = APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()); - return getNonEmpty(APInt::getNullValue(getBitWidth()), std::move(umin) + 1); + ConstantRange KnownBitsRange = + fromKnownBits(toKnownBits() & Other.toKnownBits(), false); + ConstantRange UMinUMaxRange = + getNonEmpty(APInt::getZero(getBitWidth()), + APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1); + return KnownBitsRange.intersectWith(UMinUMaxRange); } -ConstantRange -ConstantRange::binaryOr(const ConstantRange &Other) const { +ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) return getEmpty(); - // Use APInt's implementation of OR for single element ranges. - if (isSingleElement() && Other.isSingleElement()) - return {*getSingleElement() | *Other.getSingleElement()}; - - // TODO: replace this with something less conservative - - APInt umax = APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()); - return getNonEmpty(std::move(umax), APInt::getNullValue(getBitWidth())); + ConstantRange KnownBitsRange = + fromKnownBits(toKnownBits() | Other.toKnownBits(), false); + // Upper wrapped range. + ConstantRange UMaxUMinRange = + getNonEmpty(APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), + APInt::getZero(getBitWidth())); + return KnownBitsRange.intersectWith(UMaxUMinRange); } ConstantRange ConstantRange::binaryXor(const ConstantRange &Other) const { @@ -1333,13 +1431,12 @@ return {*getSingleElement() ^ *Other.getSingleElement()}; // Special-case binary complement, since we can give a precise answer. - if (Other.isSingleElement() && Other.getSingleElement()->isAllOnesValue()) + if (Other.isSingleElement() && Other.getSingleElement()->isAllOnes()) return binaryNot(); - if (isSingleElement() && getSingleElement()->isAllOnesValue()) + if (isSingleElement() && getSingleElement()->isAllOnes()) return Other.binaryNot(); - // TODO: replace this with something less conservative - return getFull(); + return fromKnownBits(toKnownBits() ^ Other.toKnownBits(), /*IsSigned*/false); } ConstantRange @@ -1347,24 +1444,33 @@ if (isEmptySet() || Other.isEmptySet()) return getEmpty(); - APInt max = getUnsignedMax(); - APInt Other_umax = Other.getUnsignedMax(); + APInt Min = getUnsignedMin(); + APInt Max = getUnsignedMax(); + if (const APInt *RHS = Other.getSingleElement()) { + unsigned BW = getBitWidth(); + if (RHS->uge(BW)) + return getEmpty(); - // If we are shifting by maximum amount of - // zero return return the original range. - if (Other_umax.isNullValue()) - return *this; - // there's overflow! - if (Other_umax.ugt(max.countLeadingZeros())) + unsigned EqualLeadingBits = (Min ^ Max).countLeadingZeros(); + if (RHS->ule(EqualLeadingBits)) + return getNonEmpty(Min << *RHS, (Max << *RHS) + 1); + + return getNonEmpty(APInt::getZero(BW), + APInt::getBitsSetFrom(BW, RHS->getZExtValue()) + 1); + } + + APInt OtherMax = Other.getUnsignedMax(); + + // There's overflow! + if (OtherMax.ugt(Max.countLeadingZeros())) return getFull(); // FIXME: implement the other tricky cases - APInt min = getUnsignedMin(); - min <<= Other.getUnsignedMin(); - max <<= Other_umax; + Min <<= Other.getUnsignedMin(); + Max <<= OtherMax; - return ConstantRange(std::move(min), std::move(max) + 1); + return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1); } ConstantRange @@ -1483,20 +1589,15 @@ // [-1,4) * [-2,3) = min(-1*-2, -1*2, 3*-2, 3*2) = -6. // Similarly for the upper bound, swapping min for max. - APInt this_min = getSignedMin().sext(getBitWidth() * 2); - APInt this_max = getSignedMax().sext(getBitWidth() * 2); - APInt Other_min = Other.getSignedMin().sext(getBitWidth() * 2); - APInt Other_max = Other.getSignedMax().sext(getBitWidth() * 2); + APInt Min = getSignedMin(); + APInt Max = getSignedMax(); + APInt OtherMin = Other.getSignedMin(); + APInt OtherMax = Other.getSignedMax(); - auto L = {this_min * Other_min, this_min * Other_max, this_max * Other_min, - this_max * Other_max}; + auto L = {Min.smul_sat(OtherMin), Min.smul_sat(OtherMax), + Max.smul_sat(OtherMin), Max.smul_sat(OtherMax)}; auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); }; - - // Note that we wanted to perform signed saturating multiplication, - // so since we performed plain multiplication in twice the bitwidth, - // we need to perform signed saturating truncation. - return getNonEmpty(std::min(L, Compare).truncSSat(getBitWidth()), - std::max(L, Compare).truncSSat(getBitWidth()) + 1); + return getNonEmpty(std::min(L, Compare), std::max(L, Compare) + 1); } ConstantRange ConstantRange::ushl_sat(const ConstantRange &Other) const { @@ -1535,7 +1636,7 @@ APInt Lo; // Check whether the range crosses zero. if (Upper.isStrictlyPositive() || !Lower.isStrictlyPositive()) - Lo = APInt::getNullValue(getBitWidth()); + Lo = APInt::getZero(getBitWidth()); else Lo = APIntOps::umin(Lower, -Upper + 1); @@ -1565,7 +1666,7 @@ return ConstantRange(-SMax, -SMin + 1); // Range crosses zero. - return ConstantRange(APInt::getNullValue(getBitWidth()), + return ConstantRange(APInt::getZero(getBitWidth()), APIntOps::umax(-SMin, SMax) + 1); }