150
|
1
|
|
2 #include "polly/Support/SCEVValidator.h"
|
|
3 #include "polly/ScopDetection.h"
|
|
4 #include "llvm/Analysis/RegionInfo.h"
|
|
5 #include "llvm/Analysis/ScalarEvolution.h"
|
|
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
|
|
7 #include "llvm/Support/Debug.h"
|
|
8
|
|
9 using namespace llvm;
|
|
10 using namespace polly;
|
|
11
|
|
12 #define DEBUG_TYPE "polly-scev-validator"
|
|
13
|
|
14 namespace SCEVType {
|
|
15 /// The type of a SCEV
|
|
16 ///
|
|
17 /// To check for the validity of a SCEV we assign to each SCEV a type. The
|
|
18 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
|
|
19 /// important. The subexpressions of SCEV with a type X can only have a type
|
|
20 /// that is smaller or equal than X.
|
|
21 enum TYPE {
|
|
22 // An integer value.
|
|
23 INT,
|
|
24
|
|
25 // An expression that is constant during the execution of the Scop,
|
|
26 // but that may depend on parameters unknown at compile time.
|
|
27 PARAM,
|
|
28
|
|
29 // An expression that may change during the execution of the SCoP.
|
|
30 IV,
|
|
31
|
|
32 // An invalid expression.
|
|
33 INVALID
|
|
34 };
|
|
35 } // namespace SCEVType
|
|
36
|
|
37 /// The result the validator returns for a SCEV expression.
|
|
38 class ValidatorResult {
|
|
39 /// The type of the expression
|
|
40 SCEVType::TYPE Type;
|
|
41
|
|
42 /// The set of Parameters in the expression.
|
|
43 ParameterSetTy Parameters;
|
|
44
|
|
45 public:
|
|
46 /// The copy constructor
|
|
47 ValidatorResult(const ValidatorResult &Source) {
|
|
48 Type = Source.Type;
|
|
49 Parameters = Source.Parameters;
|
|
50 }
|
|
51
|
|
52 /// Construct a result with a certain type and no parameters.
|
|
53 ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
|
|
54 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
|
|
55 }
|
|
56
|
|
57 /// Construct a result with a certain type and a single parameter.
|
|
58 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
|
|
59 Parameters.insert(Expr);
|
|
60 }
|
|
61
|
|
62 /// Get the type of the ValidatorResult.
|
|
63 SCEVType::TYPE getType() { return Type; }
|
|
64
|
|
65 /// Is the analyzed SCEV constant during the execution of the SCoP.
|
|
66 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
|
|
67
|
|
68 /// Is the analyzed SCEV valid.
|
|
69 bool isValid() { return Type != SCEVType::INVALID; }
|
|
70
|
|
71 /// Is the analyzed SCEV of Type IV.
|
|
72 bool isIV() { return Type == SCEVType::IV; }
|
|
73
|
|
74 /// Is the analyzed SCEV of Type INT.
|
|
75 bool isINT() { return Type == SCEVType::INT; }
|
|
76
|
|
77 /// Is the analyzed SCEV of Type PARAM.
|
|
78 bool isPARAM() { return Type == SCEVType::PARAM; }
|
|
79
|
|
80 /// Get the parameters of this validator result.
|
|
81 const ParameterSetTy &getParameters() { return Parameters; }
|
|
82
|
|
83 /// Add the parameters of Source to this result.
|
|
84 void addParamsFrom(const ValidatorResult &Source) {
|
|
85 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
|
|
86 }
|
|
87
|
|
88 /// Merge a result.
|
|
89 ///
|
|
90 /// This means to merge the parameters and to set the Type to the most
|
|
91 /// specific Type that matches both.
|
|
92 void merge(const ValidatorResult &ToMerge) {
|
|
93 Type = std::max(Type, ToMerge.Type);
|
|
94 addParamsFrom(ToMerge);
|
|
95 }
|
|
96
|
|
97 void print(raw_ostream &OS) {
|
|
98 switch (Type) {
|
|
99 case SCEVType::INT:
|
|
100 OS << "SCEVType::INT";
|
|
101 break;
|
|
102 case SCEVType::PARAM:
|
|
103 OS << "SCEVType::PARAM";
|
|
104 break;
|
|
105 case SCEVType::IV:
|
|
106 OS << "SCEVType::IV";
|
|
107 break;
|
|
108 case SCEVType::INVALID:
|
|
109 OS << "SCEVType::INVALID";
|
|
110 break;
|
|
111 }
|
|
112 }
|
|
113 };
|
|
114
|
|
115 raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
|
|
116 VR.print(OS);
|
|
117 return OS;
|
|
118 }
|
|
119
|
|
120 bool polly::isConstCall(llvm::CallInst *Call) {
|
|
121 if (Call->mayReadOrWriteMemory())
|
|
122 return false;
|
|
123
|
|
124 for (auto &Operand : Call->arg_operands())
|
|
125 if (!isa<ConstantInt>(&Operand))
|
|
126 return false;
|
|
127
|
|
128 return true;
|
|
129 }
|
|
130
|
|
131 /// Check if a SCEV is valid in a SCoP.
|
|
132 struct SCEVValidator
|
|
133 : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
|
|
134 private:
|
|
135 const Region *R;
|
|
136 Loop *Scope;
|
|
137 ScalarEvolution &SE;
|
|
138 InvariantLoadsSetTy *ILS;
|
|
139
|
|
140 public:
|
|
141 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
|
|
142 InvariantLoadsSetTy *ILS)
|
|
143 : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
|
|
144
|
|
145 class ValidatorResult visitConstant(const SCEVConstant *Constant) {
|
|
146 return ValidatorResult(SCEVType::INT);
|
|
147 }
|
|
148
|
|
149 class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
|
|
150 const SCEV *Operand) {
|
|
151 ValidatorResult Op = visit(Operand);
|
|
152 auto Type = Op.getType();
|
|
153
|
|
154 // If unsigned operations are allowed return the operand, otherwise
|
|
155 // check if we can model the expression without unsigned assumptions.
|
|
156 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
|
|
157 return Op;
|
|
158
|
|
159 if (Type == SCEVType::IV)
|
|
160 return ValidatorResult(SCEVType::INVALID);
|
|
161 return ValidatorResult(SCEVType::PARAM, Expr);
|
|
162 }
|
|
163
|
207
|
164 class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
|
|
165 return visit(Expr->getOperand());
|
|
166 }
|
|
167
|
150
|
168 class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
|
|
169 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
|
|
170 }
|
|
171
|
|
172 class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
|
|
173 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
|
|
174 }
|
|
175
|
|
176 class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
|
|
177 return visit(Expr->getOperand());
|
|
178 }
|
|
179
|
|
180 class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
|
|
181 ValidatorResult Return(SCEVType::INT);
|
|
182
|
|
183 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
|
|
184 ValidatorResult Op = visit(Expr->getOperand(i));
|
|
185 Return.merge(Op);
|
|
186
|
|
187 // Early exit.
|
|
188 if (!Return.isValid())
|
|
189 break;
|
|
190 }
|
|
191
|
|
192 return Return;
|
|
193 }
|
|
194
|
|
195 class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
|
|
196 ValidatorResult Return(SCEVType::INT);
|
|
197
|
|
198 bool HasMultipleParams = false;
|
|
199
|
|
200 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
|
|
201 ValidatorResult Op = visit(Expr->getOperand(i));
|
|
202
|
|
203 if (Op.isINT())
|
|
204 continue;
|
|
205
|
|
206 if (Op.isPARAM() && Return.isPARAM()) {
|
|
207 HasMultipleParams = true;
|
|
208 continue;
|
|
209 }
|
|
210
|
|
211 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
|
|
212 LLVM_DEBUG(
|
|
213 dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
|
|
214 << "\tExpr: " << *Expr << "\n"
|
|
215 << "\tPrevious expression type: " << Return << "\n"
|
|
216 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
|
|
217 << "\n");
|
|
218
|
|
219 return ValidatorResult(SCEVType::INVALID);
|
|
220 }
|
|
221
|
|
222 Return.merge(Op);
|
|
223 }
|
|
224
|
|
225 if (HasMultipleParams && Return.isValid())
|
|
226 return ValidatorResult(SCEVType::PARAM, Expr);
|
|
227
|
|
228 return Return;
|
|
229 }
|
|
230
|
|
231 class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
|
|
232 if (!Expr->isAffine()) {
|
|
233 LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
|
|
234 return ValidatorResult(SCEVType::INVALID);
|
|
235 }
|
|
236
|
|
237 ValidatorResult Start = visit(Expr->getStart());
|
|
238 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
|
|
239
|
|
240 if (!Start.isValid())
|
|
241 return Start;
|
|
242
|
|
243 if (!Recurrence.isValid())
|
|
244 return Recurrence;
|
|
245
|
|
246 auto *L = Expr->getLoop();
|
|
247 if (R->contains(L) && (!Scope || !L->contains(Scope))) {
|
|
248 LLVM_DEBUG(
|
|
249 dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
|
|
250 "non-affine subregion or has a non-synthesizable exit "
|
|
251 "value.");
|
|
252 return ValidatorResult(SCEVType::INVALID);
|
|
253 }
|
|
254
|
|
255 if (R->contains(L)) {
|
|
256 if (Recurrence.isINT()) {
|
|
257 ValidatorResult Result(SCEVType::IV);
|
|
258 Result.addParamsFrom(Start);
|
|
259 return Result;
|
|
260 }
|
|
261
|
|
262 LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
|
|
263 "recurrence part");
|
|
264 return ValidatorResult(SCEVType::INVALID);
|
|
265 }
|
|
266
|
|
267 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
|
|
268
|
|
269 // Directly generate ValidatorResult for Expr if 'start' is zero.
|
|
270 if (Expr->getStart()->isZero())
|
|
271 return ValidatorResult(SCEVType::PARAM, Expr);
|
|
272
|
|
273 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
|
|
274 // if 'start' is not zero.
|
|
275 const SCEV *ZeroStartExpr = SE.getAddRecExpr(
|
|
276 SE.getConstant(Expr->getStart()->getType(), 0),
|
|
277 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
|
|
278
|
|
279 ValidatorResult ZeroStartResult =
|
|
280 ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
|
|
281 ZeroStartResult.addParamsFrom(Start);
|
|
282
|
|
283 return ZeroStartResult;
|
|
284 }
|
|
285
|
|
286 class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
|
|
287 ValidatorResult Return(SCEVType::INT);
|
|
288
|
|
289 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
|
|
290 ValidatorResult Op = visit(Expr->getOperand(i));
|
|
291
|
|
292 if (!Op.isValid())
|
|
293 return Op;
|
|
294
|
|
295 Return.merge(Op);
|
|
296 }
|
|
297
|
|
298 return Return;
|
|
299 }
|
|
300
|
|
301 class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
|
|
302 ValidatorResult Return(SCEVType::INT);
|
|
303
|
|
304 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
|
|
305 ValidatorResult Op = visit(Expr->getOperand(i));
|
|
306
|
|
307 if (!Op.isValid())
|
|
308 return Op;
|
|
309
|
|
310 Return.merge(Op);
|
|
311 }
|
|
312
|
|
313 return Return;
|
|
314 }
|
|
315
|
|
316 class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
|
|
317 // We do not support unsigned max operations. If 'Expr' is constant during
|
|
318 // Scop execution we treat this as a parameter, otherwise we bail out.
|
|
319 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
|
|
320 ValidatorResult Op = visit(Expr->getOperand(i));
|
|
321
|
|
322 if (!Op.isConstant()) {
|
|
323 LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
|
|
324 return ValidatorResult(SCEVType::INVALID);
|
|
325 }
|
|
326 }
|
|
327
|
|
328 return ValidatorResult(SCEVType::PARAM, Expr);
|
|
329 }
|
|
330
|
|
331 class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
|
|
332 // We do not support unsigned min operations. If 'Expr' is constant during
|
|
333 // Scop execution we treat this as a parameter, otherwise we bail out.
|
|
334 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
|
|
335 ValidatorResult Op = visit(Expr->getOperand(i));
|
|
336
|
|
337 if (!Op.isConstant()) {
|
|
338 LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
|
|
339 return ValidatorResult(SCEVType::INVALID);
|
|
340 }
|
|
341 }
|
|
342
|
|
343 return ValidatorResult(SCEVType::PARAM, Expr);
|
|
344 }
|
|
345
|
|
346 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
|
|
347 if (R->contains(I)) {
|
|
348 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
|
|
349 "within the region\n");
|
|
350 return ValidatorResult(SCEVType::INVALID);
|
|
351 }
|
|
352
|
|
353 return ValidatorResult(SCEVType::PARAM, S);
|
|
354 }
|
|
355
|
|
356 ValidatorResult visitCallInstruction(Instruction *I, const SCEV *S) {
|
|
357 assert(I->getOpcode() == Instruction::Call && "Call instruction expected");
|
|
358
|
|
359 if (R->contains(I)) {
|
|
360 auto Call = cast<CallInst>(I);
|
|
361
|
|
362 if (!isConstCall(Call))
|
|
363 return ValidatorResult(SCEVType::INVALID, S);
|
|
364 }
|
|
365 return ValidatorResult(SCEVType::PARAM, S);
|
|
366 }
|
|
367
|
|
368 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
|
|
369 if (R->contains(I) && ILS) {
|
|
370 ILS->insert(cast<LoadInst>(I));
|
|
371 return ValidatorResult(SCEVType::PARAM, S);
|
|
372 }
|
|
373
|
|
374 return visitGenericInst(I, S);
|
|
375 }
|
|
376
|
|
377 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
|
|
378 const SCEV *DivExpr,
|
|
379 Instruction *SDiv = nullptr) {
|
|
380
|
|
381 // First check if we might be able to model the division, thus if the
|
|
382 // divisor is constant. If so, check the dividend, otherwise check if
|
|
383 // the whole division can be seen as a parameter.
|
|
384 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
|
|
385 return visit(Dividend);
|
|
386
|
|
387 // For signed divisions use the SDiv instruction to check for a parameter
|
|
388 // division, for unsigned divisions check the operands.
|
|
389 if (SDiv)
|
|
390 return visitGenericInst(SDiv, DivExpr);
|
|
391
|
|
392 ValidatorResult LHS = visit(Dividend);
|
|
393 ValidatorResult RHS = visit(Divisor);
|
|
394 if (LHS.isConstant() && RHS.isConstant())
|
|
395 return ValidatorResult(SCEVType::PARAM, DivExpr);
|
|
396
|
|
397 LLVM_DEBUG(
|
|
398 dbgs() << "INVALID: unsigned division of non-constant expressions");
|
|
399 return ValidatorResult(SCEVType::INVALID);
|
|
400 }
|
|
401
|
|
402 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
|
|
403 if (!PollyAllowUnsignedOperations)
|
|
404 return ValidatorResult(SCEVType::INVALID);
|
|
405
|
|
406 auto *Dividend = Expr->getLHS();
|
|
407 auto *Divisor = Expr->getRHS();
|
|
408 return visitDivision(Dividend, Divisor, Expr);
|
|
409 }
|
|
410
|
|
411 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
|
|
412 assert(SDiv->getOpcode() == Instruction::SDiv &&
|
|
413 "Assumed SDiv instruction!");
|
|
414
|
|
415 auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
|
|
416 auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
|
|
417 return visitDivision(Dividend, Divisor, Expr, SDiv);
|
|
418 }
|
|
419
|
|
420 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
|
|
421 assert(SRem->getOpcode() == Instruction::SRem &&
|
|
422 "Assumed SRem instruction!");
|
|
423
|
|
424 auto *Divisor = SRem->getOperand(1);
|
|
425 auto *CI = dyn_cast<ConstantInt>(Divisor);
|
|
426 if (!CI || CI->isZeroValue())
|
|
427 return visitGenericInst(SRem, S);
|
|
428
|
|
429 auto *Dividend = SRem->getOperand(0);
|
|
430 auto *DividendSCEV = SE.getSCEV(Dividend);
|
|
431 return visit(DividendSCEV);
|
|
432 }
|
|
433
|
|
434 ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
|
|
435 Value *V = Expr->getValue();
|
|
436
|
|
437 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
|
|
438 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
|
|
439 return ValidatorResult(SCEVType::INVALID);
|
|
440 }
|
|
441
|
|
442 if (isa<UndefValue>(V)) {
|
|
443 LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
|
|
444 return ValidatorResult(SCEVType::INVALID);
|
|
445 }
|
|
446
|
|
447 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
|
|
448 switch (I->getOpcode()) {
|
|
449 case Instruction::IntToPtr:
|
|
450 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
|
|
451 case Instruction::Load:
|
|
452 return visitLoadInstruction(I, Expr);
|
|
453 case Instruction::SDiv:
|
|
454 return visitSDivInstruction(I, Expr);
|
|
455 case Instruction::SRem:
|
|
456 return visitSRemInstruction(I, Expr);
|
|
457 case Instruction::Call:
|
|
458 return visitCallInstruction(I, Expr);
|
|
459 default:
|
|
460 return visitGenericInst(I, Expr);
|
|
461 }
|
|
462 }
|
|
463
|
207
|
464 if (Expr->getType()->isPointerTy()) {
|
|
465 if (isa<ConstantPointerNull>(V))
|
|
466 return ValidatorResult(SCEVType::INT); // "int"
|
|
467 }
|
|
468
|
150
|
469 return ValidatorResult(SCEVType::PARAM, Expr);
|
|
470 }
|
|
471 };
|
|
472
|
|
473 class SCEVHasIVParams {
|
|
474 bool HasIVParams = false;
|
|
475
|
|
476 public:
|
|
477 SCEVHasIVParams() {}
|
|
478
|
|
479 bool follow(const SCEV *S) {
|
|
480 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
|
|
481 if (!Unknown)
|
|
482 return true;
|
|
483
|
|
484 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
|
|
485
|
|
486 if (!Call)
|
|
487 return true;
|
|
488
|
|
489 if (isConstCall(Call)) {
|
|
490 HasIVParams = true;
|
|
491 return false;
|
|
492 }
|
|
493
|
|
494 return true;
|
|
495 }
|
|
496
|
|
497 bool isDone() { return HasIVParams; }
|
|
498 bool hasIVParams() { return HasIVParams; }
|
|
499 };
|
|
500
|
|
501 /// Check whether a SCEV refers to an SSA name defined inside a region.
|
|
502 class SCEVInRegionDependences {
|
|
503 const Region *R;
|
|
504 Loop *Scope;
|
|
505 const InvariantLoadsSetTy &ILS;
|
|
506 bool AllowLoops;
|
|
507 bool HasInRegionDeps = false;
|
|
508
|
|
509 public:
|
|
510 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
|
|
511 const InvariantLoadsSetTy &ILS)
|
|
512 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
|
|
513
|
|
514 bool follow(const SCEV *S) {
|
|
515 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
|
|
516 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
|
|
517
|
|
518 CallInst *Call = dyn_cast<CallInst>(Unknown->getValue());
|
|
519
|
|
520 if (Call && isConstCall(Call))
|
|
521 return false;
|
|
522
|
|
523 if (Inst) {
|
|
524 // When we invariant load hoist a load, we first make sure that there
|
|
525 // can be no dependences created by it in the Scop region. So, we should
|
|
526 // not consider scalar dependences to `LoadInst`s that are invariant
|
|
527 // load hoisted.
|
|
528 //
|
|
529 // If this check is not present, then we create data dependences which
|
|
530 // are strictly not necessary by tracking the invariant load as a
|
|
531 // scalar.
|
|
532 LoadInst *LI = dyn_cast<LoadInst>(Inst);
|
|
533 if (LI && ILS.count(LI) > 0)
|
|
534 return false;
|
|
535 }
|
|
536
|
|
537 // Return true when Inst is defined inside the region R.
|
|
538 if (!Inst || !R->contains(Inst))
|
|
539 return true;
|
|
540
|
|
541 HasInRegionDeps = true;
|
|
542 return false;
|
|
543 }
|
|
544
|
|
545 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
|
|
546 if (AllowLoops)
|
|
547 return true;
|
|
548
|
|
549 auto *L = AddRec->getLoop();
|
|
550 if (R->contains(L) && !L->contains(Scope)) {
|
|
551 HasInRegionDeps = true;
|
|
552 return false;
|
|
553 }
|
|
554 }
|
|
555
|
|
556 return true;
|
|
557 }
|
|
558 bool isDone() { return false; }
|
|
559 bool hasDependences() { return HasInRegionDeps; }
|
|
560 };
|
|
561
|
|
562 namespace polly {
|
|
563 /// Find all loops referenced in SCEVAddRecExprs.
|
|
564 class SCEVFindLoops {
|
|
565 SetVector<const Loop *> &Loops;
|
|
566
|
|
567 public:
|
|
568 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
|
|
569
|
|
570 bool follow(const SCEV *S) {
|
|
571 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
|
|
572 Loops.insert(AddRec->getLoop());
|
|
573 return true;
|
|
574 }
|
|
575 bool isDone() { return false; }
|
|
576 };
|
|
577
|
|
578 void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
|
|
579 SCEVFindLoops FindLoops(Loops);
|
|
580 SCEVTraversal<SCEVFindLoops> ST(FindLoops);
|
|
581 ST.visitAll(Expr);
|
|
582 }
|
|
583
|
|
584 /// Find all values referenced in SCEVUnknowns.
|
|
585 class SCEVFindValues {
|
|
586 ScalarEvolution &SE;
|
|
587 SetVector<Value *> &Values;
|
|
588
|
|
589 public:
|
|
590 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
|
|
591 : SE(SE), Values(Values) {}
|
|
592
|
|
593 bool follow(const SCEV *S) {
|
|
594 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
|
|
595 if (!Unknown)
|
|
596 return true;
|
|
597
|
|
598 Values.insert(Unknown->getValue());
|
|
599 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
|
|
600 if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
|
|
601 Inst->getOpcode() != Instruction::SDiv))
|
|
602 return false;
|
|
603
|
|
604 auto *Dividend = SE.getSCEV(Inst->getOperand(1));
|
|
605 if (!isa<SCEVConstant>(Dividend))
|
|
606 return false;
|
|
607
|
|
608 auto *Divisor = SE.getSCEV(Inst->getOperand(0));
|
|
609 SCEVFindValues FindValues(SE, Values);
|
|
610 SCEVTraversal<SCEVFindValues> ST(FindValues);
|
|
611 ST.visitAll(Dividend);
|
|
612 ST.visitAll(Divisor);
|
|
613
|
|
614 return false;
|
|
615 }
|
|
616 bool isDone() { return false; }
|
|
617 };
|
|
618
|
|
619 void findValues(const SCEV *Expr, ScalarEvolution &SE,
|
|
620 SetVector<Value *> &Values) {
|
|
621 SCEVFindValues FindValues(SE, Values);
|
|
622 SCEVTraversal<SCEVFindValues> ST(FindValues);
|
|
623 ST.visitAll(Expr);
|
|
624 }
|
|
625
|
|
626 bool hasIVParams(const SCEV *Expr) {
|
|
627 SCEVHasIVParams HasIVParams;
|
|
628 SCEVTraversal<SCEVHasIVParams> ST(HasIVParams);
|
|
629 ST.visitAll(Expr);
|
|
630 return HasIVParams.hasIVParams();
|
|
631 }
|
|
632
|
|
633 bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
|
|
634 llvm::Loop *Scope, bool AllowLoops,
|
|
635 const InvariantLoadsSetTy &ILS) {
|
|
636 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
|
|
637 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
|
|
638 ST.visitAll(Expr);
|
|
639 return InRegionDeps.hasDependences();
|
|
640 }
|
|
641
|
|
642 bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
|
|
643 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
|
|
644 if (isa<SCEVCouldNotCompute>(Expr))
|
|
645 return false;
|
|
646
|
|
647 SCEVValidator Validator(R, Scope, SE, ILS);
|
|
648 LLVM_DEBUG({
|
|
649 dbgs() << "\n";
|
|
650 dbgs() << "Expr: " << *Expr << "\n";
|
|
651 dbgs() << "Region: " << R->getNameStr() << "\n";
|
|
652 dbgs() << " -> ";
|
|
653 });
|
|
654
|
|
655 ValidatorResult Result = Validator.visit(Expr);
|
|
656
|
|
657 LLVM_DEBUG({
|
|
658 if (Result.isValid())
|
|
659 dbgs() << "VALID\n";
|
|
660 dbgs() << "\n";
|
|
661 });
|
|
662
|
|
663 return Result.isValid();
|
|
664 }
|
|
665
|
|
666 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
|
|
667 ScalarEvolution &SE, ParameterSetTy &Params) {
|
|
668 auto *E = SE.getSCEV(V);
|
|
669 if (isa<SCEVCouldNotCompute>(E))
|
|
670 return false;
|
|
671
|
|
672 SCEVValidator Validator(R, Scope, SE, nullptr);
|
|
673 ValidatorResult Result = Validator.visit(E);
|
|
674 if (!Result.isValid())
|
|
675 return false;
|
|
676
|
|
677 auto ResultParams = Result.getParameters();
|
|
678 Params.insert(ResultParams.begin(), ResultParams.end());
|
|
679
|
|
680 return true;
|
|
681 }
|
|
682
|
|
683 bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
|
|
684 ScalarEvolution &SE, ParameterSetTy &Params,
|
|
685 bool OrExpr) {
|
|
686 if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
|
|
687 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
|
|
688 true) &&
|
|
689 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
|
|
690 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
|
|
691 auto Opcode = BinOp->getOpcode();
|
|
692 if (Opcode == Instruction::And || Opcode == Instruction::Or)
|
|
693 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
|
|
694 false) &&
|
|
695 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
|
|
696 false);
|
|
697 /* Fall through */
|
|
698 }
|
|
699
|
|
700 if (!OrExpr)
|
|
701 return false;
|
|
702
|
|
703 return isAffineExpr(V, R, Scope, SE, Params);
|
|
704 }
|
|
705
|
|
706 ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
|
|
707 const SCEV *Expr, ScalarEvolution &SE) {
|
|
708 if (isa<SCEVCouldNotCompute>(Expr))
|
|
709 return ParameterSetTy();
|
|
710
|
|
711 InvariantLoadsSetTy ILS;
|
|
712 SCEVValidator Validator(R, Scope, SE, &ILS);
|
|
713 ValidatorResult Result = Validator.visit(Expr);
|
|
714 assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
|
|
715
|
|
716 return Result.getParameters();
|
|
717 }
|
|
718
|
|
719 std::pair<const SCEVConstant *, const SCEV *>
|
|
720 extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
|
|
721 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
|
|
722
|
|
723 if (auto *Constant = dyn_cast<SCEVConstant>(S))
|
|
724 return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
|
|
725
|
|
726 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
|
|
727 if (AddRec) {
|
|
728 auto *StartExpr = AddRec->getStart();
|
|
729 if (StartExpr->isZero()) {
|
|
730 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
|
|
731 auto *LeftOverAddRec =
|
|
732 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
|
|
733 AddRec->getNoWrapFlags());
|
|
734 return std::make_pair(StepPair.first, LeftOverAddRec);
|
|
735 }
|
|
736 return std::make_pair(ConstPart, S);
|
|
737 }
|
|
738
|
|
739 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
|
|
740 SmallVector<const SCEV *, 4> LeftOvers;
|
|
741 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
|
|
742 auto *Factor = Op0Pair.first;
|
|
743 if (SE.isKnownNegative(Factor)) {
|
|
744 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
|
|
745 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
|
|
746 } else {
|
|
747 LeftOvers.push_back(Op0Pair.second);
|
|
748 }
|
|
749
|
|
750 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
|
|
751 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
|
|
752 // TODO: Use something smarter than equality here, e.g., gcd.
|
|
753 if (Factor == OpUPair.first)
|
|
754 LeftOvers.push_back(OpUPair.second);
|
|
755 else if (Factor == SE.getNegativeSCEV(OpUPair.first))
|
|
756 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
|
|
757 else
|
|
758 return std::make_pair(ConstPart, S);
|
|
759 }
|
|
760
|
|
761 auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
|
|
762 return std::make_pair(Factor, NewAdd);
|
|
763 }
|
|
764
|
|
765 auto *Mul = dyn_cast<SCEVMulExpr>(S);
|
|
766 if (!Mul)
|
|
767 return std::make_pair(ConstPart, S);
|
|
768
|
|
769 SmallVector<const SCEV *, 4> LeftOvers;
|
|
770 for (auto *Op : Mul->operands())
|
|
771 if (isa<SCEVConstant>(Op))
|
|
772 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
|
|
773 else
|
|
774 LeftOvers.push_back(Op);
|
|
775
|
|
776 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
|
|
777 }
|
|
778
|
|
779 const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
|
|
780 ScalarEvolution &SE, LoopInfo &LI,
|
|
781 const DominatorTree &DT) {
|
|
782 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
|
|
783 Value *V = Unknown->getValue();
|
|
784 auto *PHI = dyn_cast<PHINode>(V);
|
|
785 if (!PHI)
|
|
786 return Expr;
|
|
787
|
|
788 Value *Final = nullptr;
|
|
789
|
|
790 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
|
|
791 BasicBlock *Incoming = PHI->getIncomingBlock(i);
|
|
792 if (isErrorBlock(*Incoming, R, LI, DT) && R.contains(Incoming))
|
|
793 continue;
|
|
794 if (Final)
|
|
795 return Expr;
|
|
796 Final = PHI->getIncomingValue(i);
|
|
797 }
|
|
798
|
|
799 if (Final)
|
|
800 return SE.getSCEV(Final);
|
|
801 }
|
|
802 return Expr;
|
|
803 }
|
|
804
|
|
805 Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, LoopInfo &LI,
|
|
806 const DominatorTree &DT) {
|
|
807 Value *V = nullptr;
|
|
808 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
|
|
809 BasicBlock *BB = PHI->getIncomingBlock(i);
|
|
810 if (!isErrorBlock(*BB, *R, LI, DT)) {
|
|
811 if (V)
|
|
812 return nullptr;
|
|
813 V = PHI->getIncomingValue(i);
|
|
814 }
|
|
815 }
|
|
816
|
|
817 return V;
|
|
818 }
|
|
819 } // namespace polly
|