llvm.org GIT mirror llvm / e9e0746
CodeExtractor : Add ability to preserve profile data. Added ability to estimate the entry count of the extracted function and the branch probabilities of the exit branches. Patch by River Riddle! Differential Revision: https://reviews.llvm.org/D22744 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@277411 91177308-0d34-0410-b5e6-96231b3b80d8 Sean Silva 4 years ago
11 changed file(s) with 241 addition(s) and 24 deletion(s). Raw diff Collapse all Expand all
5959 /// This computes the relative block frequency of \p BB and multiplies it by
6060 /// the enclosing function's count (if available) and returns the value.
6161 Optional getBlockProfileCount(const BasicBlock *BB) const;
62
63 /// \brief Returns the estimated profile count of \p Freq.
64 /// This uses the frequency \p Freq and multiplies it by
65 /// the enclosing function's count (if available) and returns the value.
66 Optional getProfileCountFromFreq(uint64_t Freq) const;
6267
6368 // Set the frequency of the given basic block.
6469 void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
481481 BlockFrequency getBlockFreq(const BlockNode &Node) const;
482482 Optional getBlockProfileCount(const Function &F,
483483 const BlockNode &Node) const;
484 Optional getProfileCountFromFreq(const Function &F,
485 uint64_t Freq) const;
484486
485487 void setBlockFreq(const BlockNode &Node, uint64_t Freq);
486488
923925 Optional getBlockProfileCount(const Function &F,
924926 const BlockT *BB) const {
925927 return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB));
928 }
929 Optional getProfileCountFromFreq(const Function &F,
930 uint64_t Freq) const {
931 return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq);
926932 }
927933 void setBlockFreq(const BlockT *BB, uint64_t Freq);
928934 Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
5151 BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
5252
5353 Optional getBlockProfileCount(const MachineBasicBlock *MBB) const;
54 Optional getProfileCountFromFreq(uint64_t Freq) const;
5455
5556 const MachineFunction *getFunction() const;
5657 const MachineBranchProbabilityInfo *getMBPI() const;
1919 namespace llvm {
2020 template class ArrayRef;
2121 class BasicBlock;
22 class BlockFrequency;
23 class BlockFrequencyInfo;
24 class BranchProbabilityInfo;
2225 class DominatorTree;
2326 class Function;
2427 class Loop;
4649 // Various bits of state computed on construction.
4750 DominatorTree *const DT;
4851 const bool AggregateArgs;
52 BlockFrequencyInfo *BFI;
53 BranchProbabilityInfo *BPI;
4954
5055 // Bits of intermediate state computed at various phases of extraction.
5156 SetVector Blocks;
6368 ///
6469 /// In this formation, we don't require a dominator tree. The given basic
6570 /// block is set up for extraction.
66 CodeExtractor(BasicBlock *BB, bool AggregateArgs = false);
71 CodeExtractor(BasicBlock *BB, bool AggregateArgs = false,
72 BlockFrequencyInfo *BFI = nullptr,
73 BranchProbabilityInfo *BPI = nullptr);
6774
6875 /// \brief Create a code extractor for a sequence of blocks.
6976 ///
7279 /// sequence out into its new function. When a DominatorTree is also given,
7380 /// extra checking and transformations are enabled.
7481 CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr,
75 bool AggregateArgs = false);
82 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
83 BranchProbabilityInfo *BPI = nullptr);
7684
7785 /// \brief Create a code extractor for a loop body.
7886 ///
7987 /// Behaves just like the generic code sequence constructor, but uses the
8088 /// block sequence of the loop.
81 CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false);
89 CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
90 BlockFrequencyInfo *BFI = nullptr,
91 BranchProbabilityInfo *BPI = nullptr);
8292
8393 /// \brief Create a code extractor for a region node.
8494 ///
8595 /// Behaves just like the generic code sequence constructor, but uses the
8696 /// block sequence of the region node passed in.
8797 CodeExtractor(DominatorTree &DT, const RegionNode &RN,
88 bool AggregateArgs = false);
98 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
99 BranchProbabilityInfo *BPI = nullptr);
89100
90101 /// \brief Perform the extraction, returning the new function.
91102 ///
121132
122133 void moveCodeToFunction(Function *newFunction);
123134
135 void calculateNewCallTerminatorWeights(
136 BasicBlock *CodeReplacer,
137 DenseMap &ExitWeights,
138 BranchProbabilityInfo *BPI);
139
124140 void emitCallAndSwitchStatement(Function *newFunction,
125141 BasicBlock *newHeader,
126142 ValueSet &inputs,
161161 return BFI->getBlockProfileCount(*getFunction(), BB);
162162 }
163163
164 Optional
165 BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
166 if (!BFI)
167 return None;
168 return BFI->getProfileCountFromFreq(*getFunction(), Freq);
169 }
170
164171 void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
165172 assert(BFI && "Expected analysis to be available");
166173 BFI->setBlockFreq(BB, Freq);
532532 Optional
533533 BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
534534 const BlockNode &Node) const {
535 return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency());
536 }
537
538 Optional
539 BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
540 uint64_t Freq) const {
535541 auto EntryCount = F.getEntryCount();
536542 if (!EntryCount)
537543 return None;
538544 // Use 128 bit APInt to do the arithmetic to avoid overflow.
539545 APInt BlockCount(128, EntryCount.getValue());
540 APInt BlockFreq(128, getBlockFreq(Node).getFrequency());
546 APInt BlockFreq(128, Freq);
541547 APInt EntryFreq(128, getEntryFreq());
542548 BlockCount *= BlockFreq;
543549 BlockCount = BlockCount.udiv(EntryFreq);
174174 return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None;
175175 }
176176
177 Optional
178 MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
179 const Function *F = MBFI->getFunction()->getFunction();
180 return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None;
181 }
182
177183 const MachineFunction *MachineBlockFrequencyInfo::getFunction() const {
178184 return MBFI ? MBFI->getFunction() : nullptr;
179185 }
1313
1414 #include "llvm/Transforms/IPO/PartialInlining.h"
1515 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/BlockFrequencyInfo.h"
17 #include "llvm/Analysis/BranchProbabilityInfo.h"
18 #include "llvm/Analysis/LoopInfo.h"
1619 #include "llvm/IR/CFG.h"
1720 #include "llvm/IR/Dominators.h"
1821 #include "llvm/IR/Instructions.h"
132135 DominatorTree DT;
133136 DT.recalculate(*DuplicateFunction);
134137
138 // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo.
139 LoopInfo LI(DT);
140 BranchProbabilityInfo BPI(*DuplicateFunction, LI);
141 BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI);
142
135143 // Extract the body of the if.
136144 Function *ExtractedFunction =
137 CodeExtractor(ToExtract, &DT).extractCodeRegion();
145 CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI)
146 .extractCodeRegion();
138147
139148 // Inline the top-level if test into all callers.
140149 std::vector Users(DuplicateFunction->user_begin(),
180189 if (Recursive)
181190 continue;
182191
183 if (Function *newFunc = unswitchFunction(CurrFunc)) {
184 Worklist.push_back(newFunc);
192 if (Function *NewFunc = unswitchFunction(CurrFunc)) {
193 Worklist.push_back(NewFunc);
185194 Changed = true;
186195 }
187196 }
1616 #include "llvm/ADT/STLExtras.h"
1717 #include "llvm/ADT/SetVector.h"
1818 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
21 #include "llvm/Analysis/BranchProbabilityInfo.h"
1922 #include "llvm/Analysis/LoopInfo.h"
2023 #include "llvm/Analysis/RegionInfo.h"
2124 #include "llvm/Analysis/RegionIterator.h"
2528 #include "llvm/IR/Instructions.h"
2629 #include "llvm/IR/Intrinsics.h"
2730 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/MDBuilder.h"
2832 #include "llvm/IR/Module.h"
2933 #include "llvm/IR/Verifier.h"
3034 #include "llvm/Pass.h"
35 #include "llvm/Support/BlockFrequency.h"
3136 #include "llvm/Support/CommandLine.h"
3237 #include "llvm/Support/Debug.h"
3338 #include "llvm/Support/ErrorHandling.h"
118123 return buildExtractionBlockSet(R.block_begin(), R.block_end());
119124 }
120125
121 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
122 : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
123 Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
126 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
127 BlockFrequencyInfo *BFI,
128 BranchProbabilityInfo *BPI)
129 : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
130 BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
124131
125132 CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT,
126 bool AggregateArgs)
127 : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
128 Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
129
130 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
131 : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
132 Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
133 bool AggregateArgs, BlockFrequencyInfo *BFI,
134 BranchProbabilityInfo *BPI)
135 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
136 BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
137
138 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
139 BlockFrequencyInfo *BFI,
140 BranchProbabilityInfo *BPI)
141 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
142 BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
143 NumExitBlocks(~0U) {}
133144
134145 CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
135 bool AggregateArgs)
136 : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
137 Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
146 bool AggregateArgs, BlockFrequencyInfo *BFI,
147 BranchProbabilityInfo *BPI)
148 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
149 BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
138150
139151 /// definedInRegion - Return true if the specified value is defined in the
140152 /// extracted region.
686698 }
687699 }
688700
701 void CodeExtractor::calculateNewCallTerminatorWeights(
702 BasicBlock *CodeReplacer,
703 DenseMap &ExitWeights,
704 BranchProbabilityInfo *BPI) {
705 typedef BlockFrequencyInfoImplBase::Distribution Distribution;
706 typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
707
708 // Update the branch weights for the exit block.
709 TerminatorInst *TI = CodeReplacer->getTerminator();
710 SmallVector BranchWeights(TI->getNumSuccessors(), 0);
711
712 // Block Frequency distribution with dummy node.
713 Distribution BranchDist;
714
715 // Add each of the frequencies of the successors.
716 for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
717 BlockNode ExitNode(i);
718 uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
719 if (ExitFreq != 0)
720 BranchDist.addExit(ExitNode, ExitFreq);
721 else
722 BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
723 }
724
725 // Check for no total weight.
726 if (BranchDist.Total == 0)
727 return;
728
729 // Normalize the distribution so that they can fit in unsigned.
730 BranchDist.normalize();
731
732 // Create normalized branch weights and set the metadata.
733 for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
734 const auto &Weight = BranchDist.Weights[I];
735
736 // Get the weight and update the current BFI.
737 BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
738 BranchProbability BP(Weight.Amount, BranchDist.Total);
739 BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
740 }
741 TI->setMetadata(
742 LLVMContext::MD_prof,
743 MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
744 }
745
689746 Function *CodeExtractor::extractCodeRegion() {
690747 if (!isEligible())
691748 return nullptr;
695752 // Assumption: this is a single-entry code region, and the header is the first
696753 // block in the region.
697754 BasicBlock *header = *Blocks.begin();
755
756 // Calculate the entry frequency of the new function before we change the root
757 // block.
758 BlockFrequency EntryFreq;
759 if (BFI) {
760 assert(BPI && "Both BPI and BFI are required to preserve profile info");
761 for (BasicBlock *Pred : predecessors(header)) {
762 if (Blocks.count(Pred))
763 continue;
764 EntryFreq +=
765 BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
766 }
767 }
698768
699769 // If we have to split PHI nodes or the entry block, do so now.
700770 severSplitPHINodes(header);
719789 // Find inputs to, outputs from the code region.
720790 findInputsOutputs(inputs, outputs);
721791
792 // Calculate the exit blocks for the extracted region and the total exit
793 // weights for each of those blocks.
794 DenseMap ExitWeights;
722795 SmallPtrSet ExitBlocks;
723 for (BasicBlock *Block : Blocks)
796 for (BasicBlock *Block : Blocks) {
724797 for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
725 ++SI)
726 if (!Blocks.count(*SI))
798 ++SI) {
799 if (!Blocks.count(*SI)) {
800 // Update the branch weight for this successor.
801 if (BFI) {
802 BlockFrequency &BF = ExitWeights[*SI];
803 BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
804 }
727805 ExitBlocks.insert(*SI);
806 }
807 }
808 }
728809 NumExitBlocks = ExitBlocks.size();
729810
730811 // Construct new function based on inputs/outputs & add allocas for all defs.
733814 codeReplacer, oldFunction,
734815 oldFunction->getParent());
735816
817 // Update the entry count of the function.
818 if (BFI) {
819 Optional EntryCount =
820 BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
821 if (EntryCount.hasValue())
822 newFunction->setEntryCount(EntryCount.getValue());
823 BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
824 }
825
736826 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
737827
738828 moveCodeToFunction(newFunction);
829
830 // Update the branch weights for the exit block.
831 if (BFI && NumExitBlocks > 1)
832 calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
739833
740834 // Loop over all of the PHI nodes in the header block, and change any
741835 // references to the old incoming edge to be the new incoming edge.
0 ; RUN: opt < %s -partial-inliner -S | FileCheck %s
1
2 ; This test checks to make sure that the CodeExtractor
3 ; properly sets the entry count for the function that is
4 ; extracted based on the root block being extracted and also
5 ; takes into consideration if the block has edges coming from
6 ; a block that is also being extracted.
7
8 define i32 @inlinedFunc(i1 %cond) !prof !1 {
9 entry:
10 br i1 %cond, label %if.then, label %return, !prof !2
11 if.then:
12 br i1 %cond, label %if.then, label %return, !prof !3
13 return: ; preds = %entry
14 ret i32 0
15 }
16
17
18 define internal i32 @dummyCaller(i1 %cond) !prof !1 {
19 entry:
20 %val = call i32 @inlinedFunc(i1 %cond)
21 ret i32 %val
22 }
23
24 ; CHECK: @inlinedFunc.1_if.then(i1 %cond) !prof [[COUNT1:![0-9]+]]
25
26
27 !llvm.module.flags = !{!0}
28 ; CHECK: [[COUNT1]] = !{!"function_entry_count", i64 250}
29 !0 = !{i32 1, !"MaxFunctionCount", i32 1000}
30 !1 = !{!"function_entry_count", i64 1000}
31 !2 = !{!"branch_weights", i32 250, i32 750}
32 !3 = !{!"branch_weights", i32 125, i32 125}
0 ; RUN: opt < %s -partial-inliner -S | FileCheck %s
1
2 ; This test checks to make sure that CodeExtractor updates
3 ; the exit branch probabilities for multiple exit blocks.
4
5 define i32 @inlinedFunc(i1 %cond) !prof !1 {
6 entry:
7 br i1 %cond, label %if.then, label %return, !prof !2
8 if.then:
9 br i1 %cond, label %return, label %return.2, !prof !3
10 return.2:
11 ret i32 10
12 return: ; preds = %entry
13 ret i32 0
14 }
15
16
17 define internal i32 @dummyCaller(i1 %cond) !prof !1 {
18 entry:
19 %val = call i32 @inlinedFunc(i1 %cond)
20 ret i32 %val
21
22 ; CHECK-LABEL: @dummyCaller
23 ; CHECK: call
24 ; CHECK-NEXT: br i1 {{.*}}!prof [[COUNT1:![0-9]+]]
25 }
26
27 !llvm.module.flags = !{!0}
28 !0 = !{i32 1, !"MaxFunctionCount", i32 10000}
29 !1 = !{!"function_entry_count", i64 10000}
30 !2 = !{!"branch_weights", i32 5, i32 5}
31 !3 = !{!"branch_weights", i32 4, i32 1}
32
33 ; CHECK: [[COUNT1]] = !{!"branch_weights", i32 8, i32 31}