llvm.org GIT mirror llvm / 9a22bfd
[CodeExtractor] Update function's assumption cache after extracting blocks from it Summary: Assumption cache's self-updating mechanism does not correctly handle the case when blocks are extracted from the function by the CodeExtractor. As a result function's assumption cache may have stale references to the llvm.assume calls that were moved to the outlined function. This patch fixes this problem by removing extracted llvm.assume calls from the function’s assumption cache. Reviewers: hfinkel, vsk, fhahn, davidxl, sanjoy Reviewed By: hfinkel, vsk Subscribers: llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D57215 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@353500 91177308-0d34-0410-b5e6-96231b3b80d8 Sergey Dmitriev 7 months ago
8 changed file(s) with 145 addition(s) and 34 deletion(s). Raw diff Collapse all Expand all
101101 /// The call passed in must be an instruction within this function and must
102102 /// not already be in the cache.
103103 void registerAssumption(CallInst *CI);
104
105 /// Remove an \@llvm.assume intrinsic from this function's cache if it has
106 /// been added to the cache earlier.
107 void unregisterAssumption(CallInst *CI);
104108
105109 /// Update the cache of values being affected by this assumption (i.e.
106110 /// the values about which this assumption provides information).
207211 /// existing cache will be returned.
208212 AssumptionCache &getAssumptionCache(Function &F);
209213
214 /// Return the cached assumptions for a function if it has already been
215 /// scanned. Otherwise return nullptr.
216 AssumptionCache *lookupAssumptionCache(Function &F);
217
210218 AssumptionCacheTracker();
211219 ~AssumptionCacheTracker() override;
212220
2525 class BlockFrequency;
2626 class BlockFrequencyInfo;
2727 class BranchProbabilityInfo;
28 class AssumptionCache;
2829 class CallInst;
2930 class DominatorTree;
3031 class Function;
5556 const bool AggregateArgs;
5657 BlockFrequencyInfo *BFI;
5758 BranchProbabilityInfo *BPI;
59 AssumptionCache *AC;
5860
5961 // If true, varargs functions can be extracted.
6062 bool AllowVarArgs;
8385 CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr,
8486 bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
8587 BranchProbabilityInfo *BPI = nullptr,
88 AssumptionCache *AC = nullptr,
8689 bool AllowVarArgs = false, bool AllowAlloca = false,
8790 std::string Suffix = "");
8891
9396 CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
9497 BlockFrequencyInfo *BFI = nullptr,
9598 BranchProbabilityInfo *BPI = nullptr,
99 AssumptionCache *AC = nullptr,
96100 std::string Suffix = "");
97101
98102 /// Perform the extraction, returning the new function.
5252 return AVIP.first->second;
5353 }
5454
55 void AssumptionCache::updateAffectedValues(CallInst *CI) {
55 static void findAffectedValues(CallInst *CI,
56 SmallVectorImpl &Affected) {
5657 // Note: This code must be kept in-sync with the code in
5758 // computeKnownBitsFromAssume in ValueTracking.
5859
59 SmallVector Affected;
6060 auto AddAffected = [&Affected](Value *V) {
6161 if (isa(V)) {
6262 Affected.push_back(V);
107107 AddAffectedFromEq(B);
108108 }
109109 }
110 }
111
112 void AssumptionCache::updateAffectedValues(CallInst *CI) {
113 SmallVector Affected;
114 findAffectedValues(CI, Affected);
110115
111116 for (auto &AV : Affected) {
112117 auto &AVV = getOrInsertAffectedValues(AV);
113118 if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
114119 AVV.push_back(CI);
115120 }
121 }
122
123 void AssumptionCache::unregisterAssumption(CallInst *CI) {
124 SmallVector Affected;
125 findAffectedValues(CI, Affected);
126
127 for (auto &AV : Affected) {
128 auto AVI = AffectedValues.find_as(AV);
129 if (AVI != AffectedValues.end())
130 AffectedValues.erase(AVI);
131 }
132 remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; });
116133 }
117134
118135 void AssumptionCache::AffectedValueCallbackVH::deleted() {
239256 return *IP.first->second;
240257 }
241258
259 AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) {
260 auto I = AssumptionCaches.find_as(&F);
261 if (I != AssumptionCaches.end())
262 return I->second.get();
263 return nullptr;
264 }
265
242266 void AssumptionCacheTracker::verifyAnalysis() const {
243267 // FIXME: In the long term the verifier should not be controllable with a
244268 // flag. We should either fix all passes to correctly update the assumption
172172 HotColdSplitting(ProfileSummaryInfo *ProfSI,
173173 function_ref GBFI,
174174 function_ref GTTI,
175 std::function *GORE)
176 : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {}
175 std::function *GORE,
176 function_ref LAC)
177 : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {}
177178 bool run(Module &M);
178179
179180 private:
182183 bool outlineColdRegions(Function &F, bool HasProfileSummary);
183184 Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT,
184185 BlockFrequencyInfo *BFI, TargetTransformInfo &TTI,
185 OptimizationRemarkEmitter &ORE, unsigned Count);
186 OptimizationRemarkEmitter &ORE,
187 AssumptionCache *AC, unsigned Count);
186188 ProfileSummaryInfo *PSI;
187189 function_ref GetBFI;
188190 function_ref GetTTI;
189191 std::function *GetORE;
192 function_ref LookupAC;
190193 };
191194
192195 class HotColdSplittingLegacyPass : public ModulePass {
197200 }
198201
199202 void getAnalysisUsage(AnalysisUsage &AU) const override {
200 AU.addRequired();
201203 AU.addRequired();
202204 AU.addRequired();
203205 AU.addRequired();
206 AU.addUsedIfAvailable();
204207 }
205208
206209 bool runOnModule(Module &M) override;
315318 BlockFrequencyInfo *BFI,
316319 TargetTransformInfo &TTI,
317320 OptimizationRemarkEmitter &ORE,
321 AssumptionCache *AC,
318322 unsigned Count) {
319323 assert(!Region.empty());
320324
321325 // TODO: Pass BFI and BPI to update profile information.
322326 CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr,
323 /* BPI */ nullptr, /* AllowVarArgs */ false,
327 /* BPI */ nullptr, AC, /* AllowVarArgs */ false,
324328 /* AllowAlloca */ false,
325329 /* Suffix */ "cold." + std::to_string(Count));
326330
576580
577581 TargetTransformInfo &TTI = GetTTI(F);
578582 OptimizationRemarkEmitter &ORE = (*GetORE)(F);
583 AssumptionCache *AC = LookupAC(F);
579584
580585 // Find all cold regions.
581586 for (BasicBlock *BB : RPOT) {
637642 BB->dump();
638643 });
639644
640 Function *Outlined =
641 extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, OutlinedFunctionID);
645 Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC,
646 OutlinedFunctionID);
642647 if (Outlined) {
643648 ++OutlinedFunctionID;
644649 Changed = true;
697702 ORE.reset(new OptimizationRemarkEmitter(&F));
698703 return *ORE.get();
699704 };
700
701 return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M);
705 auto LookupAC = [this](Function &F) -> AssumptionCache * {
706 if (auto *ACT = getAnalysisIfAvailable())
707 return ACT->lookupAssumptionCache(F);
708 return nullptr;
709 };
710
711 return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M);
702712 }
703713
704714 PreservedAnalyses
705715 HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) {
706716 auto &FAM = AM.getResult(M).getManager();
707717
708 std::function GetAssumptionCache =
709 [&FAM](Function &F) -> AssumptionCache & {
710 return FAM.getResult(F);
718 auto LookupAC = [&FAM](Function &F) -> AssumptionCache * {
719 return FAM.getCachedResult(F);
711720 };
712721
713722 auto GBFI = [&FAM](Function &F) {
728737
729738 ProfileSummaryInfo *PSI = &AM.getResult(M);
730739
731 if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M))
740 if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M))
732741 return PreservedAnalyses::none();
733742 return PreservedAnalyses::all();
734743 }
1313 //===----------------------------------------------------------------------===//
1414
1515 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/AssumptionCache.h"
1617 #include "llvm/Analysis/LoopPass.h"
1718 #include "llvm/IR/Dominators.h"
1819 #include "llvm/IR/Instructions.h"
4950 AU.addRequiredID(LoopSimplifyID);
5051 AU.addRequired();
5152 AU.addRequired();
53 AU.addUsedIfAvailable();
5254 }
5355 };
5456 }
137139 if (ShouldExtractLoop) {
138140 if (NumLoops == 0) return Changed;
139141 --NumLoops;
140 CodeExtractor Extractor(DT, *L);
142 AssumptionCache *AC = nullptr;
143 if (auto *ACT = getAnalysisIfAvailable())
144 AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent());
145 CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
141146 if (Extractor.extractCodeRegion() != nullptr) {
142147 Changed = true;
143148 // After extraction, the loop is replaced by a function call, so
198198
199199 PartialInlinerImpl(
200200 std::function *GetAC,
201 function_ref LookupAC,
201202 std::function *GTTI,
202203 Optional> GBFI,
203204 ProfileSummaryInfo *ProfSI)
204 : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}
205 : GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC),
206 GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}
205207
206208 bool run(Module &M);
207209 // Main part of the transformation that calls helper functions to find
221223 // Two constructors, one for single region outlining, the other for
222224 // multi-region outlining.
223225 FunctionCloner(Function *F, FunctionOutliningInfo *OI,
224 OptimizationRemarkEmitter &ORE);
226 OptimizationRemarkEmitter &ORE,
227 function_ref LookupAC);
225228 FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI,
226 OptimizationRemarkEmitter &ORE);
229 OptimizationRemarkEmitter &ORE,
230 function_ref LookupAC);
227231 ~FunctionCloner();
228232
229233 // Prepare for function outlining: making sure there is only
259263 std::unique_ptr ClonedOMRI = nullptr;
260264 std::unique_ptr ClonedFuncBFI = nullptr;
261265 OptimizationRemarkEmitter &ORE;
266 function_ref LookupAC;
262267 };
263268
264269 private:
265270 int NumPartialInlining = 0;
266271 std::function *GetAssumptionCache;
272 function_ref LookupAssumptionCache;
267273 std::function *GetTTI;
268274 Optional> GetBFI;
269275 ProfileSummaryInfo *PSI;
364370 return ACT->getAssumptionCache(F);
365371 };
366372
373 auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * {
374 return ACT->lookupAssumptionCache(F);
375 };
376
367377 std::function GetTTI =
368378 [&TTIWP](Function &F) -> TargetTransformInfo & {
369379 return TTIWP->getTTI(F);
370380 };
371381
372 return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI)
382 return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache,
383 &GetTTI, NoneType::None, PSI)
373384 .run(M);
374385 }
375386 };
947958 }
948959
949960 PartialInlinerImpl::FunctionCloner::FunctionCloner(
950 Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE)
951 : OrigFunc(F), ORE(ORE) {
961 Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE,
962 function_ref LookupAC)
963 : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
952964 ClonedOI = llvm::make_unique();
953965
954966 // Clone the function, so that we can hack away on it.
971983
972984 PartialInlinerImpl::FunctionCloner::FunctionCloner(
973985 Function *F, FunctionOutliningMultiRegionInfo *OI,
974 OptimizationRemarkEmitter &ORE)
975 : OrigFunc(F), ORE(ORE) {
986 OptimizationRemarkEmitter &ORE,
987 function_ref LookupAC)
988 : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
976989 ClonedOMRI = llvm::make_unique();
977990
978991 // Clone the function, so that we can hack away on it.
11101123 int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region);
11111124
11121125 CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false,
1113 ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false);
1126 ClonedFuncBFI.get(), &BPI,
1127 LookupAC(*RegionInfo.EntryBlock->getParent()),
1128 /* AllowVarargs */ false);
11141129
11151130 CE.findInputsOutputs(Inputs, Outputs, Sinks);
11161131
11921207 // Extract the body of the if.
11931208 Function *OutlinedFunc =
11941209 CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
1195 ClonedFuncBFI.get(), &BPI,
1210 ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc),
11961211 /* AllowVarargs */ true)
11971212 .extractCodeRegion();
11981213
12561271 std::unique_ptr OMRI =
12571272 computeOutliningColdRegionsInfo(F, ORE);
12581273 if (OMRI) {
1259 FunctionCloner Cloner(F, OMRI.get(), ORE);
1274 FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache);
12601275
12611276 #ifndef NDEBUG
12621277 if (TracePartialInlining) {
12891304 if (!OI)
12901305 return {false, nullptr};
12911306
1292 FunctionCloner Cloner(F, OI.get(), ORE);
1307 FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache);
12931308 Cloner.NormalizeReturnBlock();
12941309
12951310 Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining();
14831498 return FAM.getResult(F);
14841499 };
14851500
1501 auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
1502 return FAM.getCachedResult(F);
1503 };
1504
14861505 std::function GetBFI =
14871506 [&FAM](Function &F) -> BlockFrequencyInfo & {
14881507 return FAM.getResult(F);
14951514
14961515 ProfileSummaryInfo *PSI = &AM.getResult(M);
14971516
1498 if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI)
1517 if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI,
1518 {GetBFI}, PSI)
14991519 .run(M))
15001520 return PreservedAnalyses::none();
15011521 return PreservedAnalyses::all();
1919 #include "llvm/ADT/SetVector.h"
2020 #include "llvm/ADT/SmallPtrSet.h"
2121 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Analysis/AssumptionCache.h"
2223 #include "llvm/Analysis/BlockFrequencyInfo.h"
2324 #include "llvm/Analysis/BlockFrequencyInfoImpl.h"
2425 #include "llvm/Analysis/BranchProbabilityInfo.h"
4243 #include "llvm/IR/LLVMContext.h"
4344 #include "llvm/IR/MDBuilder.h"
4445 #include "llvm/IR/Module.h"
46 #include "llvm/IR/PatternMatch.h"
4547 #include "llvm/IR/Type.h"
4648 #include "llvm/IR/User.h"
4749 #include "llvm/IR/Value.h"
6567 #include
6668
6769 using namespace llvm;
70 using namespace llvm::PatternMatch;
6871 using ProfileCount = Function::ProfileCount;
6972
7073 #define DEBUG_TYPE "code-extractor"
234237
235238 CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT,
236239 bool AggregateArgs, BlockFrequencyInfo *BFI,
237 BranchProbabilityInfo *BPI, bool AllowVarArgs,
238 bool AllowAlloca, std::string Suffix)
240 BranchProbabilityInfo *BPI, AssumptionCache *AC,
241 bool AllowVarArgs, bool AllowAlloca,
242 std::string Suffix)
239243 : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
240 BPI(BPI), AllowVarArgs(AllowVarArgs),
244 BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs),
241245 Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
242246 Suffix(Suffix) {}
243247
244248 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
245249 BlockFrequencyInfo *BFI,
246 BranchProbabilityInfo *BPI, std::string Suffix)
250 BranchProbabilityInfo *BPI, AssumptionCache *AC,
251 std::string Suffix)
247252 : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
248 BPI(BPI), AllowVarArgs(false),
253 BPI(BPI), AC(AC), AllowVarArgs(false),
249254 Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
250255 /* AllowVarArgs */ false,
251256 /* AllowAlloca */ false)),
12151220
12161221 // Insert this basic block into the new function
12171222 newBlocks.push_back(Block);
1223
1224 // Remove @llvm.assume calls that were moved to the new function from the
1225 // old function's assumption cache.
1226 if (AC)
1227 for (auto &I : *Block)
1228 if (match(&I, m_Intrinsic()))
1229 AC->unregisterAssumption(cast(&I));
12181230 }
12191231 }
12201232
0 ; RUN: opt -passes="function(slp-vectorizer),module(hotcoldsplit),function(slp-vectorizer,print)" -disable-output %s 2>&1 | FileCheck %s
1 ;
2 ; Make sure this compiles. Check that function assumption cache is refreshed
3 ; after extracting blocks with assume calls from the function.
4
5 ; CHECK: Cached assumptions for function: fun
6 ; CHECK-NEXT: Cached assumptions for function: fun.cold
7 ; CHECK-NEXT: %cmp = icmp uge i32 %x, 64
8
9 declare void @fun2(i32) #0
10
11 define void @fun(i32 %x) {
12 entry:
13 br i1 undef, label %if.then, label %if.else
14
15 if.then:
16 ret void
17
18 if.else:
19 %cmp = icmp uge i32 %x, 64
20 call void @llvm.assume(i1 %cmp)
21 call void @fun2(i32 %x)
22 unreachable
23 }
24
25 declare void @llvm.assume(i1) #1
26
27 attributes #0 = { alwaysinline }
28 attributes #1 = { nounwind }