comparison llvm/lib/IR/ConstantRange.cpp @ 239:173fe712db74

merge LLVM16
author kono
date Wed, 09 Nov 2022 18:03:41 +0900
parents c4bab56944e8
children 1f2b6ac9f198
comparison
equal deleted inserted replaced
238:8222e65f95b1 239:173fe712db74
73 Lower.setSignBit(); 73 Lower.setSignBit();
74 Upper.clearSignBit(); 74 Upper.clearSignBit();
75 return ConstantRange(Lower, Upper + 1); 75 return ConstantRange(Lower, Upper + 1);
76 } 76 }
77 77
78 KnownBits ConstantRange::toKnownBits() const {
79 // TODO: We could return conflicting known bits here, but consumers are
80 // likely not prepared for that.
81 if (isEmptySet())
82 return KnownBits(getBitWidth());
83
84 // We can only retain the top bits that are the same between min and max.
85 APInt Min = getUnsignedMin();
86 APInt Max = getUnsignedMax();
87 KnownBits Known = KnownBits::makeConstant(Min);
88 if (Optional<unsigned> DifferentBit =
89 APIntOps::GetMostSignificantDifferentBit(Min, Max)) {
90 Known.Zero.clearLowBits(*DifferentBit + 1);
91 Known.One.clearLowBits(*DifferentBit + 1);
92 }
93 return Known;
94 }
95
78 ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, 96 ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
79 const ConstantRange &CR) { 97 const ConstantRange &CR) {
80 if (CR.isEmptySet()) 98 if (CR.isEmptySet())
81 return CR; 99 return CR;
82 100
108 return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1); 126 return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
109 case CmpInst::ICMP_UGT: { 127 case CmpInst::ICMP_UGT: {
110 APInt UMin(CR.getUnsignedMin()); 128 APInt UMin(CR.getUnsignedMin());
111 if (UMin.isMaxValue()) 129 if (UMin.isMaxValue())
112 return getEmpty(W); 130 return getEmpty(W);
113 return ConstantRange(std::move(UMin) + 1, APInt::getNullValue(W)); 131 return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
114 } 132 }
115 case CmpInst::ICMP_SGT: { 133 case CmpInst::ICMP_SGT: {
116 APInt SMin(CR.getSignedMin()); 134 APInt SMin(CR.getSignedMin());
117 if (SMin.isMaxSignedValue()) 135 if (SMin.isMaxSignedValue())
118 return getEmpty(W); 136 return getEmpty(W);
119 return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W)); 137 return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W));
120 } 138 }
121 case CmpInst::ICMP_UGE: 139 case CmpInst::ICMP_UGE:
122 return getNonEmpty(CR.getUnsignedMin(), APInt::getNullValue(W)); 140 return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
123 case CmpInst::ICMP_SGE: 141 case CmpInst::ICMP_SGE:
124 return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W)); 142 return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
125 } 143 }
126 } 144 }
127 145
145 // 163 //
146 assert(makeAllowedICmpRegion(Pred, C) == makeSatisfyingICmpRegion(Pred, C)); 164 assert(makeAllowedICmpRegion(Pred, C) == makeSatisfyingICmpRegion(Pred, C));
147 return makeAllowedICmpRegion(Pred, C); 165 return makeAllowedICmpRegion(Pred, C);
148 } 166 }
149 167
150 bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred, 168 bool ConstantRange::areInsensitiveToSignednessOfICmpPredicate(
151 APInt &RHS) const { 169 const ConstantRange &CR1, const ConstantRange &CR2) {
152 bool Success = false; 170 if (CR1.isEmptySet() || CR2.isEmptySet())
153 171 return true;
172
173 return (CR1.isAllNonNegative() && CR2.isAllNonNegative()) ||
174 (CR1.isAllNegative() && CR2.isAllNegative());
175 }
176
177 bool ConstantRange::areInsensitiveToSignednessOfInvertedICmpPredicate(
178 const ConstantRange &CR1, const ConstantRange &CR2) {
179 if (CR1.isEmptySet() || CR2.isEmptySet())
180 return true;
181
182 return (CR1.isAllNonNegative() && CR2.isAllNegative()) ||
183 (CR1.isAllNegative() && CR2.isAllNonNegative());
184 }
185
186 CmpInst::Predicate ConstantRange::getEquivalentPredWithFlippedSignedness(
187 CmpInst::Predicate Pred, const ConstantRange &CR1,
188 const ConstantRange &CR2) {
189 assert(CmpInst::isIntPredicate(Pred) && CmpInst::isRelational(Pred) &&
190 "Only for relational integer predicates!");
191
192 CmpInst::Predicate FlippedSignednessPred =
193 CmpInst::getFlippedSignednessPredicate(Pred);
194
195 if (areInsensitiveToSignednessOfICmpPredicate(CR1, CR2))
196 return FlippedSignednessPred;
197
198 if (areInsensitiveToSignednessOfInvertedICmpPredicate(CR1, CR2))
199 return CmpInst::getInversePredicate(FlippedSignednessPred);
200
201 return CmpInst::Predicate::BAD_ICMP_PREDICATE;
202 }
203
204 void ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred,
205 APInt &RHS, APInt &Offset) const {
206 Offset = APInt(getBitWidth(), 0);
154 if (isFullSet() || isEmptySet()) { 207 if (isFullSet() || isEmptySet()) {
155 Pred = isEmptySet() ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; 208 Pred = isEmptySet() ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
156 RHS = APInt(getBitWidth(), 0); 209 RHS = APInt(getBitWidth(), 0);
157 Success = true;
158 } else if (auto *OnlyElt = getSingleElement()) { 210 } else if (auto *OnlyElt = getSingleElement()) {
159 Pred = CmpInst::ICMP_EQ; 211 Pred = CmpInst::ICMP_EQ;
160 RHS = *OnlyElt; 212 RHS = *OnlyElt;
161 Success = true;
162 } else if (auto *OnlyMissingElt = getSingleMissingElement()) { 213 } else if (auto *OnlyMissingElt = getSingleMissingElement()) {
163 Pred = CmpInst::ICMP_NE; 214 Pred = CmpInst::ICMP_NE;
164 RHS = *OnlyMissingElt; 215 RHS = *OnlyMissingElt;
165 Success = true;
166 } else if (getLower().isMinSignedValue() || getLower().isMinValue()) { 216 } else if (getLower().isMinSignedValue() || getLower().isMinValue()) {
167 Pred = 217 Pred =
168 getLower().isMinSignedValue() ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; 218 getLower().isMinSignedValue() ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
169 RHS = getUpper(); 219 RHS = getUpper();
170 Success = true;
171 } else if (getUpper().isMinSignedValue() || getUpper().isMinValue()) { 220 } else if (getUpper().isMinSignedValue() || getUpper().isMinValue()) {
172 Pred = 221 Pred =
173 getUpper().isMinSignedValue() ? CmpInst::ICMP_SGE : CmpInst::ICMP_UGE; 222 getUpper().isMinSignedValue() ? CmpInst::ICMP_SGE : CmpInst::ICMP_UGE;
174 RHS = getLower(); 223 RHS = getLower();
175 Success = true; 224 } else {
176 } 225 Pred = CmpInst::ICMP_ULT;
177 226 RHS = getUpper() - getLower();
178 assert((!Success || ConstantRange::makeExactICmpRegion(Pred, RHS) == *this) && 227 Offset = -getLower();
228 }
229
230 assert(ConstantRange::makeExactICmpRegion(Pred, RHS) == add(Offset) &&
179 "Bad result!"); 231 "Bad result!");
180 232 }
181 return Success; 233
234 bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred,
235 APInt &RHS) const {
236 APInt Offset;
237 getEquivalentICmp(Pred, RHS, Offset);
238 return Offset.isZero();
182 } 239 }
183 240
184 bool ConstantRange::icmp(CmpInst::Predicate Pred, 241 bool ConstantRange::icmp(CmpInst::Predicate Pred,
185 const ConstantRange &Other) const { 242 const ConstantRange &Other) const {
186 return makeSatisfyingICmpRegion(Pred, Other).contains(*this); 243 return makeSatisfyingICmpRegion(Pred, Other).contains(*this);
202 /// Exact mul nsw region for single element RHS. 259 /// Exact mul nsw region for single element RHS.
203 static ConstantRange makeExactMulNSWRegion(const APInt &V) { 260 static ConstantRange makeExactMulNSWRegion(const APInt &V) {
204 // Handle special case for 0, -1 and 1. See the last for reason why we 261 // Handle special case for 0, -1 and 1. See the last for reason why we
205 // specialize -1 and 1. 262 // specialize -1 and 1.
206 unsigned BitWidth = V.getBitWidth(); 263 unsigned BitWidth = V.getBitWidth();
207 if (V == 0 || V.isOneValue()) 264 if (V == 0 || V.isOne())
208 return ConstantRange::getFull(BitWidth); 265 return ConstantRange::getFull(BitWidth);
209 266
210 APInt MinValue = APInt::getSignedMinValue(BitWidth); 267 APInt MinValue = APInt::getSignedMinValue(BitWidth);
211 APInt MaxValue = APInt::getSignedMaxValue(BitWidth); 268 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
212 // e.g. Returning [-127, 127], represented as [-127, -128). 269 // e.g. Returning [-127, 127], represented as [-127, -128).
213 if (V.isAllOnesValue()) 270 if (V.isAllOnes())
214 return ConstantRange(-MaxValue, MinValue); 271 return ConstantRange(-MaxValue, MinValue);
215 272
216 APInt Lower, Upper; 273 APInt Lower, Upper;
217 if (V.isNegative()) { 274 if (V.isNegative()) {
218 Lower = APIntOps::RoundingSDiv(MaxValue, V, APInt::Rounding::UP); 275 Lower = APIntOps::RoundingSDiv(MaxValue, V, APInt::Rounding::UP);
246 default: 303 default:
247 llvm_unreachable("Unsupported binary op"); 304 llvm_unreachable("Unsupported binary op");
248 305
249 case Instruction::Add: { 306 case Instruction::Add: {
250 if (Unsigned) 307 if (Unsigned)
251 return getNonEmpty(APInt::getNullValue(BitWidth), 308 return getNonEmpty(APInt::getZero(BitWidth), -Other.getUnsignedMax());
252 -Other.getUnsignedMax());
253 309
254 APInt SignedMinVal = APInt::getSignedMinValue(BitWidth); 310 APInt SignedMinVal = APInt::getSignedMinValue(BitWidth);
255 APInt SMin = Other.getSignedMin(), SMax = Other.getSignedMax(); 311 APInt SMin = Other.getSignedMin(), SMax = Other.getSignedMax();
256 return getNonEmpty( 312 return getNonEmpty(
257 SMin.isNegative() ? SignedMinVal - SMin : SignedMinVal, 313 SMin.isNegative() ? SignedMinVal - SMin : SignedMinVal,
289 // There are some legal shift amounts, we can compute conservatively-correct 345 // There are some legal shift amounts, we can compute conservatively-correct
290 // range of no-wrap inputs. Note that by now we have clamped the ShAmtUMax 346 // range of no-wrap inputs. Note that by now we have clamped the ShAmtUMax
291 // to be at most bitwidth-1, which results in most conservative range. 347 // to be at most bitwidth-1, which results in most conservative range.
292 APInt ShAmtUMax = ShAmt.getUnsignedMax(); 348 APInt ShAmtUMax = ShAmt.getUnsignedMax();
293 if (Unsigned) 349 if (Unsigned)
294 return getNonEmpty(APInt::getNullValue(BitWidth), 350 return getNonEmpty(APInt::getZero(BitWidth),
295 APInt::getMaxValue(BitWidth).lshr(ShAmtUMax) + 1); 351 APInt::getMaxValue(BitWidth).lshr(ShAmtUMax) + 1);
296 return getNonEmpty(APInt::getSignedMinValue(BitWidth).ashr(ShAmtUMax), 352 return getNonEmpty(APInt::getSignedMinValue(BitWidth).ashr(ShAmtUMax),
297 APInt::getSignedMaxValue(BitWidth).ashr(ShAmtUMax) + 1); 353 APInt::getSignedMaxValue(BitWidth).ashr(ShAmtUMax) + 1);
298 } 354 }
299 } 355 }
314 bool ConstantRange::isEmptySet() const { 370 bool ConstantRange::isEmptySet() const {
315 return Lower == Upper && Lower.isMinValue(); 371 return Lower == Upper && Lower.isMinValue();
316 } 372 }
317 373
318 bool ConstantRange::isWrappedSet() const { 374 bool ConstantRange::isWrappedSet() const {
319 return Lower.ugt(Upper) && !Upper.isNullValue(); 375 return Lower.ugt(Upper) && !Upper.isZero();
320 } 376 }
321 377
322 bool ConstantRange::isUpperWrapped() const { 378 bool ConstantRange::isUpperWrapped() const {
323 return Lower.ugt(Upper); 379 return Lower.ugt(Upper);
324 } 380 }
341 return (Upper - Lower).ult(Other.Upper - Other.Lower); 397 return (Upper - Lower).ult(Other.Upper - Other.Lower);
342 } 398 }
343 399
344 bool 400 bool
345 ConstantRange::isSizeLargerThan(uint64_t MaxSize) const { 401 ConstantRange::isSizeLargerThan(uint64_t MaxSize) const {
346 assert(MaxSize && "MaxSize can't be 0.");
347 // If this a full set, we need special handling to avoid needing an extra bit 402 // If this a full set, we need special handling to avoid needing an extra bit
348 // to represent the size. 403 // to represent the size.
349 if (isFullSet()) 404 if (isFullSet())
350 return APInt::getMaxValue(getBitWidth()).ugt(MaxSize - 1); 405 return MaxSize == 0 || APInt::getMaxValue(getBitWidth()).ugt(MaxSize - 1);
351 406
352 return (Upper - Lower).ugt(MaxSize); 407 return (Upper - Lower).ugt(MaxSize);
353 } 408 }
354 409
355 bool ConstantRange::isAllNegative() const { 410 bool ConstantRange::isAllNegative() const {
593 ConstantRange(Lower, CR.Upper), ConstantRange(CR.Lower, Upper), Type); 648 ConstantRange(Lower, CR.Upper), ConstantRange(CR.Lower, Upper), Type);
594 649
595 APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower; 650 APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower;
596 APInt U = (CR.Upper - 1).ugt(Upper - 1) ? CR.Upper : Upper; 651 APInt U = (CR.Upper - 1).ugt(Upper - 1) ? CR.Upper : Upper;
597 652
598 if (L.isNullValue() && U.isNullValue()) 653 if (L.isZero() && U.isZero())
599 return getFull(); 654 return getFull();
600 655
601 return ConstantRange(std::move(L), std::move(U)); 656 return ConstantRange(std::move(L), std::move(U));
602 } 657 }
603 658
640 695
641 APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower; 696 APInt L = CR.Lower.ult(Lower) ? CR.Lower : Lower;
642 APInt U = CR.Upper.ugt(Upper) ? CR.Upper : Upper; 697 APInt U = CR.Upper.ugt(Upper) ? CR.Upper : Upper;
643 698
644 return ConstantRange(std::move(L), std::move(U)); 699 return ConstantRange(std::move(L), std::move(U));
700 }
701
702 Optional<ConstantRange>
703 ConstantRange::exactIntersectWith(const ConstantRange &CR) const {
704 // TODO: This can be implemented more efficiently.
705 ConstantRange Result = intersectWith(CR);
706 if (Result == inverse().unionWith(CR.inverse()).inverse())
707 return Result;
708 return None;
709 }
710
711 Optional<ConstantRange>
712 ConstantRange::exactUnionWith(const ConstantRange &CR) const {
713 // TODO: This can be implemented more efficiently.
714 ConstantRange Result = unionWith(CR);
715 if (Result == inverse().intersectWith(CR.inverse()).inverse())
716 return Result;
717 return None;
645 } 718 }
646 719
647 ConstantRange ConstantRange::castOp(Instruction::CastOps CastOp, 720 ConstantRange ConstantRange::castOp(Instruction::CastOps CastOp,
648 uint32_t ResultBitWidth) const { 721 uint32_t ResultBitWidth) const {
649 switch (CastOp) { 722 switch (CastOp) {
664 else 737 else
665 return getFull(ResultBitWidth); 738 return getFull(ResultBitWidth);
666 case Instruction::UIToFP: { 739 case Instruction::UIToFP: {
667 // TODO: use input range if available 740 // TODO: use input range if available
668 auto BW = getBitWidth(); 741 auto BW = getBitWidth();
669 APInt Min = APInt::getMinValue(BW).zextOrSelf(ResultBitWidth); 742 APInt Min = APInt::getMinValue(BW);
670 APInt Max = APInt::getMaxValue(BW).zextOrSelf(ResultBitWidth); 743 APInt Max = APInt::getMaxValue(BW);
744 if (ResultBitWidth > BW) {
745 Min = Min.zext(ResultBitWidth);
746 Max = Max.zext(ResultBitWidth);
747 }
671 return ConstantRange(std::move(Min), std::move(Max)); 748 return ConstantRange(std::move(Min), std::move(Max));
672 } 749 }
673 case Instruction::SIToFP: { 750 case Instruction::SIToFP: {
674 // TODO: use input range if available 751 // TODO: use input range if available
675 auto BW = getBitWidth(); 752 auto BW = getBitWidth();
676 APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(ResultBitWidth); 753 APInt SMin = APInt::getSignedMinValue(BW);
677 APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(ResultBitWidth); 754 APInt SMax = APInt::getSignedMaxValue(BW);
755 if (ResultBitWidth > BW) {
756 SMin = SMin.sext(ResultBitWidth);
757 SMax = SMax.sext(ResultBitWidth);
758 }
678 return ConstantRange(std::move(SMin), std::move(SMax)); 759 return ConstantRange(std::move(SMin), std::move(SMax));
679 } 760 }
680 case Instruction::FPTrunc: 761 case Instruction::FPTrunc:
681 case Instruction::FPExt: 762 case Instruction::FPExt:
682 case Instruction::IntToPtr: 763 case Instruction::IntToPtr:
1053 ConstantRange SR = Result_sext.truncate(getBitWidth()); 1134 ConstantRange SR = Result_sext.truncate(getBitWidth());
1054 1135
1055 return UR.isSizeStrictlySmallerThan(SR) ? UR : SR; 1136 return UR.isSizeStrictlySmallerThan(SR) ? UR : SR;
1056 } 1137 }
1057 1138
1139 ConstantRange ConstantRange::smul_fast(const ConstantRange &Other) const {
1140 if (isEmptySet() || Other.isEmptySet())
1141 return getEmpty();
1142
1143 APInt Min = getSignedMin();
1144 APInt Max = getSignedMax();
1145 APInt OtherMin = Other.getSignedMin();
1146 APInt OtherMax = Other.getSignedMax();
1147
1148 bool O1, O2, O3, O4;
1149 auto Muls = {Min.smul_ov(OtherMin, O1), Min.smul_ov(OtherMax, O2),
1150 Max.smul_ov(OtherMin, O3), Max.smul_ov(OtherMax, O4)};
1151 if (O1 || O2 || O3 || O4)
1152 return getFull();
1153
1154 auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); };
1155 return getNonEmpty(std::min(Muls, Compare), std::max(Muls, Compare) + 1);
1156 }
1157
1058 ConstantRange 1158 ConstantRange
1059 ConstantRange::smax(const ConstantRange &Other) const { 1159 ConstantRange::smax(const ConstantRange &Other) const {
1060 // X smax Y is: range(smax(X_smin, Y_smin), 1160 // X smax Y is: range(smax(X_smin, Y_smin),
1061 // smax(X_smax, Y_smax)) 1161 // smax(X_smax, Y_smax))
1062 if (isEmptySet() || Other.isEmptySet()) 1162 if (isEmptySet() || Other.isEmptySet())
1111 return Res; 1211 return Res;
1112 } 1212 }
1113 1213
1114 ConstantRange 1214 ConstantRange
1115 ConstantRange::udiv(const ConstantRange &RHS) const { 1215 ConstantRange::udiv(const ConstantRange &RHS) const {
1116 if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isNullValue()) 1216 if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isZero())
1117 return getEmpty(); 1217 return getEmpty();
1118 1218
1119 APInt Lower = getUnsignedMin().udiv(RHS.getUnsignedMax()); 1219 APInt Lower = getUnsignedMin().udiv(RHS.getUnsignedMax());
1120 1220
1121 APInt RHS_umin = RHS.getUnsignedMin(); 1221 APInt RHS_umin = RHS.getUnsignedMin();
1122 if (RHS_umin.isNullValue()) { 1222 if (RHS_umin.isZero()) {
1123 // We want the lowest value in RHS excluding zero. Usually that would be 1 1223 // We want the lowest value in RHS excluding zero. Usually that would be 1
1124 // except for a range in the form of [X, 1) in which case it would be X. 1224 // except for a range in the form of [X, 1) in which case it would be X.
1125 if (RHS.getUpper() == 1) 1225 if (RHS.getUpper() == 1)
1126 RHS_umin = RHS.getLower(); 1226 RHS_umin = RHS.getLower();
1127 else 1227 else
1134 1234
1135 ConstantRange ConstantRange::sdiv(const ConstantRange &RHS) const { 1235 ConstantRange ConstantRange::sdiv(const ConstantRange &RHS) const {
1136 // We split up the LHS and RHS into positive and negative components 1236 // We split up the LHS and RHS into positive and negative components
1137 // and then also compute the positive and negative components of the result 1237 // and then also compute the positive and negative components of the result
1138 // separately by combining division results with the appropriate signs. 1238 // separately by combining division results with the appropriate signs.
1139 APInt Zero = APInt::getNullValue(getBitWidth()); 1239 APInt Zero = APInt::getZero(getBitWidth());
1140 APInt SignedMin = APInt::getSignedMinValue(getBitWidth()); 1240 APInt SignedMin = APInt::getSignedMinValue(getBitWidth());
1141 ConstantRange PosFilter(APInt(getBitWidth(), 1), SignedMin); 1241 // There are no positive 1-bit values. The 1 would get interpreted as -1.
1242 ConstantRange PosFilter =
1243 getBitWidth() == 1 ? getEmpty()
1244 : ConstantRange(APInt(getBitWidth(), 1), SignedMin);
1142 ConstantRange NegFilter(SignedMin, Zero); 1245 ConstantRange NegFilter(SignedMin, Zero);
1143 ConstantRange PosL = intersectWith(PosFilter); 1246 ConstantRange PosL = intersectWith(PosFilter);
1144 ConstantRange NegL = intersectWith(NegFilter); 1247 ConstantRange NegL = intersectWith(NegFilter);
1145 ConstantRange PosR = RHS.intersectWith(PosFilter); 1248 ConstantRange PosR = RHS.intersectWith(PosFilter);
1146 ConstantRange NegR = RHS.intersectWith(NegFilter); 1249 ConstantRange NegR = RHS.intersectWith(NegFilter);
1157 // We need to deal with one tricky case here: SignedMin / -1 is UB on the 1260 // We need to deal with one tricky case here: SignedMin / -1 is UB on the
1158 // IR level, so we'll want to exclude this case when calculating bounds. 1261 // IR level, so we'll want to exclude this case when calculating bounds.
1159 // (For APInts the operation is well-defined and yields SignedMin.) We 1262 // (For APInts the operation is well-defined and yields SignedMin.) We
1160 // handle this by dropping either SignedMin from the LHS or -1 from the RHS. 1263 // handle this by dropping either SignedMin from the LHS or -1 from the RHS.
1161 APInt Lo = (NegL.Upper - 1).sdiv(NegR.Lower); 1264 APInt Lo = (NegL.Upper - 1).sdiv(NegR.Lower);
1162 if (NegL.Lower.isMinSignedValue() && NegR.Upper.isNullValue()) { 1265 if (NegL.Lower.isMinSignedValue() && NegR.Upper.isZero()) {
1163 // Remove -1 from the LHS. Skip if it's the only element, as this would 1266 // Remove -1 from the LHS. Skip if it's the only element, as this would
1164 // leave us with an empty set. 1267 // leave us with an empty set.
1165 if (!NegR.Lower.isAllOnesValue()) { 1268 if (!NegR.Lower.isAllOnes()) {
1166 APInt AdjNegRUpper; 1269 APInt AdjNegRUpper;
1167 if (RHS.Lower.isAllOnesValue()) 1270 if (RHS.Lower.isAllOnes())
1168 // Negative part of [-1, X] without -1 is [SignedMin, X]. 1271 // Negative part of [-1, X] without -1 is [SignedMin, X].
1169 AdjNegRUpper = RHS.Upper; 1272 AdjNegRUpper = RHS.Upper;
1170 else 1273 else
1171 // [X, -1] without -1 is [X, -2]. 1274 // [X, -1] without -1 is [X, -2].
1172 AdjNegRUpper = NegR.Upper - 1; 1275 AdjNegRUpper = NegR.Upper - 1;
1216 Res = Res.unionWith(ConstantRange(Zero)); 1319 Res = Res.unionWith(ConstantRange(Zero));
1217 return Res; 1320 return Res;
1218 } 1321 }
1219 1322
1220 ConstantRange ConstantRange::urem(const ConstantRange &RHS) const { 1323 ConstantRange ConstantRange::urem(const ConstantRange &RHS) const {
1221 if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isNullValue()) 1324 if (isEmptySet() || RHS.isEmptySet() || RHS.getUnsignedMax().isZero())
1222 return getEmpty(); 1325 return getEmpty();
1223 1326
1224 if (const APInt *RHSInt = RHS.getSingleElement()) { 1327 if (const APInt *RHSInt = RHS.getSingleElement()) {
1225 // UREM by null is UB. 1328 // UREM by null is UB.
1226 if (RHSInt->isNullValue()) 1329 if (RHSInt->isZero())
1227 return getEmpty(); 1330 return getEmpty();
1228 // Use APInt's implementation of UREM for single element ranges. 1331 // Use APInt's implementation of UREM for single element ranges.
1229 if (const APInt *LHSInt = getSingleElement()) 1332 if (const APInt *LHSInt = getSingleElement())
1230 return {LHSInt->urem(*RHSInt)}; 1333 return {LHSInt->urem(*RHSInt)};
1231 } 1334 }
1234 if (getUnsignedMax().ult(RHS.getUnsignedMin())) 1337 if (getUnsignedMax().ult(RHS.getUnsignedMin()))
1235 return *this; 1338 return *this;
1236 1339
1237 // L % R is <= L and < R. 1340 // L % R is <= L and < R.
1238 APInt Upper = APIntOps::umin(getUnsignedMax(), RHS.getUnsignedMax() - 1) + 1; 1341 APInt Upper = APIntOps::umin(getUnsignedMax(), RHS.getUnsignedMax() - 1) + 1;
1239 return getNonEmpty(APInt::getNullValue(getBitWidth()), std::move(Upper)); 1342 return getNonEmpty(APInt::getZero(getBitWidth()), std::move(Upper));
1240 } 1343 }
1241 1344
1242 ConstantRange ConstantRange::srem(const ConstantRange &RHS) const { 1345 ConstantRange ConstantRange::srem(const ConstantRange &RHS) const {
1243 if (isEmptySet() || RHS.isEmptySet()) 1346 if (isEmptySet() || RHS.isEmptySet())
1244 return getEmpty(); 1347 return getEmpty();
1245 1348
1246 if (const APInt *RHSInt = RHS.getSingleElement()) { 1349 if (const APInt *RHSInt = RHS.getSingleElement()) {
1247 // SREM by null is UB. 1350 // SREM by null is UB.
1248 if (RHSInt->isNullValue()) 1351 if (RHSInt->isZero())
1249 return getEmpty(); 1352 return getEmpty();
1250 // Use APInt's implementation of SREM for single element ranges. 1353 // Use APInt's implementation of SREM for single element ranges.
1251 if (const APInt *LHSInt = getSingleElement()) 1354 if (const APInt *LHSInt = getSingleElement())
1252 return {LHSInt->srem(*RHSInt)}; 1355 return {LHSInt->srem(*RHSInt)};
1253 } 1356 }
1255 ConstantRange AbsRHS = RHS.abs(); 1358 ConstantRange AbsRHS = RHS.abs();
1256 APInt MinAbsRHS = AbsRHS.getUnsignedMin(); 1359 APInt MinAbsRHS = AbsRHS.getUnsignedMin();
1257 APInt MaxAbsRHS = AbsRHS.getUnsignedMax(); 1360 APInt MaxAbsRHS = AbsRHS.getUnsignedMax();
1258 1361
1259 // Modulus by zero is UB. 1362 // Modulus by zero is UB.
1260 if (MaxAbsRHS.isNullValue()) 1363 if (MaxAbsRHS.isZero())
1261 return getEmpty(); 1364 return getEmpty();
1262 1365
1263 if (MinAbsRHS.isNullValue()) 1366 if (MinAbsRHS.isZero())
1264 ++MinAbsRHS; 1367 ++MinAbsRHS;
1265 1368
1266 APInt MinLHS = getSignedMin(), MaxLHS = getSignedMax(); 1369 APInt MinLHS = getSignedMin(), MaxLHS = getSignedMax();
1267 1370
1268 if (MinLHS.isNonNegative()) { 1371 if (MinLHS.isNonNegative()) {
1270 if (MaxLHS.ult(MinAbsRHS)) 1373 if (MaxLHS.ult(MinAbsRHS))
1271 return *this; 1374 return *this;
1272 1375
1273 // L % R is <= L and < R. 1376 // L % R is <= L and < R.
1274 APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1; 1377 APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1;
1275 return ConstantRange(APInt::getNullValue(getBitWidth()), std::move(Upper)); 1378 return ConstantRange(APInt::getZero(getBitWidth()), std::move(Upper));
1276 } 1379 }
1277 1380
1278 // Same basic logic as above, but the result is negative. 1381 // Same basic logic as above, but the result is negative.
1279 if (MaxLHS.isNegative()) { 1382 if (MaxLHS.isNegative()) {
1280 if (MinLHS.ugt(-MinAbsRHS)) 1383 if (MinLHS.ugt(-MinAbsRHS))
1289 APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1; 1392 APInt Upper = APIntOps::umin(MaxLHS, MaxAbsRHS - 1) + 1;
1290 return ConstantRange(std::move(Lower), std::move(Upper)); 1393 return ConstantRange(std::move(Lower), std::move(Upper));
1291 } 1394 }
1292 1395
1293 ConstantRange ConstantRange::binaryNot() const { 1396 ConstantRange ConstantRange::binaryNot() const {
1294 return ConstantRange(APInt::getAllOnesValue(getBitWidth())).sub(*this); 1397 return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this);
1295 } 1398 }
1296 1399
1297 ConstantRange 1400 ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const {
1298 ConstantRange::binaryAnd(const ConstantRange &Other) const { 1401 if (isEmptySet() || Other.isEmptySet())
1299 if (isEmptySet() || Other.isEmptySet()) 1402 return getEmpty();
1300 return getEmpty(); 1403
1301 1404 ConstantRange KnownBitsRange =
1302 // Use APInt's implementation of AND for single element ranges. 1405 fromKnownBits(toKnownBits() & Other.toKnownBits(), false);
1303 if (isSingleElement() && Other.isSingleElement()) 1406 ConstantRange UMinUMaxRange =
1304 return {*getSingleElement() & *Other.getSingleElement()}; 1407 getNonEmpty(APInt::getZero(getBitWidth()),
1305 1408 APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
1306 // TODO: replace this with something less conservative 1409 return KnownBitsRange.intersectWith(UMinUMaxRange);
1307 1410 }
1308 APInt umin = APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()); 1411
1309 return getNonEmpty(APInt::getNullValue(getBitWidth()), std::move(umin) + 1); 1412 ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const {
1310 } 1413 if (isEmptySet() || Other.isEmptySet())
1311 1414 return getEmpty();
1312 ConstantRange 1415
1313 ConstantRange::binaryOr(const ConstantRange &Other) const { 1416 ConstantRange KnownBitsRange =
1314 if (isEmptySet() || Other.isEmptySet()) 1417 fromKnownBits(toKnownBits() | Other.toKnownBits(), false);
1315 return getEmpty(); 1418 // Upper wrapped range.
1316 1419 ConstantRange UMaxUMinRange =
1317 // Use APInt's implementation of OR for single element ranges. 1420 getNonEmpty(APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()),
1318 if (isSingleElement() && Other.isSingleElement()) 1421 APInt::getZero(getBitWidth()));
1319 return {*getSingleElement() | *Other.getSingleElement()}; 1422 return KnownBitsRange.intersectWith(UMaxUMinRange);
1320
1321 // TODO: replace this with something less conservative
1322
1323 APInt umax = APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin());
1324 return getNonEmpty(std::move(umax), APInt::getNullValue(getBitWidth()));
1325 } 1423 }
1326 1424
1327 ConstantRange ConstantRange::binaryXor(const ConstantRange &Other) const { 1425 ConstantRange ConstantRange::binaryXor(const ConstantRange &Other) const {
1328 if (isEmptySet() || Other.isEmptySet()) 1426 if (isEmptySet() || Other.isEmptySet())
1329 return getEmpty(); 1427 return getEmpty();
1331 // Use APInt's implementation of XOR for single element ranges. 1429 // Use APInt's implementation of XOR for single element ranges.
1332 if (isSingleElement() && Other.isSingleElement()) 1430 if (isSingleElement() && Other.isSingleElement())
1333 return {*getSingleElement() ^ *Other.getSingleElement()}; 1431 return {*getSingleElement() ^ *Other.getSingleElement()};
1334 1432
1335 // Special-case binary complement, since we can give a precise answer. 1433 // Special-case binary complement, since we can give a precise answer.
1336 if (Other.isSingleElement() && Other.getSingleElement()->isAllOnesValue()) 1434 if (Other.isSingleElement() && Other.getSingleElement()->isAllOnes())
1337 return binaryNot(); 1435 return binaryNot();
1338 if (isSingleElement() && getSingleElement()->isAllOnesValue()) 1436 if (isSingleElement() && getSingleElement()->isAllOnes())
1339 return Other.binaryNot(); 1437 return Other.binaryNot();
1340 1438
1341 // TODO: replace this with something less conservative 1439 return fromKnownBits(toKnownBits() ^ Other.toKnownBits(), /*IsSigned*/false);
1342 return getFull();
1343 } 1440 }
1344 1441
1345 ConstantRange 1442 ConstantRange
1346 ConstantRange::shl(const ConstantRange &Other) const { 1443 ConstantRange::shl(const ConstantRange &Other) const {
1347 if (isEmptySet() || Other.isEmptySet()) 1444 if (isEmptySet() || Other.isEmptySet())
1348 return getEmpty(); 1445 return getEmpty();
1349 1446
1350 APInt max = getUnsignedMax(); 1447 APInt Min = getUnsignedMin();
1351 APInt Other_umax = Other.getUnsignedMax(); 1448 APInt Max = getUnsignedMax();
1352 1449 if (const APInt *RHS = Other.getSingleElement()) {
1353 // If we are shifting by maximum amount of 1450 unsigned BW = getBitWidth();
1354 // zero return return the original range. 1451 if (RHS->uge(BW))
1355 if (Other_umax.isNullValue()) 1452 return getEmpty();
1356 return *this; 1453
1357 // there's overflow! 1454 unsigned EqualLeadingBits = (Min ^ Max).countLeadingZeros();
1358 if (Other_umax.ugt(max.countLeadingZeros())) 1455 if (RHS->ule(EqualLeadingBits))
1456 return getNonEmpty(Min << *RHS, (Max << *RHS) + 1);
1457
1458 return getNonEmpty(APInt::getZero(BW),
1459 APInt::getBitsSetFrom(BW, RHS->getZExtValue()) + 1);
1460 }
1461
1462 APInt OtherMax = Other.getUnsignedMax();
1463
1464 // There's overflow!
1465 if (OtherMax.ugt(Max.countLeadingZeros()))
1359 return getFull(); 1466 return getFull();
1360 1467
1361 // FIXME: implement the other tricky cases 1468 // FIXME: implement the other tricky cases
1362 1469
1363 APInt min = getUnsignedMin(); 1470 Min <<= Other.getUnsignedMin();
1364 min <<= Other.getUnsignedMin(); 1471 Max <<= OtherMax;
1365 max <<= Other_umax; 1472
1366 1473 return ConstantRange::getNonEmpty(std::move(Min), std::move(Max) + 1);
1367 return ConstantRange(std::move(min), std::move(max) + 1);
1368 } 1474 }
1369 1475
1370 ConstantRange 1476 ConstantRange
1371 ConstantRange::lshr(const ConstantRange &Other) const { 1477 ConstantRange::lshr(const ConstantRange &Other) const {
1372 if (isEmptySet() || Other.isEmptySet()) 1478 if (isEmptySet() || Other.isEmptySet())
1481 // the smallest of the cartesian product of the lower and upper ranges; 1587 // the smallest of the cartesian product of the lower and upper ranges;
1482 // for example: 1588 // for example:
1483 // [-1,4) * [-2,3) = min(-1*-2, -1*2, 3*-2, 3*2) = -6. 1589 // [-1,4) * [-2,3) = min(-1*-2, -1*2, 3*-2, 3*2) = -6.
1484 // Similarly for the upper bound, swapping min for max. 1590 // Similarly for the upper bound, swapping min for max.
1485 1591
1486 APInt this_min = getSignedMin().sext(getBitWidth() * 2); 1592 APInt Min = getSignedMin();
1487 APInt this_max = getSignedMax().sext(getBitWidth() * 2); 1593 APInt Max = getSignedMax();
1488 APInt Other_min = Other.getSignedMin().sext(getBitWidth() * 2); 1594 APInt OtherMin = Other.getSignedMin();
1489 APInt Other_max = Other.getSignedMax().sext(getBitWidth() * 2); 1595 APInt OtherMax = Other.getSignedMax();
1490 1596
1491 auto L = {this_min * Other_min, this_min * Other_max, this_max * Other_min, 1597 auto L = {Min.smul_sat(OtherMin), Min.smul_sat(OtherMax),
1492 this_max * Other_max}; 1598 Max.smul_sat(OtherMin), Max.smul_sat(OtherMax)};
1493 auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); }; 1599 auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); };
1494 1600 return getNonEmpty(std::min(L, Compare), std::max(L, Compare) + 1);
1495 // Note that we wanted to perform signed saturating multiplication,
1496 // so since we performed plain multiplication in twice the bitwidth,
1497 // we need to perform signed saturating truncation.
1498 return getNonEmpty(std::min(L, Compare).truncSSat(getBitWidth()),
1499 std::max(L, Compare).truncSSat(getBitWidth()) + 1);
1500 } 1601 }
1501 1602
1502 ConstantRange ConstantRange::ushl_sat(const ConstantRange &Other) const { 1603 ConstantRange ConstantRange::ushl_sat(const ConstantRange &Other) const {
1503 if (isEmptySet() || Other.isEmptySet()) 1604 if (isEmptySet() || Other.isEmptySet())
1504 return getEmpty(); 1605 return getEmpty();
1533 1634
1534 if (isSignWrappedSet()) { 1635 if (isSignWrappedSet()) {
1535 APInt Lo; 1636 APInt Lo;
1536 // Check whether the range crosses zero. 1637 // Check whether the range crosses zero.
1537 if (Upper.isStrictlyPositive() || !Lower.isStrictlyPositive()) 1638 if (Upper.isStrictlyPositive() || !Lower.isStrictlyPositive())
1538 Lo = APInt::getNullValue(getBitWidth()); 1639 Lo = APInt::getZero(getBitWidth());
1539 else 1640 else
1540 Lo = APIntOps::umin(Lower, -Upper + 1); 1641 Lo = APIntOps::umin(Lower, -Upper + 1);
1541 1642
1542 // If SignedMin is not poison, then it is included in the result range. 1643 // If SignedMin is not poison, then it is included in the result range.
1543 if (IntMinIsPoison) 1644 if (IntMinIsPoison)
1563 // All negative. 1664 // All negative.
1564 if (SMax.isNegative()) 1665 if (SMax.isNegative())
1565 return ConstantRange(-SMax, -SMin + 1); 1666 return ConstantRange(-SMax, -SMin + 1);
1566 1667
1567 // Range crosses zero. 1668 // Range crosses zero.
1568 return ConstantRange(APInt::getNullValue(getBitWidth()), 1669 return ConstantRange(APInt::getZero(getBitWidth()),
1569 APIntOps::umax(-SMin, SMax) + 1); 1670 APIntOps::umax(-SMin, SMax) + 1);
1570 } 1671 }
1571 1672
1572 ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow( 1673 ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
1573 const ConstantRange &Other) const { 1674 const ConstantRange &Other) const {