llvm.org GIT mirror llvm / a99158b
Update the branch weight metadata in JumpThreading pass. Currently in JumpThreading pass, the branch weight metadata is not updated after CFG modification. Consider the jump threading on PredBB, BB, and SuccBB. After jump threading, the weight on BB->SuccBB should be adjusted as some of it is contributed by the edge PredBB->BB, which doesn't exist anymore. This patch tries to update the edge weight in metadata on BB->SuccBB by scaling it by 1 - Freq(PredBB->BB) / Freq(BB->SuccBB). This is the third attempt to submit this patch, while the first two led to failures in some FDO tests. After investigation, it is the edge weight normalization that caused those failures. In this patch the edge weight normalization is fixed so that there is no zero weight in the output and the sum of all weights can fit in 32-bit integer. Several unit tests are added. Differential revision: http://reviews.llvm.org/D10979 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@250345 91177308-0d34-0410-b5e6-96231b3b80d8 Cong Hou 4 years ago
8 changed file(s) with 286 addition(s) and 6 deletion(s). Raw diff Collapse all Expand all
4444 /// floating points.
4545 BlockFrequency getBlockFreq(const BasicBlock *BB) const;
4646
47 // Set the frequency of the given basic block.
48 void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
49
4750 /// calculate - compute block frequency info for the given function.
4851 void calculate(const Function &F, const BranchProbabilityInfo &BPI,
4952 const LoopInfo &LI);
475475 Scaled64 getFloatingBlockFreq(const BlockNode &Node) const;
476476
477477 BlockFrequency getBlockFreq(const BlockNode &Node) const;
478
479 void setBlockFreq(const BlockNode &Node, uint64_t Freq);
478480
479481 raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const;
480482 raw_ostream &printBlockFreq(raw_ostream &OS,
912914 BlockFrequency getBlockFreq(const BlockT *BB) const {
913915 return BlockFrequencyInfoImplBase::getBlockFreq(getNode(BB));
914916 }
917 void setBlockFreq(const BlockT *BB, uint64_t Freq);
915918 Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
916919 return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB));
917920 }
962965 computeMassInFunction();
963966 unwrapLoops();
964967 finalizeMetrics();
968 }
969
970 template
971 void BlockFrequencyInfoImpl::setBlockFreq(const BlockT *BB, uint64_t Freq) {
972 if (Nodes.count(BB))
973 BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq);
974 else {
975 // If BB is a newly added block after BFI is done, we need to create a new
976 // BlockNode for it assigned with a new index. The index can be determined
977 // by the size of Freqs.
978 BlockNode NewNode(Freqs.size());
979 Nodes[BB] = NewNode;
980 Freqs.emplace_back();
981 BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq);
982 }
965983 }
966984
967985 template void BlockFrequencyInfoImpl::initializeRPOT() {
1414 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
1515
1616 #include "llvm/Support/DataTypes.h"
17 #include
1718 #include
19 #include
20 #include
1821
1922 namespace llvm {
2023
5154 // one.
5255 template
5356 static void normalizeProbabilities(ProbabilityList &Probs);
57
58 // Normalize a list of weights by scaling them down so that the sum of them
59 // doesn't exceed UINT32_MAX.
60 template
61 static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
5462
5563 uint32_t getNumerator() const { return N; }
5664 static uint32_t getDenominator() { return D; }
134142 Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum;
135143 }
136144
145 template
146 void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
147 WeightListIter End) {
148 // First we compute the sum with 64-bits of precision.
149 uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
150
151 if (Sum > UINT32_MAX) {
152 // Compute the scale necessary to cause the weights to fit, and re-sum with
153 // that scale applied.
154 assert(Sum / UINT32_MAX < UINT32_MAX &&
155 "The sum of weights exceeds UINT32_MAX^2!");
156 uint32_t Scale = Sum / UINT32_MAX + 1;
157 for (auto I = Begin; I != End; ++I)
158 *I /= Scale;
159 Sum = std::accumulate(Begin, End, uint64_t(0));
160 }
161
162 // Eliminate zero weights.
163 auto ZeroWeightNum = std::count(Begin, End, 0u);
164 if (ZeroWeightNum > 0) {
165 // If all weights are zeros, replace them by 1.
166 if (Sum == 0)
167 std::fill(Begin, End, 1u);
168 else {
169 // We are converting zeros into ones, and here we need to make sure that
170 // after this the sum won't exceed UINT32_MAX.
171 if (Sum + ZeroWeightNum > UINT32_MAX) {
172 for (auto I = Begin; I != End; ++I)
173 *I /= 2;
174 ZeroWeightNum = std::count(Begin, End, 0u);
175 Sum = std::accumulate(Begin, End, uint64_t(0));
176 }
177 // Scale up non-zero weights and turn zero weights into ones.
178 uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
179 assert(ScalingFactor >= 1);
180 if (ScalingFactor > 1)
181 for (auto I = Begin; I != End; ++I)
182 *I *= ScalingFactor;
183 std::replace(Begin, End, 0u, 1u);
184 }
185 }
186 }
187
137188 }
138189
139190 #endif
128128 return BFI ? BFI->getBlockFreq(BB) : 0;
129129 }
130130
131 void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB,
132 uint64_t Freq) {
133 assert(BFI && "Expected analysis to be available");
134 BFI->setBlockFreq(BB, Freq);
135 }
136
131137 /// Pop up a ghostview window with the current block frequency propagation
132138 /// rendered using dot.
133139 void BlockFrequencyInfo::view() const {
529529 return Freqs[Node.Index].Scaled;
530530 }
531531
532 void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node,
533 uint64_t Freq) {
534 assert(Node.isValid() && "Expected valid node");
535 assert(Node.Index < Freqs.size() && "Expected legal index");
536 Freqs[Node.Index].Integer = Freq;
537 }
538
532539 std::string
533540 BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const {
534541 return std::string();
1919 #include "llvm/ADT/Statistic.h"
2020 #include "llvm/Analysis/GlobalsModRef.h"
2121 #include "llvm/Analysis/CFG.h"
22 #include "llvm/Analysis/BlockFrequencyInfo.h"
23 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
24 #include "llvm/Analysis/BranchProbabilityInfo.h"
2225 #include "llvm/Analysis/ConstantFolding.h"
2326 #include "llvm/Analysis/InstructionSimplify.h"
2427 #include "llvm/Analysis/LazyValueInfo.h"
2528 #include "llvm/Analysis/Loads.h"
29 #include "llvm/Analysis/LoopInfo.h"
2630 #include "llvm/Analysis/TargetLibraryInfo.h"
2731 #include "llvm/IR/DataLayout.h"
2832 #include "llvm/IR/IntrinsicInst.h"
2933 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/MDBuilder.h"
3035 #include "llvm/IR/Metadata.h"
3136 #include "llvm/IR/ValueHandle.h"
3237 #include "llvm/Pass.h"
3641 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
3742 #include "llvm/Transforms/Utils/Local.h"
3843 #include "llvm/Transforms/Utils/SSAUpdater.h"
44 #include
45 #include
3946 using namespace llvm;
4047
4148 #define DEBUG_TYPE "jump-threading"
8087 class JumpThreading : public FunctionPass {
8188 TargetLibraryInfo *TLI;
8289 LazyValueInfo *LVI;
90 std::unique_ptr BFI;
91 std::unique_ptr BPI;
92 bool HasProfileData;
8393 #ifdef NDEBUG
8494 SmallPtrSet LoopHeaders;
8595 #else
118128 AU.addRequired();
119129 }
120130
131 void releaseMemory() override {
132 BFI.reset();
133 BPI.reset();
134 }
135
121136 void FindLoopHeaders(Function &F);
122137 bool ProcessBlock(BasicBlock *BB);
123138 bool ThreadEdge(BasicBlock *BB, const SmallVectorImpl &PredBBs,
138153
139154 bool SimplifyPartiallyRedundantLoad(LoadInst *LI);
140155 bool TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB);
156
157 private:
158 BasicBlock *SplitBlockPreds(BasicBlock *BB, ArrayRef Preds,
159 const char *Suffix);
160 void UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, BasicBlock *BB,
161 BasicBlock *NewBB, BasicBlock *SuccBB);
141162 };
142163 }
143164
161182 DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n");
162183 TLI = &getAnalysis().getTLI();
163184 LVI = &getAnalysis();
185 BFI.reset();
186 BPI.reset();
187 // When profile data is available, we need to update edge weights after
188 // successful jump threading, which requires both BPI and BFI being available.
189 HasProfileData = F.getEntryCount().hasValue();
190 if (HasProfileData) {
191 LoopInfo LI{DominatorTree(F)};
192 BPI.reset(new BranchProbabilityInfo(F, LI));
193 BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));
194 }
164195
165196 // Remove unreachable blocks from function as they may result in infinite
166197 // loop. We do threading if we found something profitable. Jump threading a
9761007 }
9771008
9781009 // Split them out to their own block.
979 UnavailablePred =
980 SplitBlockPredecessors(LoadBB, PredsToSplit, "thread-pre-split");
1010 UnavailablePred = SplitBlockPreds(LoadBB, PredsToSplit, "thread-pre-split");
9811011 }
9821012
9831013 // If the value isn't available in all predecessors, then there will be
14021432 else {
14031433 DEBUG(dbgs() << " Factoring out " << PredBBs.size()
14041434 << " common predecessors.\n");
1405 PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm");
1435 PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm");
14061436 }
14071437
14081438 // And finally, do it!
14221452 BB->getName()+".thread",
14231453 BB->getParent(), BB);
14241454 NewBB->moveAfter(PredBB);
1455
1456 // Set the block frequency of NewBB.
1457 if (HasProfileData) {
1458 auto NewBBFreq =
1459 BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB);
1460 BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
1461 }
14251462
14261463 BasicBlock::iterator BI = BB->begin();
14271464 for (; PHINode *PN = dyn_cast(BI); ++BI)
14461483
14471484 // We didn't copy the terminator from BB over to NewBB, because there is now
14481485 // an unconditional jump to SuccBB. Insert the unconditional jump.
1449 BranchInst *NewBI =BranchInst::Create(SuccBB, NewBB);
1486 BranchInst *NewBI = BranchInst::Create(SuccBB, NewBB);
14501487 NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc());
14511488
14521489 // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the
15071544 // frequently happens because of phi translation.
15081545 SimplifyInstructionsInBlock(NewBB, TLI);
15091546
1547 // Update the edge weight from BB to SuccBB, which should be less than before.
1548 UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB);
1549
15101550 // Threaded an edge!
15111551 ++NumThreads;
15121552 return true;
1553 }
1554
1555 /// Create a new basic block that will be the predecessor of BB and successor of
1556 /// all blocks in Preds. When profile data is availble, update the frequency of
1557 /// this new block.
1558 BasicBlock *JumpThreading::SplitBlockPreds(BasicBlock *BB,
1559 ArrayRef Preds,
1560 const char *Suffix) {
1561 // Collect the frequencies of all predecessors of BB, which will be used to
1562 // update the edge weight on BB->SuccBB.
1563 BlockFrequency PredBBFreq(0);
1564 if (HasProfileData)
1565 for (auto Pred : Preds)
1566 PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB);
1567
1568 BasicBlock *PredBB = SplitBlockPredecessors(BB, Preds, Suffix);
1569
1570 // Set the block frequency of the newly created PredBB, which is the sum of
1571 // frequencies of Preds.
1572 if (HasProfileData)
1573 BFI->setBlockFreq(PredBB, PredBBFreq.getFrequency());
1574 return PredBB;
1575 }
1576
1577 /// Update the block frequency of BB and branch weight and the metadata on the
1578 /// edge BB->SuccBB. This is done by scaling the weight of BB->SuccBB by 1 -
1579 /// Freq(PredBB->BB) / Freq(BB->SuccBB).
1580 void JumpThreading::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
1581 BasicBlock *BB,
1582 BasicBlock *NewBB,
1583 BasicBlock *SuccBB) {
1584 if (!HasProfileData)
1585 return;
1586
1587 assert(BFI && BPI && "BFI & BPI should have been created here");
1588
1589 // As the edge from PredBB to BB is deleted, we have to update the block
1590 // frequency of BB.
1591 auto BBOrigFreq = BFI->getBlockFreq(BB);
1592 auto NewBBFreq = BFI->getBlockFreq(NewBB);
1593 auto BB2SuccBBFreq = BBOrigFreq * BPI->getEdgeProbability(BB, SuccBB);
1594 auto BBNewFreq = BBOrigFreq - NewBBFreq;
1595 BFI->setBlockFreq(BB, BBNewFreq.getFrequency());
1596
1597 // Collect updated outgoing edges' frequencies from BB and use them to update
1598 // edge weights.
1599 SmallVector BBSuccFreq;
1600 for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
1601 auto SuccFreq = (*I == SuccBB)
1602 ? BB2SuccBBFreq - NewBBFreq
1603 : BBOrigFreq * BPI->getEdgeProbability(BB, *I);
1604 BBSuccFreq.push_back(SuccFreq.getFrequency());
1605 }
1606
1607 // Normalize edge weights in Weights64 so that the sum of them can fit in
1608 BranchProbability::normalizeEdgeWeights(BBSuccFreq.begin(), BBSuccFreq.end());
1609
1610 SmallVector Weights;
1611 for (auto Freq : BBSuccFreq)
1612 Weights.push_back(static_cast(Freq));
1613
1614 // Update edge weights in BPI.
1615 for (int I = 0, E = Weights.size(); I < E; I++)
1616 BPI->setEdgeWeight(BB, I, Weights[I]);
1617
1618 if (Weights.size() >= 2) {
1619 auto TI = BB->getTerminator();
1620 TI->setMetadata(
1621 LLVMContext::MD_prof,
1622 MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
1623 }
15131624 }
15141625
15151626 /// DuplicateCondBranchOnPHIIntoPred - PredBB contains an unconditional branch
15451656 else {
15461657 DEBUG(dbgs() << " Factoring out " << PredBBs.size()
15471658 << " common predecessors.\n");
1548 PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm");
1659 PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm");
15491660 }
15501661
15511662 // Okay, we decided to do this! Clone all the instructions in BB onto the end
0 ; RUN: opt -S -jump-threading %s | FileCheck %s
1
2 ; Test if edge weights are properly updated after jump threading.
3
4 ; CHECK: !2 = !{!"branch_weights", i32 22, i32 7}
5
6 define void @foo(i32 %n) !prof !0 {
7 entry:
8 %cmp = icmp sgt i32 %n, 10
9 br i1 %cmp, label %if.then.1, label %if.else.1, !prof !1
10
11 if.then.1:
12 tail call void @a()
13 br label %if.cond
14
15 if.else.1:
16 tail call void @b()
17 br label %if.cond
18
19 if.cond:
20 %cmp1 = icmp sgt i32 %n, 5
21 br i1 %cmp1, label %if.then.2, label %if.else.2, !prof !2
22
23 if.then.2:
24 tail call void @c()
25 br label %if.end
26
27 if.else.2:
28 tail call void @d()
29 br label %if.end
30
31 if.end:
32 ret void
33 }
34
35 declare void @a()
36 declare void @b()
37 declare void @c()
38 declare void @d()
39
40 !0 = !{!"function_entry_count", i64 1}
41 !1 = !{!"branch_weights", i32 10, i32 5}
42 !2 = !{!"branch_weights", i32 10, i32 1}
286286 }
287287 }
288288
289 }
289 TEST(BranchProbabilityTest, NormalizeEdgeWeights) {
290 {
291 SmallVector Weights{0, 0};
292 BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
293 EXPECT_EQ(1u, Weights[0]);
294 EXPECT_EQ(1u, Weights[1]);
295 }
296 {
297 SmallVector Weights{0, UINT32_MAX};
298 BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
299 EXPECT_EQ(1u, Weights[0]);
300 EXPECT_EQ(UINT32_MAX - 1u, Weights[1]);
301 }
302 {
303 SmallVector Weights{1, UINT32_MAX};
304 BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
305 EXPECT_EQ(1u, Weights[0]);
306 EXPECT_EQ(UINT32_MAX - 1u, Weights[1]);
307 }
308 {
309 SmallVector Weights{0, 0, UINT32_MAX};
310 BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
311 EXPECT_EQ(1u, Weights[0]);
312 EXPECT_EQ(1u, Weights[1]);
313 EXPECT_EQ(UINT32_MAX / 2u, Weights[2]);
314 }
315 {
316 SmallVector Weights{UINT32_MAX, UINT32_MAX};
317 BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
318 EXPECT_EQ(UINT32_MAX / 3u, Weights[0]);
319 EXPECT_EQ(UINT32_MAX / 3u, Weights[1]);
320 }
321 {
322 SmallVector Weights{UINT32_MAX, UINT32_MAX, UINT32_MAX};
323 BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
324 EXPECT_EQ(UINT32_MAX / 4u, Weights[0]);
325 EXPECT_EQ(UINT32_MAX / 4u, Weights[1]);
326 EXPECT_EQ(UINT32_MAX / 4u, Weights[2]);
327 }
328 }
329
330 }