comparison lib/Transforms/InstCombine/InstCombineShifts.cpp @ 121:803732b1fca8

LLVM 5.0
author kono
date Fri, 27 Oct 2017 17:07:41 +0900
parents 1172e4bd9c6f
children 3a76565eade5
comparison
equal deleted inserted replaced
120:1172e4bd9c6f 121:803732b1fca8
20 using namespace PatternMatch; 20 using namespace PatternMatch;
21 21
22 #define DEBUG_TYPE "instcombine" 22 #define DEBUG_TYPE "instcombine"
23 23
24 Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { 24 Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
25 assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
26 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 25 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
26 assert(Op0->getType() == Op1->getType());
27 27
28 // See if we can fold away this shift. 28 // See if we can fold away this shift.
29 if (SimplifyDemandedInstructionBits(I)) 29 if (SimplifyDemandedInstructionBits(I))
30 return &I; 30 return &I;
31 31
42 // (C1 shift (A add C2)) -> (C1 shift C2) shift A) 42 // (C1 shift (A add C2)) -> (C1 shift C2) shift A)
43 // iff A and C2 are both positive. 43 // iff A and C2 are both positive.
44 Value *A; 44 Value *A;
45 Constant *C; 45 Constant *C;
46 if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) 46 if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C))))
47 if (isKnownNonNegative(A, DL) && isKnownNonNegative(C, DL)) 47 if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) &&
48 isKnownNonNegative(C, DL, 0, &AC, &I, &DT))
48 return BinaryOperator::Create( 49 return BinaryOperator::Create(
49 I.getOpcode(), Builder->CreateBinOp(I.getOpcode(), Op0, C), A); 50 I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A);
50 51
51 // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. 52 // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2.
52 // Because shifts by negative values (which could occur if A were negative) 53 // Because shifts by negative values (which could occur if A were negative)
53 // are undefined. 54 // are undefined.
54 const APInt *B; 55 const APInt *B;
55 if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { 56 if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) {
56 // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't 57 // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
57 // demand the sign bit (and many others) here?? 58 // demand the sign bit (and many others) here??
58 Value *Rem = Builder->CreateAnd(A, ConstantInt::get(I.getType(), *B-1), 59 Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1),
59 Op1->getName()); 60 Op1->getName());
60 I.setOperand(1, Rem); 61 I.setOperand(1, Rem);
61 return &I; 62 return &I;
62 } 63 }
63 64
64 return nullptr; 65 return nullptr;
65 } 66 }
66 67
67 /// Return true if we can simplify two logical (either left or right) shifts 68 /// Return true if we can simplify two logical (either left or right) shifts
68 /// that have constant shift amounts. 69 /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
69 static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, 70 static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
70 bool IsFirstShiftLeft, 71 Instruction *InnerShift, InstCombiner &IC,
71 Instruction *SecondShift, InstCombiner &IC,
72 Instruction *CxtI) { 72 Instruction *CxtI) {
73 assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); 73 assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
74 74
75 // We need constant shifts. 75 // We need constant scalar or constant splat shifts.
76 auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); 76 const APInt *InnerShiftConst;
77 if (!SecondShiftConst) 77 if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
78 return false; 78 return false;
79 79
80 unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); 80 // Two logical shifts in the same direction:
81 bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; 81 // shl (shl X, C1), C2 --> shl X, C1 + C2
82 82 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
83 // We can always fold shl(c1) + shl(c2) -> shl(c1+c2). 83 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
84 // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). 84 if (IsInnerShl == IsOuterShl)
85 if (IsFirstShiftLeft == IsSecondShiftLeft)
86 return true; 85 return true;
87 86
88 // We can always fold lshr(c) + shl(c) -> and(c2). 87 // Equal shift amounts in opposite directions become bitwise 'and':
89 // We can always fold shl(c) + lshr(c) -> and(c2). 88 // lshr (shl X, C), C --> and X, C'
90 if (FirstShiftAmt == SecondShiftAmt) 89 // shl (lshr X, C), C --> and X, C'
90 unsigned InnerShAmt = InnerShiftConst->getZExtValue();
91 if (InnerShAmt == OuterShAmt)
91 return true; 92 return true;
92 93
93 unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits();
94
95 // If the 2nd shift is bigger than the 1st, we can fold: 94 // If the 2nd shift is bigger than the 1st, we can fold:
96 // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or 95 // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
97 // shl(c1) + lshr(c2) -> lshr(c3) + and(c4), 96 // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
98 // but it isn't profitable unless we know the and'd out bits are already zero. 97 // but it isn't profitable unless we know the and'd out bits are already zero.
99 // Also check that the 2nd shift is valid (less than the type width) or we'll 98 // Also, check that the inner shift is valid (less than the type width) or
100 // crash trying to produce the bit mask for the 'and'. 99 // we'll crash trying to produce the bit mask for the 'and'.
101 if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { 100 unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
102 unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt 101 if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) {
103 : SecondShiftAmt - FirstShiftAmt; 102 unsigned MaskShift =
104 APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; 103 IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
105 if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) 104 APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
105 if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
106 return true; 106 return true;
107 } 107 }
108 108
109 return false; 109 return false;
110 } 110 }
111 111
112 /// See if we can compute the specified value, but shifted 112 /// See if we can compute the specified value, but shifted logically to the left
113 /// logically to the left or right by some number of bits. This should return 113 /// or right by some number of bits. This should return true if the expression
114 /// true if the expression can be computed for the same cost as the current 114 /// can be computed for the same cost as the current expression tree. This is
115 /// expression tree. This is used to eliminate extraneous shifting from things 115 /// used to eliminate extraneous shifting from things like:
116 /// like:
117 /// %C = shl i128 %A, 64 116 /// %C = shl i128 %A, 64
118 /// %D = shl i128 %B, 96 117 /// %D = shl i128 %B, 96
119 /// %E = or i128 %C, %D 118 /// %E = or i128 %C, %D
120 /// %F = lshr i128 %E, 64 119 /// %F = lshr i128 %E, 64
121 /// where the client will ask if E can be computed shifted right by 64-bits. If 120 /// where the client will ask if E can be computed shifted right by 64-bits. If
122 /// this succeeds, the GetShiftedValue function will be called to produce the 121 /// this succeeds, getShiftedValue() will be called to produce the value.
123 /// value. 122 static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
124 static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
125 InstCombiner &IC, Instruction *CxtI) { 123 InstCombiner &IC, Instruction *CxtI) {
126 // We can always evaluate constants shifted. 124 // We can always evaluate constants shifted.
127 if (isa<Constant>(V)) 125 if (isa<Constant>(V))
128 return true; 126 return true;
129 127
163 default: return false; 161 default: return false;
164 case Instruction::And: 162 case Instruction::And:
165 case Instruction::Or: 163 case Instruction::Or:
166 case Instruction::Xor: 164 case Instruction::Xor:
167 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. 165 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
168 return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && 166 return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
169 CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); 167 canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
170 168
171 case Instruction::Shl: 169 case Instruction::Shl:
172 case Instruction::LShr: 170 case Instruction::LShr:
173 return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI); 171 return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
174 172
175 case Instruction::Select: { 173 case Instruction::Select: {
176 SelectInst *SI = cast<SelectInst>(I); 174 SelectInst *SI = cast<SelectInst>(I);
177 Value *TrueVal = SI->getTrueValue(); 175 Value *TrueVal = SI->getTrueValue();
178 Value *FalseVal = SI->getFalseValue(); 176 Value *FalseVal = SI->getFalseValue();
179 return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && 177 return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
180 CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); 178 canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
181 } 179 }
182 case Instruction::PHI: { 180 case Instruction::PHI: {
183 // We can change a phi if we can change all operands. Note that we never 181 // We can change a phi if we can change all operands. Note that we never
184 // get into trouble with cyclic PHIs here because we only consider 182 // get into trouble with cyclic PHIs here because we only consider
185 // instructions with a single use. 183 // instructions with a single use.
186 PHINode *PN = cast<PHINode>(I); 184 PHINode *PN = cast<PHINode>(I);
187 for (Value *IncValue : PN->incoming_values()) 185 for (Value *IncValue : PN->incoming_values())
188 if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) 186 if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
189 return false; 187 return false;
190 return true; 188 return true;
191 } 189 }
192 } 190 }
193 } 191 }
194 192
195 /// When CanEvaluateShifted returned true for an expression, 193 /// Fold OuterShift (InnerShift X, C1), C2.
196 /// this value inserts the new computation that produces the shifted value. 194 /// See canEvaluateShiftedShift() for the constraints on these instructions.
197 static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, 195 static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
196 bool IsOuterShl,
197 InstCombiner::BuilderTy &Builder) {
198 bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
199 Type *ShType = InnerShift->getType();
200 unsigned TypeWidth = ShType->getScalarSizeInBits();
201
202 // We only accept shifts-by-a-constant in canEvaluateShifted().
203 const APInt *C1;
204 match(InnerShift->getOperand(1), m_APInt(C1));
205 unsigned InnerShAmt = C1->getZExtValue();
206
207 // Change the shift amount and clear the appropriate IR flags.
208 auto NewInnerShift = [&](unsigned ShAmt) {
209 InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
210 if (IsInnerShl) {
211 InnerShift->setHasNoUnsignedWrap(false);
212 InnerShift->setHasNoSignedWrap(false);
213 } else {
214 InnerShift->setIsExact(false);
215 }
216 return InnerShift;
217 };
218
219 // Two logical shifts in the same direction:
220 // shl (shl X, C1), C2 --> shl X, C1 + C2
221 // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
222 if (IsInnerShl == IsOuterShl) {
223 // If this is an oversized composite shift, then unsigned shifts get 0.
224 if (InnerShAmt + OuterShAmt >= TypeWidth)
225 return Constant::getNullValue(ShType);
226
227 return NewInnerShift(InnerShAmt + OuterShAmt);
228 }
229
230 // Equal shift amounts in opposite directions become bitwise 'and':
231 // lshr (shl X, C), C --> and X, C'
232 // shl (lshr X, C), C --> and X, C'
233 if (InnerShAmt == OuterShAmt) {
234 APInt Mask = IsInnerShl
235 ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
236 : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
237 Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
238 ConstantInt::get(ShType, Mask));
239 if (auto *AndI = dyn_cast<Instruction>(And)) {
240 AndI->moveBefore(InnerShift);
241 AndI->takeName(InnerShift);
242 }
243 return And;
244 }
245
246 assert(InnerShAmt > OuterShAmt &&
247 "Unexpected opposite direction logical shift pair");
248
249 // In general, we would need an 'and' for this transform, but
250 // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
251 // lshr (shl X, C1), C2 --> shl X, C1 - C2
252 // shl (lshr X, C1), C2 --> lshr X, C1 - C2
253 return NewInnerShift(InnerShAmt - OuterShAmt);
254 }
255
256 /// When canEvaluateShifted() returns true for an expression, this function
257 /// inserts the new computation that produces the shifted value.
258 static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
198 InstCombiner &IC, const DataLayout &DL) { 259 InstCombiner &IC, const DataLayout &DL) {
199 // We can always evaluate constants shifted. 260 // We can always evaluate constants shifted.
200 if (Constant *C = dyn_cast<Constant>(V)) { 261 if (Constant *C = dyn_cast<Constant>(V)) {
201 if (isLeftShift) 262 if (isLeftShift)
202 V = IC.Builder->CreateShl(C, NumBits); 263 V = IC.Builder.CreateShl(C, NumBits);
203 else 264 else
204 V = IC.Builder->CreateLShr(C, NumBits); 265 V = IC.Builder.CreateLShr(C, NumBits);
205 // If we got a constantexpr back, try to simplify it with TD info. 266 // If we got a constantexpr back, try to simplify it with TD info.
206 if (auto *C = dyn_cast<Constant>(V)) 267 if (auto *C = dyn_cast<Constant>(V))
207 if (auto *FoldedC = 268 if (auto *FoldedC =
208 ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) 269 ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo()))
209 V = FoldedC; 270 V = FoldedC;
218 case Instruction::And: 279 case Instruction::And:
219 case Instruction::Or: 280 case Instruction::Or:
220 case Instruction::Xor: 281 case Instruction::Xor:
221 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. 282 // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
222 I->setOperand( 283 I->setOperand(
223 0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); 284 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
224 I->setOperand( 285 I->setOperand(
225 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); 286 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
226 return I; 287 return I;
227 288
228 case Instruction::Shl: { 289 case Instruction::Shl:
229 BinaryOperator *BO = cast<BinaryOperator>(I); 290 case Instruction::LShr:
230 unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); 291 return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
231 292 IC.Builder);
232 // We only accept shifts-by-a-constant in CanEvaluateShifted.
233 ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
234
235 // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
236 if (isLeftShift) {
237 // If this is oversized composite shift, then unsigned shifts get 0.
238 unsigned NewShAmt = NumBits+CI->getZExtValue();
239 if (NewShAmt >= TypeWidth)
240 return Constant::getNullValue(I->getType());
241
242 BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
243 BO->setHasNoUnsignedWrap(false);
244 BO->setHasNoSignedWrap(false);
245 return I;
246 }
247
248 // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have
249 // zeros.
250 if (CI->getValue() == NumBits) {
251 APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits));
252 V = IC.Builder->CreateAnd(BO->getOperand(0),
253 ConstantInt::get(BO->getContext(), Mask));
254 if (Instruction *VI = dyn_cast<Instruction>(V)) {
255 VI->moveBefore(BO);
256 VI->takeName(BO);
257 }
258 return V;
259 }
260
261 // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that
262 // the and won't be needed.
263 assert(CI->getZExtValue() > NumBits);
264 BO->setOperand(1, ConstantInt::get(BO->getType(),
265 CI->getZExtValue() - NumBits));
266 BO->setHasNoUnsignedWrap(false);
267 BO->setHasNoSignedWrap(false);
268 return BO;
269 }
270 // FIXME: This is almost identical to the SHL case. Refactor both cases into
271 // a helper function.
272 case Instruction::LShr: {
273 BinaryOperator *BO = cast<BinaryOperator>(I);
274 unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
275 // We only accept shifts-by-a-constant in CanEvaluateShifted.
276 ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
277
278 // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
279 if (!isLeftShift) {
280 // If this is oversized composite shift, then unsigned shifts get 0.
281 unsigned NewShAmt = NumBits+CI->getZExtValue();
282 if (NewShAmt >= TypeWidth)
283 return Constant::getNullValue(BO->getType());
284
285 BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
286 BO->setIsExact(false);
287 return I;
288 }
289
290 // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have
291 // zeros.
292 if (CI->getValue() == NumBits) {
293 APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits));
294 V = IC.Builder->CreateAnd(I->getOperand(0),
295 ConstantInt::get(BO->getContext(), Mask));
296 if (Instruction *VI = dyn_cast<Instruction>(V)) {
297 VI->moveBefore(I);
298 VI->takeName(I);
299 }
300 return V;
301 }
302
303 // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that
304 // the and won't be needed.
305 assert(CI->getZExtValue() > NumBits);
306 BO->setOperand(1, ConstantInt::get(BO->getType(),
307 CI->getZExtValue() - NumBits));
308 BO->setIsExact(false);
309 return BO;
310 }
311 293
312 case Instruction::Select: 294 case Instruction::Select:
313 I->setOperand( 295 I->setOperand(
314 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); 296 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
315 I->setOperand( 297 I->setOperand(
316 2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); 298 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
317 return I; 299 return I;
318 case Instruction::PHI: { 300 case Instruction::PHI: {
319 // We can change a phi if we can change all operands. Note that we never 301 // We can change a phi if we can change all operands. Note that we never
320 // get into trouble with cyclic PHIs here because we only consider 302 // get into trouble with cyclic PHIs here because we only consider
321 // instructions with a single use. 303 // instructions with a single use.
322 PHINode *PN = cast<PHINode>(I); 304 PHINode *PN = cast<PHINode>(I);
323 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 305 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
324 PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits, 306 PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
325 isLeftShift, IC, DL)); 307 isLeftShift, IC, DL));
326 return PN; 308 return PN;
327 } 309 }
328 } 310 }
329 } 311 }
330
331
332 312
333 Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, 313 Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
334 BinaryOperator &I) { 314 BinaryOperator &I) {
335 bool isLeftShift = I.getOpcode() == Instruction::Shl; 315 bool isLeftShift = I.getOpcode() == Instruction::Shl;
336 316
337 ConstantInt *COp1 = nullptr; 317 const APInt *Op1C;
338 if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1)) 318 if (!match(Op1, m_APInt(Op1C)))
339 COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
340 else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
341 COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
342 else
343 COp1 = dyn_cast<ConstantInt>(Op1);
344
345 if (!COp1)
346 return nullptr; 319 return nullptr;
347 320
348 // See if we can propagate this shift into the input, this covers the trivial 321 // See if we can propagate this shift into the input, this covers the trivial
349 // cast of lshr(shl(x,c1),c2) as well as other more complex cases. 322 // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
350 if (I.getOpcode() != Instruction::AShr && 323 if (I.getOpcode() != Instruction::AShr &&
351 CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { 324 canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
352 DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" 325 DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
353 " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); 326 " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
354 327
355 return replaceInstUsesWith( 328 return replaceInstUsesWith(
356 I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); 329 I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
357 } 330 }
358 331
359 // See if we can simplify any instructions used by the instruction whose sole 332 // See if we can simplify any instructions used by the instruction whose sole
360 // purpose is to compute bits we don't care about. 333 // purpose is to compute bits we don't care about.
361 uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); 334 unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
362 335
363 assert(!COp1->uge(TypeBits) && 336 assert(!Op1C->uge(TypeBits) &&
364 "Shift over the type width should have been removed already"); 337 "Shift over the type width should have been removed already");
365 338
366 // ((X*C1) << C2) == (X * (C1 << C2)) 339 if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I))
367 if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0)) 340 return FoldedShift;
368 if (BO->getOpcode() == Instruction::Mul && isLeftShift)
369 if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1)))
370 return BinaryOperator::CreateMul(BO->getOperand(0),
371 ConstantExpr::getShl(BOOp, Op1));
372
373 // Try to fold constant and into select arguments.
374 if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
375 if (Instruction *R = FoldOpIntoSelect(I, SI))
376 return R;
377 if (isa<PHINode>(Op0))
378 if (Instruction *NV = FoldOpIntoPhi(I))
379 return NV;
380 341
381 // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) 342 // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
382 if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) { 343 if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) {
383 Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0)); 344 Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0));
384 // If 'shift2' is an ashr, we would have to get the sign bit into a funny 345 // If 'shift2' is an ashr, we would have to get the sign bit into a funny
387 // confidence that the shifts will get folded together. We could do this 348 // confidence that the shifts will get folded together. We could do this
388 // xform in more cases, but it is unlikely to be profitable. 349 // xform in more cases, but it is unlikely to be profitable.
389 if (TrOp && I.isLogicalShift() && TrOp->isShift() && 350 if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
390 isa<ConstantInt>(TrOp->getOperand(1))) { 351 isa<ConstantInt>(TrOp->getOperand(1))) {
391 // Okay, we'll do this xform. Make the shift of shift. 352 // Okay, we'll do this xform. Make the shift of shift.
392 Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); 353 Constant *ShAmt =
354 ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
393 // (shift2 (shift1 & 0x00FF), c2) 355 // (shift2 (shift1 & 0x00FF), c2)
394 Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); 356 Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
395 357
396 // For logical shifts, the truncation has the effect of making the high 358 // For logical shifts, the truncation has the effect of making the high
397 // part of the register be zeros. Emulate this by inserting an AND to 359 // part of the register be zeros. Emulate this by inserting an AND to
398 // clear the top bits as needed. This 'and' will usually be zapped by 360 // clear the top bits as needed. This 'and' will usually be zapped by
399 // other xforms later if dead. 361 // other xforms later if dead.
404 // The mask we constructed says what the trunc would do if occurring 366 // The mask we constructed says what the trunc would do if occurring
405 // between the shifts. We want to know the effect *after* the second 367 // between the shifts. We want to know the effect *after* the second
406 // shift. We know that it is a logical shift by a constant, so adjust the 368 // shift. We know that it is a logical shift by a constant, so adjust the
407 // mask as appropriate. 369 // mask as appropriate.
408 if (I.getOpcode() == Instruction::Shl) 370 if (I.getOpcode() == Instruction::Shl)
409 MaskV <<= COp1->getZExtValue(); 371 MaskV <<= Op1C->getZExtValue();
410 else { 372 else {
411 assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); 373 assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
412 MaskV = MaskV.lshr(COp1->getZExtValue()); 374 MaskV.lshrInPlace(Op1C->getZExtValue());
413 } 375 }
414 376
415 // shift1 & 0x00FF 377 // shift1 & 0x00FF
416 Value *And = Builder->CreateAnd(NSh, 378 Value *And = Builder.CreateAnd(NSh,
417 ConstantInt::get(I.getContext(), MaskV), 379 ConstantInt::get(I.getContext(), MaskV),
418 TI->getName()); 380 TI->getName());
419 381
420 // Return the value truncated to the interesting size. 382 // Return the value truncated to the interesting size.
421 return new TruncInst(And, I.getType()); 383 return new TruncInst(And, I.getType());
422 } 384 }
423 } 385 }
437 // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) 399 // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C)
438 if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && 400 if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
439 match(Op0BO->getOperand(1), m_Shr(m_Value(V1), 401 match(Op0BO->getOperand(1), m_Shr(m_Value(V1),
440 m_Specific(Op1)))) { 402 m_Specific(Op1)))) {
441 Value *YS = // (Y << C) 403 Value *YS = // (Y << C)
442 Builder->CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); 404 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
443 // (X + (Y << C)) 405 // (X + (Y << C))
444 Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, 406 Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
445 Op0BO->getOperand(1)->getName()); 407 Op0BO->getOperand(1)->getName());
446 uint32_t Op1Val = COp1->getLimitedValue(TypeBits); 408 unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
447 409
448 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); 410 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
449 Constant *Mask = ConstantInt::get(I.getContext(), Bits); 411 Constant *Mask = ConstantInt::get(I.getContext(), Bits);
450 if (VectorType *VT = dyn_cast<VectorType>(X->getType())) 412 if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
451 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); 413 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
457 if (isLeftShift && Op0BOOp1->hasOneUse() && 419 if (isLeftShift && Op0BOOp1->hasOneUse() &&
458 match(Op0BOOp1, 420 match(Op0BOOp1,
459 m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), 421 m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
460 m_ConstantInt(CC)))) { 422 m_ConstantInt(CC)))) {
461 Value *YS = // (Y << C) 423 Value *YS = // (Y << C)
462 Builder->CreateShl(Op0BO->getOperand(0), Op1, 424 Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
463 Op0BO->getName());
464 // X & (CC << C) 425 // X & (CC << C)
465 Value *XM = Builder->CreateAnd(V1, ConstantExpr::getShl(CC, Op1), 426 Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
466 V1->getName()+".mask"); 427 V1->getName()+".mask");
467 return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); 428 return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
468 } 429 }
469 LLVM_FALLTHROUGH; 430 LLVM_FALLTHROUGH;
470 } 431 }
471 432
473 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) 434 // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C)
474 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && 435 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
475 match(Op0BO->getOperand(0), m_Shr(m_Value(V1), 436 match(Op0BO->getOperand(0), m_Shr(m_Value(V1),
476 m_Specific(Op1)))) { 437 m_Specific(Op1)))) {
477 Value *YS = // (Y << C) 438 Value *YS = // (Y << C)
478 Builder->CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); 439 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
479 // (X + (Y << C)) 440 // (X + (Y << C))
480 Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, 441 Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
481 Op0BO->getOperand(0)->getName()); 442 Op0BO->getOperand(0)->getName());
482 uint32_t Op1Val = COp1->getLimitedValue(TypeBits); 443 unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
483 444
484 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); 445 APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
485 Constant *Mask = ConstantInt::get(I.getContext(), Bits); 446 Constant *Mask = ConstantInt::get(I.getContext(), Bits);
486 if (VectorType *VT = dyn_cast<VectorType>(X->getType())) 447 if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
487 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); 448 Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
492 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && 453 if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
493 match(Op0BO->getOperand(0), 454 match(Op0BO->getOperand(0),
494 m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))), 455 m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))),
495 m_ConstantInt(CC))) && V2 == Op1) { 456 m_ConstantInt(CC))) && V2 == Op1) {
496 Value *YS = // (Y << C) 457 Value *YS = // (Y << C)
497 Builder->CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); 458 Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
498 // X & (CC << C) 459 // X & (CC << C)
499 Value *XM = Builder->CreateAnd(V1, ConstantExpr::getShl(CC, Op1), 460 Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
500 V1->getName()+".mask"); 461 V1->getName()+".mask");
501 462
502 return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); 463 return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
503 } 464 }
504 465
505 break; 466 break;
507 } 468 }
508 469
509 470
510 // If the operand is a bitwise operator with a constant RHS, and the 471 // If the operand is a bitwise operator with a constant RHS, and the
511 // shift is the only use, we can pull it out of the shift. 472 // shift is the only use, we can pull it out of the shift.
512 if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) { 473 const APInt *Op0C;
474 if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
513 bool isValid = true; // Valid only for And, Or, Xor 475 bool isValid = true; // Valid only for And, Or, Xor
514 bool highBitSet = false; // Transform if high bit of constant set? 476 bool highBitSet = false; // Transform if high bit of constant set?
515 477
516 switch (Op0BO->getOpcode()) { 478 switch (Op0BO->getOpcode()) {
517 default: isValid = false; break; // Do not perform transform! 479 default: isValid = false; break; // Do not perform transform!
532 // The highBitSet boolean indicates the value of the high bit of 494 // The highBitSet boolean indicates the value of the high bit of
533 // the constant which would cause it to be modified for this 495 // the constant which would cause it to be modified for this
534 // operation. 496 // operation.
535 // 497 //
536 if (isValid && I.getOpcode() == Instruction::AShr) 498 if (isValid && I.getOpcode() == Instruction::AShr)
537 isValid = Op0C->getValue()[TypeBits-1] == highBitSet; 499 isValid = Op0C->isNegative() == highBitSet;
538 500
539 if (isValid) { 501 if (isValid) {
540 Constant *NewRHS = ConstantExpr::get(I.getOpcode(), Op0C, Op1); 502 Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
503 cast<Constant>(Op0BO->getOperand(1)), Op1);
541 504
542 Value *NewShift = 505 Value *NewShift =
543 Builder->CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); 506 Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1);
544 NewShift->takeName(Op0BO); 507 NewShift->takeName(Op0BO);
545 508
546 return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, 509 return BinaryOperator::Create(Op0BO->getOpcode(), NewShift,
547 NewRHS); 510 NewRHS);
548 } 511 }
549 } 512 }
550 } 513
551 } 514 // If the operand is a subtract with a constant LHS, and the shift
552 515 // is the only use, we can pull it out of the shift.
553 // Find out if this is a shift of a shift by a constant. 516 // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
554 BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); 517 if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
555 if (ShiftOp && !ShiftOp->isShift()) 518 match(Op0BO->getOperand(0), m_APInt(Op0C))) {
556 ShiftOp = nullptr; 519 Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
557 520 cast<Constant>(Op0BO->getOperand(0)), Op1);
558 if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { 521
559 522 Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1);
560 // This is a constant shift of a constant shift. Be careful about hiding 523 NewShift->takeName(Op0BO);
561 // shl instructions behind bit masks. They are used to represent multiplies 524
562 // by a constant, and it is important that simple arithmetic expressions 525 return BinaryOperator::CreateSub(NewRHS, NewShift);
563 // are still recognizable by scalar evolution. 526 }
564 // 527 }
565 // The transforms applied to shl are very similar to the transforms applied 528 }
566 // to mul by constant. We can be more aggressive about optimizing right 529
567 // shifts. 530 return nullptr;
568 // 531 }
569 // Combinations of right and left shifts will still be optimized in 532
570 // DAGCombine where scalar evolution no longer applies. 533 Instruction *InstCombiner::visitShl(BinaryOperator &I) {
571 534 if (Value *V = SimplifyVectorOp(I))
572 ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); 535 return replaceInstUsesWith(I, V);
573 uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); 536
574 uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); 537 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
575 assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); 538 if (Value *V =
576 if (ShiftAmt1 == 0) return nullptr; // Will be simplified in the future. 539 SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
577 Value *X = ShiftOp->getOperand(0); 540 SQ.getWithInstruction(&I)))
578 541 return replaceInstUsesWith(I, V);
579 IntegerType *Ty = cast<IntegerType>(I.getType()); 542
580 543 if (Instruction *V = commonShiftTransforms(I))
581 // Check for (X << c1) << c2 and (X >> c1) >> c2 544 return V;
582 if (I.getOpcode() == ShiftOp->getOpcode()) { 545
583 uint32_t AmtSum = ShiftAmt1+ShiftAmt2; // Fold into one big shift. 546 const APInt *ShAmtAPInt;
584 // If this is oversized composite shift, then unsigned shifts get 0, ashr 547 if (match(Op1, m_APInt(ShAmtAPInt))) {
585 // saturates. 548 unsigned ShAmt = ShAmtAPInt->getZExtValue();
586 if (AmtSum >= TypeBits) { 549 unsigned BitWidth = I.getType()->getScalarSizeInBits();
587 if (I.getOpcode() != Instruction::AShr) 550 Type *Ty = I.getType();
588 return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); 551
589 AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. 552 // shl (zext X), ShAmt --> zext (shl X, ShAmt)
590 } 553 // This is only valid if X would have zeros shifted out.
591 554 Value *X;
592 return BinaryOperator::Create(I.getOpcode(), X, 555 if (match(Op0, m_ZExt(m_Value(X)))) {
593 ConstantInt::get(Ty, AmtSum)); 556 unsigned SrcWidth = X->getType()->getScalarSizeInBits();
594 } 557 if (ShAmt < SrcWidth &&
595 558 MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
596 if (ShiftAmt1 == ShiftAmt2) { 559 return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty);
597 // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). 560 }
598 if (I.getOpcode() == Instruction::LShr && 561
599 ShiftOp->getOpcode() == Instruction::Shl) { 562 // (X >> C) << C --> X & (-1 << C)
600 APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); 563 if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) {
601 return BinaryOperator::CreateAnd(X, 564 APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
602 ConstantInt::get(I.getContext(), Mask)); 565 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
603 } 566 }
604 } else if (ShiftAmt1 < ShiftAmt2) { 567
605 uint32_t ShiftDiff = ShiftAmt2-ShiftAmt1; 568 // Be careful about hiding shl instructions behind bit masks. They are used
606 569 // to represent multiplies by a constant, and it is important that simple
607 // (X >>?,exact C1) << C2 --> X << (C2-C1) 570 // arithmetic expressions are still recognizable by scalar evolution.
608 // The inexact version is deferred to DAGCombine so we don't hide shl 571 // The inexact versions are deferred to DAGCombine, so we don't hide shl
609 // behind a bit mask. 572 // behind a bit mask.
610 if (I.getOpcode() == Instruction::Shl && 573 const APInt *ShOp1;
611 ShiftOp->getOpcode() != Instruction::Shl && 574 if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) {
612 ShiftOp->isExact()) { 575 unsigned ShrAmt = ShOp1->getZExtValue();
613 assert(ShiftOp->getOpcode() == Instruction::LShr || 576 if (ShrAmt < ShAmt) {
614 ShiftOp->getOpcode() == Instruction::AShr); 577 // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
615 ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); 578 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
616 BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, 579 auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
617 X, ShiftDiffCst);
618 NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); 580 NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
619 NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); 581 NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
620 return NewShl; 582 return NewShl;
621 } 583 }
622 584 if (ShrAmt > ShAmt) {
623 // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) 585 // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
624 if (I.getOpcode() == Instruction::LShr && 586 Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
625 ShiftOp->getOpcode() == Instruction::Shl) { 587 auto *NewShr = BinaryOperator::Create(
626 ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); 588 cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
627 // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) 589 NewShr->setIsExact(true);
628 if (ShiftOp->hasNoUnsignedWrap()) { 590 return NewShr;
629 BinaryOperator *NewLShr = BinaryOperator::Create(Instruction::LShr, 591 }
630 X, ShiftDiffCst); 592 }
593
594 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
595 unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
596 // Oversized shifts are simplified to zero in InstSimplify.
597 if (AmtSum < BitWidth)
598 // (X << C1) << C2 --> X << (C1 + C2)
599 return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
600 }
601
602 // If the shifted-out value is known-zero, then this is a NUW shift.
603 if (!I.hasNoUnsignedWrap() &&
604 MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
605 I.setHasNoUnsignedWrap();
606 return &I;
607 }
608
609 // If the shifted-out value is all signbits, then this is a NSW shift.
610 if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
611 I.setHasNoSignedWrap();
612 return &I;
613 }
614 }
615
616 Constant *C1;
617 if (match(Op1, m_Constant(C1))) {
618 Constant *C2;
619 Value *X;
620 // (C2 << X) << C1 --> (C2 << C1) << X
621 if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
622 return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
623
624 // (X * C2) << C1 --> X * (C2 << C1)
625 if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
626 return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
627 }
628
629 return nullptr;
630 }
631
632 Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
633 if (Value *V = SimplifyVectorOp(I))
634 return replaceInstUsesWith(I, V);
635
636 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
637 if (Value *V =
638 SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I)))
639 return replaceInstUsesWith(I, V);
640
641 if (Instruction *R = commonShiftTransforms(I))
642 return R;
643
644 Type *Ty = I.getType();
645 const APInt *ShAmtAPInt;
646 if (match(Op1, m_APInt(ShAmtAPInt))) {
647 unsigned ShAmt = ShAmtAPInt->getZExtValue();
648 unsigned BitWidth = Ty->getScalarSizeInBits();
649 auto *II = dyn_cast<IntrinsicInst>(Op0);
650 if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
651 (II->getIntrinsicID() == Intrinsic::ctlz ||
652 II->getIntrinsicID() == Intrinsic::cttz ||
653 II->getIntrinsicID() == Intrinsic::ctpop)) {
654 // ctlz.i32(x)>>5 --> zext(x == 0)
655 // cttz.i32(x)>>5 --> zext(x == 0)
656 // ctpop.i32(x)>>5 --> zext(x == -1)
657 bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
658 Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
659 Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS);
660 return new ZExtInst(Cmp, Ty);
661 }
662
663 Value *X;
664 const APInt *ShOp1;
665 if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
666 unsigned ShlAmt = ShOp1->getZExtValue();
667 if (ShlAmt < ShAmt) {
668 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
669 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
670 // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
671 auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
631 NewLShr->setIsExact(I.isExact()); 672 NewLShr->setIsExact(I.isExact());
632 return NewLShr; 673 return NewLShr;
633 } 674 }
634 Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); 675 // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2)
635 676 Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact());
636 APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); 677 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
637 return BinaryOperator::CreateAnd(Shift, 678 return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
638 ConstantInt::get(I.getContext(),Mask)); 679 }
639 } 680 if (ShlAmt > ShAmt) {
640 681 Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
641 // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, 682 if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
642 // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. 683 // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
643 if (I.getOpcode() == Instruction::AShr && 684 auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
644 ShiftOp->getOpcode() == Instruction::Shl) {
645 if (ShiftOp->hasNoSignedWrap()) {
646 // (X <<nsw C1) >>s C2 --> X >>s (C2-C1)
647 ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
648 BinaryOperator *NewAShr = BinaryOperator::Create(Instruction::AShr,
649 X, ShiftDiffCst);
650 NewAShr->setIsExact(I.isExact());
651 return NewAShr;
652 }
653 }
654 } else {
655 assert(ShiftAmt2 < ShiftAmt1);
656 uint32_t ShiftDiff = ShiftAmt1-ShiftAmt2;
657
658 // (X >>?exact C1) << C2 --> X >>?exact (C1-C2)
659 // The inexact version is deferred to DAGCombine so we don't hide shl
660 // behind a bit mask.
661 if (I.getOpcode() == Instruction::Shl &&
662 ShiftOp->getOpcode() != Instruction::Shl &&
663 ShiftOp->isExact()) {
664 ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
665 BinaryOperator *NewShr = BinaryOperator::Create(ShiftOp->getOpcode(),
666 X, ShiftDiffCst);
667 NewShr->setIsExact(true);
668 return NewShr;
669 }
670
671 // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2)
672 if (I.getOpcode() == Instruction::LShr &&
673 ShiftOp->getOpcode() == Instruction::Shl) {
674 ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
675 if (ShiftOp->hasNoUnsignedWrap()) {
676 // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2)
677 BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl,
678 X, ShiftDiffCst);
679 NewShl->setHasNoUnsignedWrap(true); 685 NewShl->setHasNoUnsignedWrap(true);
680 return NewShl; 686 return NewShl;
681 } 687 }
682 Value *Shift = Builder->CreateShl(X, ShiftDiffCst); 688 // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2)
683 689 Value *NewShl = Builder.CreateShl(X, ShiftDiff);
684 APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); 690 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
685 return BinaryOperator::CreateAnd(Shift, 691 return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
686 ConstantInt::get(I.getContext(),Mask)); 692 }
687 } 693 assert(ShlAmt == ShAmt);
688 694 // (X << C) >>u C --> X & (-1 >>u C)
689 // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, 695 APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
690 // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. 696 return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
691 if (I.getOpcode() == Instruction::AShr && 697 }
692 ShiftOp->getOpcode() == Instruction::Shl) { 698
693 if (ShiftOp->hasNoSignedWrap()) { 699 if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) &&
694 // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) 700 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
695 ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); 701 assert(ShAmt < X->getType()->getScalarSizeInBits() &&
696 BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, 702 "Big shift not simplified to zero?");
697 X, ShiftDiffCst); 703 // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
698 NewShl->setHasNoSignedWrap(true); 704 Value *NewLShr = Builder.CreateLShr(X, ShAmt);
699 return NewShl; 705 return new ZExtInst(NewLShr, Ty);
706 }
707
708 if (match(Op0, m_SExt(m_Value(X))) &&
709 (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
710 // Are we moving the sign bit to the low bit and widening with high zeros?
711 unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
712 if (ShAmt == BitWidth - 1) {
713 // lshr (sext i1 X to iN), N-1 --> zext X to iN
714 if (SrcTyBitWidth == 1)
715 return new ZExtInst(X, Ty);
716
717 // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
718 if (Op0->hasOneUse()) {
719 Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
720 return new ZExtInst(NewLShr, Ty);
700 } 721 }
701 } 722 }
702 } 723
703 } 724 // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
704 return nullptr; 725 if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) {
705 } 726 // The new shift amount can't be more than the narrow source type.
706 727 unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1);
707 Instruction *InstCombiner::visitShl(BinaryOperator &I) { 728 Value *AShr = Builder.CreateAShr(X, NewShAmt);
708 if (Value *V = SimplifyVectorOp(I)) 729 return new ZExtInst(AShr, Ty);
709 return replaceInstUsesWith(I, V); 730 }
710 731 }
711 if (Value *V = 732
712 SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), 733 if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
713 I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) 734 unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
714 return replaceInstUsesWith(I, V); 735 // Oversized shifts are simplified to zero in InstSimplify.
715 736 if (AmtSum < BitWidth)
716 if (Instruction *V = commonShiftTransforms(I)) 737 // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
717 return V; 738 return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
718
719 if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) {
720 unsigned ShAmt = Op1C->getZExtValue();
721
722 // If the shifted-out value is known-zero, then this is a NUW shift.
723 if (!I.hasNoUnsignedWrap() &&
724 MaskedValueIsZero(I.getOperand(0),
725 APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0,
726 &I)) {
727 I.setHasNoUnsignedWrap();
728 return &I;
729 }
730
731 // If the shifted out value is all signbits, this is a NSW shift.
732 if (!I.hasNoSignedWrap() &&
733 ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) {
734 I.setHasNoSignedWrap();
735 return &I;
736 }
737 }
738
739 // (C1 << A) << C2 -> (C1 << C2) << A
740 Constant *C1, *C2;
741 Value *A;
742 if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) &&
743 match(I.getOperand(1), m_Constant(C2)))
744 return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
745
746 return nullptr;
747 }
748
749 Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
750 if (Value *V = SimplifyVectorOp(I))
751 return replaceInstUsesWith(I, V);
752
753 if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
754 DL, &TLI, &DT, &AC))
755 return replaceInstUsesWith(I, V);
756
757 if (Instruction *R = commonShiftTransforms(I))
758 return R;
759
760 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
761
762 if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
763 unsigned ShAmt = Op1C->getZExtValue();
764
765 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
766 unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
767 // ctlz.i32(x)>>5 --> zext(x == 0)
768 // cttz.i32(x)>>5 --> zext(x == 0)
769 // ctpop.i32(x)>>5 --> zext(x == -1)
770 if ((II->getIntrinsicID() == Intrinsic::ctlz ||
771 II->getIntrinsicID() == Intrinsic::cttz ||
772 II->getIntrinsicID() == Intrinsic::ctpop) &&
773 isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) {
774 bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop;
775 Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0);
776 Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
777 return new ZExtInst(Cmp, II->getType());
778 }
779 } 739 }
780 740
781 // If the shifted-out value is known-zero, then this is an exact shift. 741 // If the shifted-out value is known-zero, then this is an exact shift.
782 if (!I.isExact() && 742 if (!I.isExact() &&
783 MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), 743 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
784 0, &I)){
785 I.setIsExact(); 744 I.setIsExact();
786 return &I; 745 return &I;
787 } 746 }
788 } 747 }
789
790 return nullptr; 748 return nullptr;
791 } 749 }
792 750
793 Instruction *InstCombiner::visitAShr(BinaryOperator &I) { 751 Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
794 if (Value *V = SimplifyVectorOp(I)) 752 if (Value *V = SimplifyVectorOp(I))
795 return replaceInstUsesWith(I, V); 753 return replaceInstUsesWith(I, V);
796 754
797 if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), 755 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
798 DL, &TLI, &DT, &AC)) 756 if (Value *V =
757 SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I)))
799 return replaceInstUsesWith(I, V); 758 return replaceInstUsesWith(I, V);
800 759
801 if (Instruction *R = commonShiftTransforms(I)) 760 if (Instruction *R = commonShiftTransforms(I))
802 return R; 761 return R;
803 762
804 Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); 763 Type *Ty = I.getType();
805 764 unsigned BitWidth = Ty->getScalarSizeInBits();
806 if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { 765 const APInt *ShAmtAPInt;
807 unsigned ShAmt = Op1C->getZExtValue(); 766 if (match(Op1, m_APInt(ShAmtAPInt))) {
808 767 unsigned ShAmt = ShAmtAPInt->getZExtValue();
809 // If the input is a SHL by the same constant (ashr (shl X, C), C), then we 768
810 // have a sign-extend idiom. 769 // If the shift amount equals the difference in width of the destination
770 // and source scalar types:
771 // ashr (shl (zext X), C), C --> sext X
811 Value *X; 772 Value *X;
812 if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) { 773 if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
813 // If the input is an extension from the shifted amount value, e.g. 774 ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
814 // %x = zext i8 %A to i32 775 return new SExtInst(X, Ty);
815 // %y = shl i32 %x, 24 776
816 // %z = ashr %y, 24 777 // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
817 // then turn this into "z = sext i8 A to i32". 778 // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
818 if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) { 779 const APInt *ShOp1;
819 uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits(); 780 if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) {
820 uint32_t DestBits = ZI->getType()->getScalarSizeInBits(); 781 unsigned ShlAmt = ShOp1->getZExtValue();
821 if (Op1C->getZExtValue() == DestBits-SrcBits) 782 if (ShlAmt < ShAmt) {
822 return new SExtInst(ZI->getOperand(0), ZI->getType()); 783 // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
823 } 784 Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
785 auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
786 NewAShr->setIsExact(I.isExact());
787 return NewAShr;
788 }
789 if (ShlAmt > ShAmt) {
790 // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
791 Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
792 auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
793 NewShl->setHasNoSignedWrap(true);
794 return NewShl;
795 }
796 }
797
798 if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) {
799 unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
800 // Oversized arithmetic shifts replicate the sign bit.
801 AmtSum = std::min(AmtSum, BitWidth - 1);
802 // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
803 return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
804 }
805
806 if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
807 (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
808 // ashr (sext X), C --> sext (ashr X, C')
809 Type *SrcTy = X->getType();
810 ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1);
811 Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt));
812 return new SExtInst(NewSh, Ty);
824 } 813 }
825 814
826 // If the shifted-out value is known-zero, then this is an exact shift. 815 // If the shifted-out value is known-zero, then this is an exact shift.
827 if (!I.isExact() && 816 if (!I.isExact() &&
828 MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), 817 MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
829 0, &I)) {
830 I.setIsExact(); 818 I.setIsExact();
831 return &I; 819 return &I;
832 } 820 }
833 } 821 }
834 822
835 // See if we can turn a signed shr into an unsigned shr. 823 // See if we can turn a signed shr into an unsigned shr.
836 if (MaskedValueIsZero(Op0, 824 if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
837 APInt::getSignBit(I.getType()->getScalarSizeInBits()),
838 0, &I))
839 return BinaryOperator::CreateLShr(Op0, Op1); 825 return BinaryOperator::CreateLShr(Op0, Op1);
840 826
841 return nullptr; 827 return nullptr;
842 } 828 }