llvm.org GIT mirror llvm / 0f7cbe1
CodeExtractor : Add ability to preserve profile data. Added ability to estimate the entry count of the extracted function and the branch probabilities of the exit branches. Patch by River Riddle! Differential Revision: https://reviews.llvm.org/D22744 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@277313 91177308-0d34-0410-b5e6-96231b3b80d8 Sean Silva 4 years ago
9 changed file(s) with 195 addition(s) and 27 deletion(s). Raw diff Collapse all Expand all
5959 /// This computes the relative block frequency of \p BB and multiplies it by
6060 /// the enclosing function's count (if available) and returns the value.
6161 Optional getBlockProfileCount(const BasicBlock *BB) const;
62
63 /// \brief Returns the estimated profile count of \p Freq.
64 /// This uses the frequency \p Freq and multiplies it by
65 /// the enclosing function's count (if available) and returns the value.
66 Optional getProfileCountFromFreq(uint64_t Freq) const;
6267
6368 // Set the frequency of the given basic block.
6469 void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
481481 BlockFrequency getBlockFreq(const BlockNode &Node) const;
482482 Optional getBlockProfileCount(const Function &F,
483483 const BlockNode &Node) const;
484 Optional getProfileCountFromFreq(const Function &F,
485 uint64_t Freq) const;
484486
485487 void setBlockFreq(const BlockNode &Node, uint64_t Freq);
486488
923925 Optional getBlockProfileCount(const Function &F,
924926 const BlockT *BB) const {
925927 return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB));
928 }
929 Optional getProfileCountFromFreq(const Function &F,
930 uint64_t Freq) const {
931 return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq);
926932 }
927933 void setBlockFreq(const BlockT *BB, uint64_t Freq);
928934 Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
5151 BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
5252
5353 Optional getBlockProfileCount(const MachineBasicBlock *MBB) const;
54 Optional getProfileCountFromFreq(uint64_t Freq) const;
5455
5556 const MachineFunction *getFunction() const;
5657 const MachineBranchProbabilityInfo *getMBPI() const;
1919 namespace llvm {
2020 template class ArrayRef;
2121 class BasicBlock;
22 class BlockFrequency;
23 class BlockFrequencyInfo;
24 class BranchProbabilityInfo;
2225 class DominatorTree;
2326 class Function;
2427 class Loop;
4649 // Various bits of state computed on construction.
4750 DominatorTree *const DT;
4851 const bool AggregateArgs;
52 BlockFrequencyInfo *BFI;
53 BranchProbabilityInfo *BPI;
4954
5055 // Bits of intermediate state computed at various phases of extraction.
5156 SetVector Blocks;
6368 ///
6469 /// In this formation, we don't require a dominator tree. The given basic
6570 /// block is set up for extraction.
66 CodeExtractor(BasicBlock *BB, bool AggregateArgs = false);
71 CodeExtractor(BasicBlock *BB, bool AggregateArgs = false,
72 BlockFrequencyInfo *BFI = nullptr,
73 BranchProbabilityInfo *BPI = nullptr);
6774
6875 /// \brief Create a code extractor for a sequence of blocks.
6976 ///
7279 /// sequence out into its new function. When a DominatorTree is also given,
7380 /// extra checking and transformations are enabled.
7481 CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr,
75 bool AggregateArgs = false);
82 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
83 BranchProbabilityInfo *BPI = nullptr);
7684
7785 /// \brief Create a code extractor for a loop body.
7886 ///
7987 /// Behaves just like the generic code sequence constructor, but uses the
8088 /// block sequence of the loop.
81 CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false);
89 CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
90 BlockFrequencyInfo *BFI = nullptr,
91 BranchProbabilityInfo *BPI = nullptr);
8292
8393 /// \brief Create a code extractor for a region node.
8494 ///
8595 /// Behaves just like the generic code sequence constructor, but uses the
8696 /// block sequence of the region node passed in.
8797 CodeExtractor(DominatorTree &DT, const RegionNode &RN,
88 bool AggregateArgs = false);
98 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
99 BranchProbabilityInfo *BPI = nullptr);
89100
90101 /// \brief Perform the extraction, returning the new function.
91102 ///
121132
122133 void moveCodeToFunction(Function *newFunction);
123134
135 void calculateNewCallTerminatorWeights(
136 BasicBlock *CodeReplacer,
137 DenseMap &ExitWeights,
138 BranchProbabilityInfo *BPI);
139
124140 void emitCallAndSwitchStatement(Function *newFunction,
125141 BasicBlock *newHeader,
126142 ValueSet &inputs,
161161 return BFI->getBlockProfileCount(*getFunction(), BB);
162162 }
163163
164 Optional
165 BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
166 if (!BFI)
167 return None;
168 return BFI->getProfileCountFromFreq(*getFunction(), Freq);
169 }
170
164171 void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
165172 assert(BFI && "Expected analysis to be available");
166173 BFI->setBlockFreq(BB, Freq);
532532 Optional
533533 BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
534534 const BlockNode &Node) const {
535 return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency());
536 }
537
538 Optional
539 BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
540 uint64_t Freq) const {
535541 auto EntryCount = F.getEntryCount();
536542 if (!EntryCount)
537543 return None;
538544 // Use 128 bit APInt to do the arithmetic to avoid overflow.
539545 APInt BlockCount(128, EntryCount.getValue());
540 APInt BlockFreq(128, getBlockFreq(Node).getFrequency());
546 APInt BlockFreq(128, Freq);
541547 APInt EntryFreq(128, getEntryFreq());
542548 BlockCount *= BlockFreq;
543549 BlockCount = BlockCount.udiv(EntryFreq);
174174 return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None;
175175 }
176176
177 Optional
178 MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
179 const Function *F = MBFI->getFunction()->getFunction();
180 return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None;
181 }
182
177183 const MachineFunction *MachineBlockFrequencyInfo::getFunction() const {
178184 return MBFI ? MBFI->getFunction() : nullptr;
179185 }
1313
1414 #include "llvm/Transforms/IPO/PartialInlining.h"
1515 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/BlockFrequencyInfo.h"
17 #include "llvm/Analysis/BranchProbabilityInfo.h"
1618 #include "llvm/IR/CFG.h"
1719 #include "llvm/IR/Dominators.h"
1820 #include "llvm/IR/Instructions.h"
2830 STATISTIC(NumPartialInlined, "Number of functions partially inlined");
2931
3032 namespace {
33 typedef std::function(
34 Function &)>
35 GetProfileDataFn;
3136 struct PartialInlinerImpl {
32 PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {}
37 PartialInlinerImpl(InlineFunctionInfo IFI, GetProfileDataFn GetProfileInfo)
38 : IFI(IFI), GetProfileInfo(GetProfileInfo) {}
3339 bool run(Module &M);
3440 Function *unswitchFunction(Function *F);
3541
3642 private:
3743 InlineFunctionInfo IFI;
44 GetProfileDataFn GetProfileInfo;
3845 };
3946 struct PartialInlinerLegacyPass : public ModulePass {
4047 static char ID; // Pass identification, replacement for typeid
4451
4552 void getAnalysisUsage(AnalysisUsage &AU) const override {
4653 AU.addRequired();
54 AU.addRequired();
55 AU.addRequired();
4756 }
4857 bool runOnModule(Module &M) override {
4958 if (skipModule(M))
5463 [&ACT](Function &F) -> AssumptionCache & {
5564 return ACT->getAssumptionCache(F);
5665 };
66 GetProfileDataFn GetProfileData = [this](Function &F)
67 -> std::pair {
68 auto *BFI = &getAnalysis(F).getBFI();
69 auto *BPI = &getAnalysis(F).getBPI();
70 return std::make_pair(BFI, BPI);
71 };
5772 InlineFunctionInfo IFI(nullptr, &GetAssumptionCache);
58 return PartialInlinerImpl(IFI).run(M);
73 return PartialInlinerImpl(IFI, GetProfileData).run(M);
5974 }
6075 };
6176 }
132147 DominatorTree DT;
133148 DT.recalculate(*DuplicateFunction);
134149
150 auto ProfileInfo = GetProfileInfo(*DuplicateFunction);
151
135152 // Extract the body of the if.
136153 Function *ExtractedFunction =
137 CodeExtractor(ToExtract, &DT).extractCodeRegion();
154 CodeExtractor(ToExtract, &DT, /*AggregateArgs*/false, ProfileInfo.first,
155 ProfileInfo.second)
156 .extractCodeRegion();
138157
139158 // Inline the top-level if test into all callers.
140159 std::vector Users(DuplicateFunction->user_begin(),
180199 if (Recursive)
181200 continue;
182201
183 if (Function *newFunc = unswitchFunction(CurrFunc)) {
184 Worklist.push_back(newFunc);
202 if (Function *NewFunc = unswitchFunction(CurrFunc)) {
203 Worklist.push_back(NewFunc);
185204 Changed = true;
186205 }
187206 }
193212 INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner",
194213 "Partial Inliner", false, false)
195214 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
215 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
216 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
196217 INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner",
197218 "Partial Inliner", false, false)
198219
207228 [&FAM](Function &F) -> AssumptionCache & {
208229 return FAM.getResult(F);
209230 };
231 GetProfileDataFn GetProfileData = [&FAM](
232 Function &F) -> std::pair {
233 auto *BFI = &FAM.getResult(F);
234 auto *BPI = &FAM.getResult(F);
235 return std::make_pair(BFI, BPI);
236 };
210237 InlineFunctionInfo IFI(nullptr, &GetAssumptionCache);
211 if (PartialInlinerImpl(IFI).run(M))
238 if (PartialInlinerImpl(IFI, GetProfileData).run(M))
212239 return PreservedAnalyses::none();
213240 return PreservedAnalyses::all();
214241 }
1616 #include "llvm/ADT/STLExtras.h"
1717 #include "llvm/ADT/SetVector.h"
1818 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
21 #include "llvm/Analysis/BranchProbabilityInfo.h"
1922 #include "llvm/Analysis/LoopInfo.h"
2023 #include "llvm/Analysis/RegionInfo.h"
2124 #include "llvm/Analysis/RegionIterator.h"
2528 #include "llvm/IR/Instructions.h"
2629 #include "llvm/IR/Intrinsics.h"
2730 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/MDBuilder.h"
2832 #include "llvm/IR/Module.h"
2933 #include "llvm/IR/Verifier.h"
3034 #include "llvm/Pass.h"
35 #include "llvm/Support/BlockFrequency.h"
3136 #include "llvm/Support/CommandLine.h"
3237 #include "llvm/Support/Debug.h"
3338 #include "llvm/Support/ErrorHandling.h"
118123 return buildExtractionBlockSet(R.block_begin(), R.block_end());
119124 }
120125
121 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
122 : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
123 Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
126 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
127 BlockFrequencyInfo *BFI,
128 BranchProbabilityInfo *BPI)
129 : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
130 BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
124131
125132 CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT,
126 bool AggregateArgs)
127 : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
128 Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
129
130 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
131 : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
132 Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
133 bool AggregateArgs, BlockFrequencyInfo *BFI,
134 BranchProbabilityInfo *BPI)
135 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
136 BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
137
138 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
139 BlockFrequencyInfo *BFI,
140 BranchProbabilityInfo *BPI)
141 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
142 BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
143 NumExitBlocks(~0U) {}
133144
134145 CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
135 bool AggregateArgs)
136 : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
137 Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
146 bool AggregateArgs, BlockFrequencyInfo *BFI,
147 BranchProbabilityInfo *BPI)
148 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
149 BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
138150
139151 /// definedInRegion - Return true if the specified value is defined in the
140152 /// extracted region.
671683 }
672684 }
673685
686 void CodeExtractor::calculateNewCallTerminatorWeights(
687 BasicBlock *CodeReplacer,
688 DenseMap &ExitWeights,
689 BranchProbabilityInfo *BPI) {
690 typedef BlockFrequencyInfoImplBase::Distribution Distribution;
691 typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
692
693 // Update the branch weights for the exit block.
694 TerminatorInst *TI = CodeReplacer->getTerminator();
695 SmallVector BranchWeights(TI->getNumSuccessors(), 0);
696
697 // Block Frequency distribution with dummy node.
698 Distribution BranchDist;
699
700 // Add each of the frequencies of the successors.
701 for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
702 BlockNode ExitNode(i);
703 uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
704 if (ExitFreq != 0)
705 BranchDist.addExit(ExitNode, ExitFreq);
706 else
707 BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
708 }
709
710 // Check for no total weight.
711 if (BranchDist.Total == 0)
712 return;
713
714 // Normalize the distribution so that they can fit in unsigned.
715 BranchDist.normalize();
716
717 // Create normalized branch weights and set the metadata.
718 for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
719 const auto &Weight = BranchDist.Weights[I];
720
721 // Get the weight and update the current BFI.
722 BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
723 BranchProbability BP(Weight.Amount, BranchDist.Total);
724 BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
725 }
726 TI->setMetadata(
727 LLVMContext::MD_prof,
728 MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
729 }
730
674731 Function *CodeExtractor::extractCodeRegion() {
675732 if (!isEligible())
676733 return nullptr;
680737 // Assumption: this is a single-entry code region, and the header is the first
681738 // block in the region.
682739 BasicBlock *header = *Blocks.begin();
740
741 // Calculate the entry frequency of the new function before we change the root
742 // block.
743 BlockFrequency EntryFreq;
744 if (BFI) {
745 assert(BPI && "Both BPI and BFI are required to preserve profile info");
746 for (BasicBlock *Pred : predecessors(header)) {
747 if (Blocks.count(Pred))
748 continue;
749 EntryFreq +=
750 BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
751 }
752 }
683753
684754 // If we have to split PHI nodes or the entry block, do so now.
685755 severSplitPHINodes(header);
704774 // Find inputs to, outputs from the code region.
705775 findInputsOutputs(inputs, outputs);
706776
777 // Calculate the exit blocks for the extracted region and the total exit
778 // weights for each of those blocks.
779 DenseMap ExitWeights;
707780 SmallPtrSet ExitBlocks;
708 for (BasicBlock *Block : Blocks)
781 for (BasicBlock *Block : Blocks) {
709782 for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
710 ++SI)
711 if (!Blocks.count(*SI))
783 ++SI) {
784 if (!Blocks.count(*SI)) {
785 // Update the branch weight for this successor.
786 if (BFI) {
787 BlockFrequency &BF = ExitWeights[*SI];
788 BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
789 }
712790 ExitBlocks.insert(*SI);
791 }
792 }
793 }
713794 NumExitBlocks = ExitBlocks.size();
714795
715796 // Construct new function based on inputs/outputs & add allocas for all defs.
718799 codeReplacer, oldFunction,
719800 oldFunction->getParent());
720801
802 // Update the entry count of the function.
803 if (BFI) {
804 Optional EntryCount =
805 BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
806 if (EntryCount.hasValue())
807 newFunction->setEntryCount(EntryCount.getValue());
808 BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
809 }
810
721811 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
722812
723813 moveCodeToFunction(newFunction);
814
815 // Update the branch weights for the exit block.
816 if (BFI && NumExitBlocks > 1)
817 calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
724818
725819 // Loop over all of the PHI nodes in the header block, and change any
726820 // references to the old incoming edge to be the new incoming edge.