150
|
1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8 //
|
|
9 // Instrumentation-based profile-guided optimization
|
|
10 //
|
|
11 //===----------------------------------------------------------------------===//
|
|
12
|
|
13 #include "CodeGenPGO.h"
|
|
14 #include "CodeGenFunction.h"
|
|
15 #include "CoverageMappingGen.h"
|
|
16 #include "clang/AST/RecursiveASTVisitor.h"
|
|
17 #include "clang/AST/StmtVisitor.h"
|
|
18 #include "llvm/IR/Intrinsics.h"
|
|
19 #include "llvm/IR/MDBuilder.h"
|
|
20 #include "llvm/Support/CommandLine.h"
|
|
21 #include "llvm/Support/Endian.h"
|
|
22 #include "llvm/Support/FileSystem.h"
|
|
23 #include "llvm/Support/MD5.h"
|
|
24
|
|
25 static llvm::cl::opt<bool>
|
|
26 EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
|
|
27 llvm::cl::desc("Enable value profiling"),
|
|
28 llvm::cl::Hidden, llvm::cl::init(false));
|
|
29
|
|
30 using namespace clang;
|
|
31 using namespace CodeGen;
|
|
32
|
|
33 void CodeGenPGO::setFuncName(StringRef Name,
|
|
34 llvm::GlobalValue::LinkageTypes Linkage) {
|
|
35 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
|
|
36 FuncName = llvm::getPGOFuncName(
|
|
37 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
|
|
38 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
|
|
39
|
|
40 // If we're generating a profile, create a variable for the name.
|
|
41 if (CGM.getCodeGenOpts().hasProfileClangInstr())
|
|
42 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
|
|
43 }
|
|
44
|
|
45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
|
|
46 setFuncName(Fn->getName(), Fn->getLinkage());
|
|
47 // Create PGOFuncName meta data.
|
|
48 llvm::createPGOFuncNameMetadata(*Fn, FuncName);
|
|
49 }
|
|
50
|
|
51 /// The version of the PGO hash algorithm.
|
|
52 enum PGOHashVersion : unsigned {
|
|
53 PGO_HASH_V1,
|
|
54 PGO_HASH_V2,
|
|
55
|
|
56 // Keep this set to the latest hash version.
|
|
57 PGO_HASH_LATEST = PGO_HASH_V2
|
|
58 };
|
|
59
|
|
60 namespace {
|
|
61 /// Stable hasher for PGO region counters.
|
|
62 ///
|
|
63 /// PGOHash produces a stable hash of a given function's control flow.
|
|
64 ///
|
|
65 /// Changing the output of this hash will invalidate all previously generated
|
|
66 /// profiles -- i.e., don't do it.
|
|
67 ///
|
|
68 /// \note When this hash does eventually change (years?), we still need to
|
|
69 /// support old hashes. We'll need to pull in the version number from the
|
|
70 /// profile data format and use the matching hash function.
|
|
71 class PGOHash {
|
|
72 uint64_t Working;
|
|
73 unsigned Count;
|
|
74 PGOHashVersion HashVersion;
|
|
75 llvm::MD5 MD5;
|
|
76
|
|
77 static const int NumBitsPerType = 6;
|
|
78 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
|
|
79 static const unsigned TooBig = 1u << NumBitsPerType;
|
|
80
|
|
81 public:
|
|
82 /// Hash values for AST nodes.
|
|
83 ///
|
|
84 /// Distinct values for AST nodes that have region counters attached.
|
|
85 ///
|
|
86 /// These values must be stable. All new members must be added at the end,
|
|
87 /// and no members should be removed. Changing the enumeration value for an
|
|
88 /// AST node will affect the hash of every function that contains that node.
|
|
89 enum HashType : unsigned char {
|
|
90 None = 0,
|
|
91 LabelStmt = 1,
|
|
92 WhileStmt,
|
|
93 DoStmt,
|
|
94 ForStmt,
|
|
95 CXXForRangeStmt,
|
|
96 ObjCForCollectionStmt,
|
|
97 SwitchStmt,
|
|
98 CaseStmt,
|
|
99 DefaultStmt,
|
|
100 IfStmt,
|
|
101 CXXTryStmt,
|
|
102 CXXCatchStmt,
|
|
103 ConditionalOperator,
|
|
104 BinaryOperatorLAnd,
|
|
105 BinaryOperatorLOr,
|
|
106 BinaryConditionalOperator,
|
|
107 // The preceding values are available with PGO_HASH_V1.
|
|
108
|
|
109 EndOfScope,
|
|
110 IfThenBranch,
|
|
111 IfElseBranch,
|
|
112 GotoStmt,
|
|
113 IndirectGotoStmt,
|
|
114 BreakStmt,
|
|
115 ContinueStmt,
|
|
116 ReturnStmt,
|
|
117 ThrowExpr,
|
|
118 UnaryOperatorLNot,
|
|
119 BinaryOperatorLT,
|
|
120 BinaryOperatorGT,
|
|
121 BinaryOperatorLE,
|
|
122 BinaryOperatorGE,
|
|
123 BinaryOperatorEQ,
|
|
124 BinaryOperatorNE,
|
|
125 // The preceding values are available with PGO_HASH_V2.
|
|
126
|
|
127 // Keep this last. It's for the static assert that follows.
|
|
128 LastHashType
|
|
129 };
|
|
130 static_assert(LastHashType <= TooBig, "Too many types in HashType");
|
|
131
|
|
132 PGOHash(PGOHashVersion HashVersion)
|
|
133 : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
|
|
134 void combine(HashType Type);
|
|
135 uint64_t finalize();
|
|
136 PGOHashVersion getHashVersion() const { return HashVersion; }
|
|
137 };
|
|
138 const int PGOHash::NumBitsPerType;
|
|
139 const unsigned PGOHash::NumTypesPerWord;
|
|
140 const unsigned PGOHash::TooBig;
|
|
141
|
|
142 /// Get the PGO hash version used in the given indexed profile.
|
|
143 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
|
|
144 CodeGenModule &CGM) {
|
|
145 if (PGOReader->getVersion() <= 4)
|
|
146 return PGO_HASH_V1;
|
|
147 return PGO_HASH_V2;
|
|
148 }
|
|
149
|
|
150 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
|
|
151 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
|
|
152 using Base = RecursiveASTVisitor<MapRegionCounters>;
|
|
153
|
|
154 /// The next counter value to assign.
|
|
155 unsigned NextCounter;
|
|
156 /// The function hash.
|
|
157 PGOHash Hash;
|
|
158 /// The map of statements to counters.
|
|
159 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
|
|
160
|
|
161 MapRegionCounters(PGOHashVersion HashVersion,
|
|
162 llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
|
|
163 : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
|
|
164
|
|
165 // Blocks and lambdas are handled as separate functions, so we need not
|
|
166 // traverse them in the parent context.
|
|
167 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
|
|
168 bool TraverseLambdaExpr(LambdaExpr *LE) {
|
|
169 // Traverse the captures, but not the body.
|
|
170 for (auto C : zip(LE->captures(), LE->capture_inits()))
|
|
171 TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
|
|
172 return true;
|
|
173 }
|
|
174 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
|
|
175
|
|
176 bool VisitDecl(const Decl *D) {
|
|
177 switch (D->getKind()) {
|
|
178 default:
|
|
179 break;
|
|
180 case Decl::Function:
|
|
181 case Decl::CXXMethod:
|
|
182 case Decl::CXXConstructor:
|
|
183 case Decl::CXXDestructor:
|
|
184 case Decl::CXXConversion:
|
|
185 case Decl::ObjCMethod:
|
|
186 case Decl::Block:
|
|
187 case Decl::Captured:
|
|
188 CounterMap[D->getBody()] = NextCounter++;
|
|
189 break;
|
|
190 }
|
|
191 return true;
|
|
192 }
|
|
193
|
|
194 /// If \p S gets a fresh counter, update the counter mappings. Return the
|
|
195 /// V1 hash of \p S.
|
|
196 PGOHash::HashType updateCounterMappings(Stmt *S) {
|
|
197 auto Type = getHashType(PGO_HASH_V1, S);
|
|
198 if (Type != PGOHash::None)
|
|
199 CounterMap[S] = NextCounter++;
|
|
200 return Type;
|
|
201 }
|
|
202
|
|
203 /// Include \p S in the function hash.
|
|
204 bool VisitStmt(Stmt *S) {
|
|
205 auto Type = updateCounterMappings(S);
|
|
206 if (Hash.getHashVersion() != PGO_HASH_V1)
|
|
207 Type = getHashType(Hash.getHashVersion(), S);
|
|
208 if (Type != PGOHash::None)
|
|
209 Hash.combine(Type);
|
|
210 return true;
|
|
211 }
|
|
212
|
|
213 bool TraverseIfStmt(IfStmt *If) {
|
|
214 // If we used the V1 hash, use the default traversal.
|
|
215 if (Hash.getHashVersion() == PGO_HASH_V1)
|
|
216 return Base::TraverseIfStmt(If);
|
|
217
|
|
218 // Otherwise, keep track of which branch we're in while traversing.
|
|
219 VisitStmt(If);
|
|
220 for (Stmt *CS : If->children()) {
|
|
221 if (!CS)
|
|
222 continue;
|
|
223 if (CS == If->getThen())
|
|
224 Hash.combine(PGOHash::IfThenBranch);
|
|
225 else if (CS == If->getElse())
|
|
226 Hash.combine(PGOHash::IfElseBranch);
|
|
227 TraverseStmt(CS);
|
|
228 }
|
|
229 Hash.combine(PGOHash::EndOfScope);
|
|
230 return true;
|
|
231 }
|
|
232
|
|
233 // If the statement type \p N is nestable, and its nesting impacts profile
|
|
234 // stability, define a custom traversal which tracks the end of the statement
|
|
235 // in the hash (provided we're not using the V1 hash).
|
|
236 #define DEFINE_NESTABLE_TRAVERSAL(N) \
|
|
237 bool Traverse##N(N *S) { \
|
|
238 Base::Traverse##N(S); \
|
|
239 if (Hash.getHashVersion() != PGO_HASH_V1) \
|
|
240 Hash.combine(PGOHash::EndOfScope); \
|
|
241 return true; \
|
|
242 }
|
|
243
|
|
244 DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
|
|
245 DEFINE_NESTABLE_TRAVERSAL(DoStmt)
|
|
246 DEFINE_NESTABLE_TRAVERSAL(ForStmt)
|
|
247 DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
|
|
248 DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
|
|
249 DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
|
|
250 DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
|
|
251
|
|
252 /// Get version \p HashVersion of the PGO hash for \p S.
|
|
253 PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
|
|
254 switch (S->getStmtClass()) {
|
|
255 default:
|
|
256 break;
|
|
257 case Stmt::LabelStmtClass:
|
|
258 return PGOHash::LabelStmt;
|
|
259 case Stmt::WhileStmtClass:
|
|
260 return PGOHash::WhileStmt;
|
|
261 case Stmt::DoStmtClass:
|
|
262 return PGOHash::DoStmt;
|
|
263 case Stmt::ForStmtClass:
|
|
264 return PGOHash::ForStmt;
|
|
265 case Stmt::CXXForRangeStmtClass:
|
|
266 return PGOHash::CXXForRangeStmt;
|
|
267 case Stmt::ObjCForCollectionStmtClass:
|
|
268 return PGOHash::ObjCForCollectionStmt;
|
|
269 case Stmt::SwitchStmtClass:
|
|
270 return PGOHash::SwitchStmt;
|
|
271 case Stmt::CaseStmtClass:
|
|
272 return PGOHash::CaseStmt;
|
|
273 case Stmt::DefaultStmtClass:
|
|
274 return PGOHash::DefaultStmt;
|
|
275 case Stmt::IfStmtClass:
|
|
276 return PGOHash::IfStmt;
|
|
277 case Stmt::CXXTryStmtClass:
|
|
278 return PGOHash::CXXTryStmt;
|
|
279 case Stmt::CXXCatchStmtClass:
|
|
280 return PGOHash::CXXCatchStmt;
|
|
281 case Stmt::ConditionalOperatorClass:
|
|
282 return PGOHash::ConditionalOperator;
|
|
283 case Stmt::BinaryConditionalOperatorClass:
|
|
284 return PGOHash::BinaryConditionalOperator;
|
|
285 case Stmt::BinaryOperatorClass: {
|
|
286 const BinaryOperator *BO = cast<BinaryOperator>(S);
|
|
287 if (BO->getOpcode() == BO_LAnd)
|
|
288 return PGOHash::BinaryOperatorLAnd;
|
|
289 if (BO->getOpcode() == BO_LOr)
|
|
290 return PGOHash::BinaryOperatorLOr;
|
|
291 if (HashVersion == PGO_HASH_V2) {
|
|
292 switch (BO->getOpcode()) {
|
|
293 default:
|
|
294 break;
|
|
295 case BO_LT:
|
|
296 return PGOHash::BinaryOperatorLT;
|
|
297 case BO_GT:
|
|
298 return PGOHash::BinaryOperatorGT;
|
|
299 case BO_LE:
|
|
300 return PGOHash::BinaryOperatorLE;
|
|
301 case BO_GE:
|
|
302 return PGOHash::BinaryOperatorGE;
|
|
303 case BO_EQ:
|
|
304 return PGOHash::BinaryOperatorEQ;
|
|
305 case BO_NE:
|
|
306 return PGOHash::BinaryOperatorNE;
|
|
307 }
|
|
308 }
|
|
309 break;
|
|
310 }
|
|
311 }
|
|
312
|
|
313 if (HashVersion == PGO_HASH_V2) {
|
|
314 switch (S->getStmtClass()) {
|
|
315 default:
|
|
316 break;
|
|
317 case Stmt::GotoStmtClass:
|
|
318 return PGOHash::GotoStmt;
|
|
319 case Stmt::IndirectGotoStmtClass:
|
|
320 return PGOHash::IndirectGotoStmt;
|
|
321 case Stmt::BreakStmtClass:
|
|
322 return PGOHash::BreakStmt;
|
|
323 case Stmt::ContinueStmtClass:
|
|
324 return PGOHash::ContinueStmt;
|
|
325 case Stmt::ReturnStmtClass:
|
|
326 return PGOHash::ReturnStmt;
|
|
327 case Stmt::CXXThrowExprClass:
|
|
328 return PGOHash::ThrowExpr;
|
|
329 case Stmt::UnaryOperatorClass: {
|
|
330 const UnaryOperator *UO = cast<UnaryOperator>(S);
|
|
331 if (UO->getOpcode() == UO_LNot)
|
|
332 return PGOHash::UnaryOperatorLNot;
|
|
333 break;
|
|
334 }
|
|
335 }
|
|
336 }
|
|
337
|
|
338 return PGOHash::None;
|
|
339 }
|
|
340 };
|
|
341
|
|
342 /// A StmtVisitor that propagates the raw counts through the AST and
|
|
343 /// records the count at statements where the value may change.
|
|
344 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
|
|
345 /// PGO state.
|
|
346 CodeGenPGO &PGO;
|
|
347
|
|
348 /// A flag that is set when the current count should be recorded on the
|
|
349 /// next statement, such as at the exit of a loop.
|
|
350 bool RecordNextStmtCount;
|
|
351
|
|
352 /// The count at the current location in the traversal.
|
|
353 uint64_t CurrentCount;
|
|
354
|
|
355 /// The map of statements to count values.
|
|
356 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
|
|
357
|
|
358 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
|
|
359 struct BreakContinue {
|
|
360 uint64_t BreakCount;
|
|
361 uint64_t ContinueCount;
|
|
362 BreakContinue() : BreakCount(0), ContinueCount(0) {}
|
|
363 };
|
|
364 SmallVector<BreakContinue, 8> BreakContinueStack;
|
|
365
|
|
366 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
|
|
367 CodeGenPGO &PGO)
|
|
368 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
|
|
369
|
|
370 void RecordStmtCount(const Stmt *S) {
|
|
371 if (RecordNextStmtCount) {
|
|
372 CountMap[S] = CurrentCount;
|
|
373 RecordNextStmtCount = false;
|
|
374 }
|
|
375 }
|
|
376
|
|
377 /// Set and return the current count.
|
|
378 uint64_t setCount(uint64_t Count) {
|
|
379 CurrentCount = Count;
|
|
380 return Count;
|
|
381 }
|
|
382
|
|
383 void VisitStmt(const Stmt *S) {
|
|
384 RecordStmtCount(S);
|
|
385 for (const Stmt *Child : S->children())
|
|
386 if (Child)
|
|
387 this->Visit(Child);
|
|
388 }
|
|
389
|
|
390 void VisitFunctionDecl(const FunctionDecl *D) {
|
|
391 // Counter tracks entry to the function body.
|
|
392 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
|
|
393 CountMap[D->getBody()] = BodyCount;
|
|
394 Visit(D->getBody());
|
|
395 }
|
|
396
|
|
397 // Skip lambda expressions. We visit these as FunctionDecls when we're
|
|
398 // generating them and aren't interested in the body when generating a
|
|
399 // parent context.
|
|
400 void VisitLambdaExpr(const LambdaExpr *LE) {}
|
|
401
|
|
402 void VisitCapturedDecl(const CapturedDecl *D) {
|
|
403 // Counter tracks entry to the capture body.
|
|
404 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
|
|
405 CountMap[D->getBody()] = BodyCount;
|
|
406 Visit(D->getBody());
|
|
407 }
|
|
408
|
|
409 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
|
|
410 // Counter tracks entry to the method body.
|
|
411 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
|
|
412 CountMap[D->getBody()] = BodyCount;
|
|
413 Visit(D->getBody());
|
|
414 }
|
|
415
|
|
416 void VisitBlockDecl(const BlockDecl *D) {
|
|
417 // Counter tracks entry to the block body.
|
|
418 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
|
|
419 CountMap[D->getBody()] = BodyCount;
|
|
420 Visit(D->getBody());
|
|
421 }
|
|
422
|
|
423 void VisitReturnStmt(const ReturnStmt *S) {
|
|
424 RecordStmtCount(S);
|
|
425 if (S->getRetValue())
|
|
426 Visit(S->getRetValue());
|
|
427 CurrentCount = 0;
|
|
428 RecordNextStmtCount = true;
|
|
429 }
|
|
430
|
|
431 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
|
|
432 RecordStmtCount(E);
|
|
433 if (E->getSubExpr())
|
|
434 Visit(E->getSubExpr());
|
|
435 CurrentCount = 0;
|
|
436 RecordNextStmtCount = true;
|
|
437 }
|
|
438
|
|
439 void VisitGotoStmt(const GotoStmt *S) {
|
|
440 RecordStmtCount(S);
|
|
441 CurrentCount = 0;
|
|
442 RecordNextStmtCount = true;
|
|
443 }
|
|
444
|
|
445 void VisitLabelStmt(const LabelStmt *S) {
|
|
446 RecordNextStmtCount = false;
|
|
447 // Counter tracks the block following the label.
|
|
448 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
|
|
449 CountMap[S] = BlockCount;
|
|
450 Visit(S->getSubStmt());
|
|
451 }
|
|
452
|
|
453 void VisitBreakStmt(const BreakStmt *S) {
|
|
454 RecordStmtCount(S);
|
|
455 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
|
|
456 BreakContinueStack.back().BreakCount += CurrentCount;
|
|
457 CurrentCount = 0;
|
|
458 RecordNextStmtCount = true;
|
|
459 }
|
|
460
|
|
461 void VisitContinueStmt(const ContinueStmt *S) {
|
|
462 RecordStmtCount(S);
|
|
463 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
|
|
464 BreakContinueStack.back().ContinueCount += CurrentCount;
|
|
465 CurrentCount = 0;
|
|
466 RecordNextStmtCount = true;
|
|
467 }
|
|
468
|
|
469 void VisitWhileStmt(const WhileStmt *S) {
|
|
470 RecordStmtCount(S);
|
|
471 uint64_t ParentCount = CurrentCount;
|
|
472
|
|
473 BreakContinueStack.push_back(BreakContinue());
|
|
474 // Visit the body region first so the break/continue adjustments can be
|
|
475 // included when visiting the condition.
|
|
476 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
|
|
477 CountMap[S->getBody()] = CurrentCount;
|
|
478 Visit(S->getBody());
|
|
479 uint64_t BackedgeCount = CurrentCount;
|
|
480
|
|
481 // ...then go back and propagate counts through the condition. The count
|
|
482 // at the start of the condition is the sum of the incoming edges,
|
|
483 // the backedge from the end of the loop body, and the edges from
|
|
484 // continue statements.
|
|
485 BreakContinue BC = BreakContinueStack.pop_back_val();
|
|
486 uint64_t CondCount =
|
|
487 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
|
|
488 CountMap[S->getCond()] = CondCount;
|
|
489 Visit(S->getCond());
|
|
490 setCount(BC.BreakCount + CondCount - BodyCount);
|
|
491 RecordNextStmtCount = true;
|
|
492 }
|
|
493
|
|
494 void VisitDoStmt(const DoStmt *S) {
|
|
495 RecordStmtCount(S);
|
|
496 uint64_t LoopCount = PGO.getRegionCount(S);
|
|
497
|
|
498 BreakContinueStack.push_back(BreakContinue());
|
|
499 // The count doesn't include the fallthrough from the parent scope. Add it.
|
|
500 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
|
|
501 CountMap[S->getBody()] = BodyCount;
|
|
502 Visit(S->getBody());
|
|
503 uint64_t BackedgeCount = CurrentCount;
|
|
504
|
|
505 BreakContinue BC = BreakContinueStack.pop_back_val();
|
|
506 // The count at the start of the condition is equal to the count at the
|
|
507 // end of the body, plus any continues.
|
|
508 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
|
|
509 CountMap[S->getCond()] = CondCount;
|
|
510 Visit(S->getCond());
|
|
511 setCount(BC.BreakCount + CondCount - LoopCount);
|
|
512 RecordNextStmtCount = true;
|
|
513 }
|
|
514
|
|
515 void VisitForStmt(const ForStmt *S) {
|
|
516 RecordStmtCount(S);
|
|
517 if (S->getInit())
|
|
518 Visit(S->getInit());
|
|
519
|
|
520 uint64_t ParentCount = CurrentCount;
|
|
521
|
|
522 BreakContinueStack.push_back(BreakContinue());
|
|
523 // Visit the body region first. (This is basically the same as a while
|
|
524 // loop; see further comments in VisitWhileStmt.)
|
|
525 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
|
|
526 CountMap[S->getBody()] = BodyCount;
|
|
527 Visit(S->getBody());
|
|
528 uint64_t BackedgeCount = CurrentCount;
|
|
529 BreakContinue BC = BreakContinueStack.pop_back_val();
|
|
530
|
|
531 // The increment is essentially part of the body but it needs to include
|
|
532 // the count for all the continue statements.
|
|
533 if (S->getInc()) {
|
|
534 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
|
|
535 CountMap[S->getInc()] = IncCount;
|
|
536 Visit(S->getInc());
|
|
537 }
|
|
538
|
|
539 // ...then go back and propagate counts through the condition.
|
|
540 uint64_t CondCount =
|
|
541 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
|
|
542 if (S->getCond()) {
|
|
543 CountMap[S->getCond()] = CondCount;
|
|
544 Visit(S->getCond());
|
|
545 }
|
|
546 setCount(BC.BreakCount + CondCount - BodyCount);
|
|
547 RecordNextStmtCount = true;
|
|
548 }
|
|
549
|
|
550 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
|
|
551 RecordStmtCount(S);
|
|
552 if (S->getInit())
|
|
553 Visit(S->getInit());
|
|
554 Visit(S->getLoopVarStmt());
|
|
555 Visit(S->getRangeStmt());
|
|
556 Visit(S->getBeginStmt());
|
|
557 Visit(S->getEndStmt());
|
|
558
|
|
559 uint64_t ParentCount = CurrentCount;
|
|
560 BreakContinueStack.push_back(BreakContinue());
|
|
561 // Visit the body region first. (This is basically the same as a while
|
|
562 // loop; see further comments in VisitWhileStmt.)
|
|
563 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
|
|
564 CountMap[S->getBody()] = BodyCount;
|
|
565 Visit(S->getBody());
|
|
566 uint64_t BackedgeCount = CurrentCount;
|
|
567 BreakContinue BC = BreakContinueStack.pop_back_val();
|
|
568
|
|
569 // The increment is essentially part of the body but it needs to include
|
|
570 // the count for all the continue statements.
|
|
571 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
|
|
572 CountMap[S->getInc()] = IncCount;
|
|
573 Visit(S->getInc());
|
|
574
|
|
575 // ...then go back and propagate counts through the condition.
|
|
576 uint64_t CondCount =
|
|
577 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
|
|
578 CountMap[S->getCond()] = CondCount;
|
|
579 Visit(S->getCond());
|
|
580 setCount(BC.BreakCount + CondCount - BodyCount);
|
|
581 RecordNextStmtCount = true;
|
|
582 }
|
|
583
|
|
584 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
|
|
585 RecordStmtCount(S);
|
|
586 Visit(S->getElement());
|
|
587 uint64_t ParentCount = CurrentCount;
|
|
588 BreakContinueStack.push_back(BreakContinue());
|
|
589 // Counter tracks the body of the loop.
|
|
590 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
|
|
591 CountMap[S->getBody()] = BodyCount;
|
|
592 Visit(S->getBody());
|
|
593 uint64_t BackedgeCount = CurrentCount;
|
|
594 BreakContinue BC = BreakContinueStack.pop_back_val();
|
|
595
|
|
596 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
|
|
597 BodyCount);
|
|
598 RecordNextStmtCount = true;
|
|
599 }
|
|
600
|
|
601 void VisitSwitchStmt(const SwitchStmt *S) {
|
|
602 RecordStmtCount(S);
|
|
603 if (S->getInit())
|
|
604 Visit(S->getInit());
|
|
605 Visit(S->getCond());
|
|
606 CurrentCount = 0;
|
|
607 BreakContinueStack.push_back(BreakContinue());
|
|
608 Visit(S->getBody());
|
|
609 // If the switch is inside a loop, add the continue counts.
|
|
610 BreakContinue BC = BreakContinueStack.pop_back_val();
|
|
611 if (!BreakContinueStack.empty())
|
|
612 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
|
|
613 // Counter tracks the exit block of the switch.
|
|
614 setCount(PGO.getRegionCount(S));
|
|
615 RecordNextStmtCount = true;
|
|
616 }
|
|
617
|
|
618 void VisitSwitchCase(const SwitchCase *S) {
|
|
619 RecordNextStmtCount = false;
|
|
620 // Counter for this particular case. This counts only jumps from the
|
|
621 // switch header and does not include fallthrough from the case before
|
|
622 // this one.
|
|
623 uint64_t CaseCount = PGO.getRegionCount(S);
|
|
624 setCount(CurrentCount + CaseCount);
|
|
625 // We need the count without fallthrough in the mapping, so it's more useful
|
|
626 // for branch probabilities.
|
|
627 CountMap[S] = CaseCount;
|
|
628 RecordNextStmtCount = true;
|
|
629 Visit(S->getSubStmt());
|
|
630 }
|
|
631
|
|
632 void VisitIfStmt(const IfStmt *S) {
|
|
633 RecordStmtCount(S);
|
|
634 uint64_t ParentCount = CurrentCount;
|
|
635 if (S->getInit())
|
|
636 Visit(S->getInit());
|
|
637 Visit(S->getCond());
|
|
638
|
|
639 // Counter tracks the "then" part of an if statement. The count for
|
|
640 // the "else" part, if it exists, will be calculated from this counter.
|
|
641 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
|
|
642 CountMap[S->getThen()] = ThenCount;
|
|
643 Visit(S->getThen());
|
|
644 uint64_t OutCount = CurrentCount;
|
|
645
|
|
646 uint64_t ElseCount = ParentCount - ThenCount;
|
|
647 if (S->getElse()) {
|
|
648 setCount(ElseCount);
|
|
649 CountMap[S->getElse()] = ElseCount;
|
|
650 Visit(S->getElse());
|
|
651 OutCount += CurrentCount;
|
|
652 } else
|
|
653 OutCount += ElseCount;
|
|
654 setCount(OutCount);
|
|
655 RecordNextStmtCount = true;
|
|
656 }
|
|
657
|
|
658 void VisitCXXTryStmt(const CXXTryStmt *S) {
|
|
659 RecordStmtCount(S);
|
|
660 Visit(S->getTryBlock());
|
|
661 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
|
|
662 Visit(S->getHandler(I));
|
|
663 // Counter tracks the continuation block of the try statement.
|
|
664 setCount(PGO.getRegionCount(S));
|
|
665 RecordNextStmtCount = true;
|
|
666 }
|
|
667
|
|
668 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
|
|
669 RecordNextStmtCount = false;
|
|
670 // Counter tracks the catch statement's handler block.
|
|
671 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
|
|
672 CountMap[S] = CatchCount;
|
|
673 Visit(S->getHandlerBlock());
|
|
674 }
|
|
675
|
|
676 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
|
|
677 RecordStmtCount(E);
|
|
678 uint64_t ParentCount = CurrentCount;
|
|
679 Visit(E->getCond());
|
|
680
|
|
681 // Counter tracks the "true" part of a conditional operator. The
|
|
682 // count in the "false" part will be calculated from this counter.
|
|
683 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
|
|
684 CountMap[E->getTrueExpr()] = TrueCount;
|
|
685 Visit(E->getTrueExpr());
|
|
686 uint64_t OutCount = CurrentCount;
|
|
687
|
|
688 uint64_t FalseCount = setCount(ParentCount - TrueCount);
|
|
689 CountMap[E->getFalseExpr()] = FalseCount;
|
|
690 Visit(E->getFalseExpr());
|
|
691 OutCount += CurrentCount;
|
|
692
|
|
693 setCount(OutCount);
|
|
694 RecordNextStmtCount = true;
|
|
695 }
|
|
696
|
|
697 void VisitBinLAnd(const BinaryOperator *E) {
|
|
698 RecordStmtCount(E);
|
|
699 uint64_t ParentCount = CurrentCount;
|
|
700 Visit(E->getLHS());
|
|
701 // Counter tracks the right hand side of a logical and operator.
|
|
702 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
|
|
703 CountMap[E->getRHS()] = RHSCount;
|
|
704 Visit(E->getRHS());
|
|
705 setCount(ParentCount + RHSCount - CurrentCount);
|
|
706 RecordNextStmtCount = true;
|
|
707 }
|
|
708
|
|
709 void VisitBinLOr(const BinaryOperator *E) {
|
|
710 RecordStmtCount(E);
|
|
711 uint64_t ParentCount = CurrentCount;
|
|
712 Visit(E->getLHS());
|
|
713 // Counter tracks the right hand side of a logical or operator.
|
|
714 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
|
|
715 CountMap[E->getRHS()] = RHSCount;
|
|
716 Visit(E->getRHS());
|
|
717 setCount(ParentCount + RHSCount - CurrentCount);
|
|
718 RecordNextStmtCount = true;
|
|
719 }
|
|
720 };
|
|
721 } // end anonymous namespace
|
|
722
|
|
723 void PGOHash::combine(HashType Type) {
|
|
724 // Check that we never combine 0 and only have six bits.
|
|
725 assert(Type && "Hash is invalid: unexpected type 0");
|
|
726 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
|
|
727
|
|
728 // Pass through MD5 if enough work has built up.
|
|
729 if (Count && Count % NumTypesPerWord == 0) {
|
|
730 using namespace llvm::support;
|
|
731 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
|
|
732 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
|
|
733 Working = 0;
|
|
734 }
|
|
735
|
|
736 // Accumulate the current type.
|
|
737 ++Count;
|
|
738 Working = Working << NumBitsPerType | Type;
|
|
739 }
|
|
740
|
|
741 uint64_t PGOHash::finalize() {
|
|
742 // Use Working as the hash directly if we never used MD5.
|
|
743 if (Count <= NumTypesPerWord)
|
|
744 // No need to byte swap here, since none of the math was endian-dependent.
|
|
745 // This number will be byte-swapped as required on endianness transitions,
|
|
746 // so we will see the same value on the other side.
|
|
747 return Working;
|
|
748
|
|
749 // Check for remaining work in Working.
|
|
750 if (Working)
|
|
751 MD5.update(Working);
|
|
752
|
|
753 // Finalize the MD5 and return the hash.
|
|
754 llvm::MD5::MD5Result Result;
|
|
755 MD5.final(Result);
|
|
756 using namespace llvm::support;
|
|
757 return Result.low();
|
|
758 }
|
|
759
|
|
760 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
|
|
761 const Decl *D = GD.getDecl();
|
|
762 if (!D->hasBody())
|
|
763 return;
|
|
764
|
|
765 bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
|
|
766 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
|
|
767 if (!InstrumentRegions && !PGOReader)
|
|
768 return;
|
|
769 if (D->isImplicit())
|
|
770 return;
|
|
771 // Constructors and destructors may be represented by several functions in IR.
|
|
772 // If so, instrument only base variant, others are implemented by delegation
|
|
773 // to the base one, it would be counted twice otherwise.
|
|
774 if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
|
|
775 if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
|
|
776 if (GD.getCtorType() != Ctor_Base &&
|
|
777 CodeGenFunction::IsConstructorDelegationValid(CCD))
|
|
778 return;
|
|
779 }
|
|
780 if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
|
|
781 return;
|
|
782
|
|
783 CGM.ClearUnusedCoverageMapping(D);
|
|
784 setFuncName(Fn);
|
|
785
|
|
786 mapRegionCounters(D);
|
|
787 if (CGM.getCodeGenOpts().CoverageMapping)
|
|
788 emitCounterRegionMapping(D);
|
|
789 if (PGOReader) {
|
|
790 SourceManager &SM = CGM.getContext().getSourceManager();
|
|
791 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
|
|
792 computeRegionCounts(D);
|
|
793 applyFunctionAttributes(PGOReader, Fn);
|
|
794 }
|
|
795 }
|
|
796
|
|
797 void CodeGenPGO::mapRegionCounters(const Decl *D) {
|
|
798 // Use the latest hash version when inserting instrumentation, but use the
|
|
799 // version in the indexed profile if we're reading PGO data.
|
|
800 PGOHashVersion HashVersion = PGO_HASH_LATEST;
|
|
801 if (auto *PGOReader = CGM.getPGOReader())
|
|
802 HashVersion = getPGOHashVersion(PGOReader, CGM);
|
|
803
|
|
804 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
|
|
805 MapRegionCounters Walker(HashVersion, *RegionCounterMap);
|
|
806 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
|
|
807 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
|
|
808 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
|
|
809 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
|
|
810 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
|
|
811 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
|
|
812 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
|
|
813 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
|
|
814 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
|
|
815 NumRegionCounters = Walker.NextCounter;
|
|
816 FunctionHash = Walker.Hash.finalize();
|
|
817 }
|
|
818
|
|
819 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
|
|
820 if (!D->getBody())
|
|
821 return true;
|
|
822
|
|
823 // Don't map the functions in system headers.
|
|
824 const auto &SM = CGM.getContext().getSourceManager();
|
|
825 auto Loc = D->getBody()->getBeginLoc();
|
|
826 return SM.isInSystemHeader(Loc);
|
|
827 }
|
|
828
|
|
829 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
|
|
830 if (skipRegionMappingForDecl(D))
|
|
831 return;
|
|
832
|
|
833 std::string CoverageMapping;
|
|
834 llvm::raw_string_ostream OS(CoverageMapping);
|
|
835 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
|
|
836 CGM.getContext().getSourceManager(),
|
|
837 CGM.getLangOpts(), RegionCounterMap.get());
|
|
838 MappingGen.emitCounterMapping(D, OS);
|
|
839 OS.flush();
|
|
840
|
|
841 if (CoverageMapping.empty())
|
|
842 return;
|
|
843
|
|
844 CGM.getCoverageMapping()->addFunctionMappingRecord(
|
|
845 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
|
|
846 }
|
|
847
|
|
848 void
|
|
849 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
|
|
850 llvm::GlobalValue::LinkageTypes Linkage) {
|
|
851 if (skipRegionMappingForDecl(D))
|
|
852 return;
|
|
853
|
|
854 std::string CoverageMapping;
|
|
855 llvm::raw_string_ostream OS(CoverageMapping);
|
|
856 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
|
|
857 CGM.getContext().getSourceManager(),
|
|
858 CGM.getLangOpts());
|
|
859 MappingGen.emitEmptyMapping(D, OS);
|
|
860 OS.flush();
|
|
861
|
|
862 if (CoverageMapping.empty())
|
|
863 return;
|
|
864
|
|
865 setFuncName(Name, Linkage);
|
|
866 CGM.getCoverageMapping()->addFunctionMappingRecord(
|
|
867 FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
|
|
868 }
|
|
869
|
|
870 void CodeGenPGO::computeRegionCounts(const Decl *D) {
|
|
871 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
|
|
872 ComputeRegionCounts Walker(*StmtCountMap, *this);
|
|
873 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
|
|
874 Walker.VisitFunctionDecl(FD);
|
|
875 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
|
|
876 Walker.VisitObjCMethodDecl(MD);
|
|
877 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
|
|
878 Walker.VisitBlockDecl(BD);
|
|
879 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
|
|
880 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
|
|
881 }
|
|
882
|
|
883 void
|
|
884 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
|
|
885 llvm::Function *Fn) {
|
|
886 if (!haveRegionCounts())
|
|
887 return;
|
|
888
|
|
889 uint64_t FunctionCount = getRegionCount(nullptr);
|
|
890 Fn->setEntryCount(FunctionCount);
|
|
891 }
|
|
892
|
|
893 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
|
|
894 llvm::Value *StepV) {
|
|
895 if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
|
|
896 return;
|
|
897 if (!Builder.GetInsertBlock())
|
|
898 return;
|
|
899
|
|
900 unsigned Counter = (*RegionCounterMap)[S];
|
|
901 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
|
|
902
|
|
903 llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
|
|
904 Builder.getInt64(FunctionHash),
|
|
905 Builder.getInt32(NumRegionCounters),
|
|
906 Builder.getInt32(Counter), StepV};
|
|
907 if (!StepV)
|
|
908 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
|
|
909 makeArrayRef(Args, 4));
|
|
910 else
|
|
911 Builder.CreateCall(
|
|
912 CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
|
|
913 makeArrayRef(Args));
|
|
914 }
|
|
915
|
|
916 // This method either inserts a call to the profile run-time during
|
|
917 // instrumentation or puts profile data into metadata for PGO use.
|
|
918 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
|
|
919 llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
|
|
920
|
|
921 if (!EnableValueProfiling)
|
|
922 return;
|
|
923
|
|
924 if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
|
|
925 return;
|
|
926
|
|
927 if (isa<llvm::Constant>(ValuePtr))
|
|
928 return;
|
|
929
|
|
930 bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
|
|
931 if (InstrumentValueSites && RegionCounterMap) {
|
|
932 auto BuilderInsertPoint = Builder.saveIP();
|
|
933 Builder.SetInsertPoint(ValueSite);
|
|
934 llvm::Value *Args[5] = {
|
|
935 llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
|
|
936 Builder.getInt64(FunctionHash),
|
|
937 Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
|
|
938 Builder.getInt32(ValueKind),
|
|
939 Builder.getInt32(NumValueSites[ValueKind]++)
|
|
940 };
|
|
941 Builder.CreateCall(
|
|
942 CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
|
|
943 Builder.restoreIP(BuilderInsertPoint);
|
|
944 return;
|
|
945 }
|
|
946
|
|
947 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
|
|
948 if (PGOReader && haveRegionCounts()) {
|
|
949 // We record the top most called three functions at each call site.
|
|
950 // Profile metadata contains "VP" string identifying this metadata
|
|
951 // as value profiling data, then a uint32_t value for the value profiling
|
|
952 // kind, a uint64_t value for the total number of times the call is
|
|
953 // executed, followed by the function hash and execution count (uint64_t)
|
|
954 // pairs for each function.
|
|
955 if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
|
|
956 return;
|
|
957
|
|
958 llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
|
|
959 (llvm::InstrProfValueKind)ValueKind,
|
|
960 NumValueSites[ValueKind]);
|
|
961
|
|
962 NumValueSites[ValueKind]++;
|
|
963 }
|
|
964 }
|
|
965
|
|
966 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
|
|
967 bool IsInMainFile) {
|
|
968 CGM.getPGOStats().addVisited(IsInMainFile);
|
|
969 RegionCounts.clear();
|
|
970 llvm::Expected<llvm::InstrProfRecord> RecordExpected =
|
|
971 PGOReader->getInstrProfRecord(FuncName, FunctionHash);
|
|
972 if (auto E = RecordExpected.takeError()) {
|
|
973 auto IPE = llvm::InstrProfError::take(std::move(E));
|
|
974 if (IPE == llvm::instrprof_error::unknown_function)
|
|
975 CGM.getPGOStats().addMissing(IsInMainFile);
|
|
976 else if (IPE == llvm::instrprof_error::hash_mismatch)
|
|
977 CGM.getPGOStats().addMismatched(IsInMainFile);
|
|
978 else if (IPE == llvm::instrprof_error::malformed)
|
|
979 // TODO: Consider a more specific warning for this case.
|
|
980 CGM.getPGOStats().addMismatched(IsInMainFile);
|
|
981 return;
|
|
982 }
|
|
983 ProfRecord =
|
|
984 std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
|
|
985 RegionCounts = ProfRecord->Counts;
|
|
986 }
|
|
987
|
|
988 /// Calculate what to divide by to scale weights.
|
|
989 ///
|
|
990 /// Given the maximum weight, calculate a divisor that will scale all the
|
|
991 /// weights to strictly less than UINT32_MAX.
|
|
992 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
|
|
993 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
|
|
994 }
|
|
995
|
|
996 /// Scale an individual branch weight (and add 1).
|
|
997 ///
|
|
998 /// Scale a 64-bit weight down to 32-bits using \c Scale.
|
|
999 ///
|
|
1000 /// According to Laplace's Rule of Succession, it is better to compute the
|
|
1001 /// weight based on the count plus 1, so universally add 1 to the value.
|
|
1002 ///
|
|
1003 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
|
|
1004 /// greater than \c Weight.
|
|
1005 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
|
|
1006 assert(Scale && "scale by 0?");
|
|
1007 uint64_t Scaled = Weight / Scale + 1;
|
|
1008 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
|
|
1009 return Scaled;
|
|
1010 }
|
|
1011
|
|
1012 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
|
|
1013 uint64_t FalseCount) {
|
|
1014 // Check for empty weights.
|
|
1015 if (!TrueCount && !FalseCount)
|
|
1016 return nullptr;
|
|
1017
|
|
1018 // Calculate how to scale down to 32-bits.
|
|
1019 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
|
|
1020
|
|
1021 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
|
|
1022 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
|
|
1023 scaleBranchWeight(FalseCount, Scale));
|
|
1024 }
|
|
1025
|
|
1026 llvm::MDNode *
|
|
1027 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
|
|
1028 // We need at least two elements to create meaningful weights.
|
|
1029 if (Weights.size() < 2)
|
|
1030 return nullptr;
|
|
1031
|
|
1032 // Check for empty weights.
|
|
1033 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
|
|
1034 if (MaxWeight == 0)
|
|
1035 return nullptr;
|
|
1036
|
|
1037 // Calculate how to scale down to 32-bits.
|
|
1038 uint64_t Scale = calculateWeightScale(MaxWeight);
|
|
1039
|
|
1040 SmallVector<uint32_t, 16> ScaledWeights;
|
|
1041 ScaledWeights.reserve(Weights.size());
|
|
1042 for (uint64_t W : Weights)
|
|
1043 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
|
|
1044
|
|
1045 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
|
|
1046 return MDHelper.createBranchWeights(ScaledWeights);
|
|
1047 }
|
|
1048
|
|
1049 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
|
|
1050 uint64_t LoopCount) {
|
|
1051 if (!PGO.haveRegionCounts())
|
|
1052 return nullptr;
|
|
1053 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
|
|
1054 assert(CondCount.hasValue() && "missing expected loop condition count");
|
|
1055 if (*CondCount == 0)
|
|
1056 return nullptr;
|
|
1057 return createProfileWeights(LoopCount,
|
|
1058 std::max(*CondCount, LoopCount) - LoopCount);
|
|
1059 }
|