llvm.org GIT mirror llvm / 9517cbd
Allow LLE/LD and the loop versioning infrastructure to use SCEV predicates Summary: LAA currently generates a set of SCEV predicates that must be checked by users. In the case of Loop Distribute/Loop Load Elimination, no such predicates could have been emitted, since we don't allow stride versioning. However, in the future there could be SCEV predicates that will need to be checked. This change adds support for SCEV predicate versioning in the Loop Distribute, Loop Load Eliminate and the loop versioning infrastructure. Reviewers: anemet Subscribers: mssimpso, sanjoy, llvm-commits Differential Revision: http://reviews.llvm.org/D14240 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@252467 91177308-0d34-0410-b5e6-96231b3b80d8 Silviu Baranga 4 years ago
8 changed file(s) with 116 addition(s) and 44 deletion(s). Raw diff Collapse all Expand all
192192
193193 /// \brief Returns the estimated complexity of this predicate.
194194 /// This is roughly measured in the number of run-time checks required.
195 virtual unsigned getComplexity() { return 1; }
195 virtual unsigned getComplexity() const { return 1; }
196196
197197 /// \brief Returns true if the predicate is always true. This means that no
198198 /// assumptions were made and nothing needs to be checked at run-time.
302302
303303 /// \brief We estimate the complexity of a union predicate as the size
304304 /// number of predicates in the union.
305 unsigned getComplexity() override { return Preds.size(); }
305 unsigned getComplexity() const override { return Preds.size(); }
306306
307307 /// Methods for support type inquiry through isa, cast, and dyn_cast:
308308 static inline bool classof(const SCEVPredicate *P) {
1616 #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H
1717
1818 #include "llvm/Analysis/LoopAccessAnalysis.h"
19 #include "llvm/Analysis/ScalarEvolution.h"
1920 #include "llvm/Transforms/Utils/ValueMapper.h"
2021 #include "llvm/Transforms/Utils/LoopUtils.h"
2122
2425 class Loop;
2526 class LoopAccessInfo;
2627 class LoopInfo;
28 class ScalarEvolution;
2729
2830 /// \brief This class emits a version of the loop where run-time checks ensure
2931 /// that may-alias pointers can't overlap.
3234 /// already has a preheader.
3335 class LoopVersioning {
3436 public:
35 /// \brief Expects MemCheck, LoopAccessInfo, Loop, LoopInfo, DominatorTree
36 /// as input. It uses runtime check provided by user.
37 LoopVersioning(SmallVector Checks,
38 const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
39 DominatorTree *DT);
40
4137 /// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input.
42 /// It uses default runtime check provided by LoopAccessInfo.
43 LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI,
44 DominatorTree *DT);
38 /// It uses runtime check provided by the user. If \p UseLAIChecks is true,
39 /// we will retain the default checks made by LAI. Otherwise, construct an
40 /// object having no checks and we expect the user to add them.
41 LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
42 DominatorTree *DT, ScalarEvolution *SE,
43 bool UseLAIChecks = true);
4544
4645 /// \brief Performs the CFG manipulation part of versioning the loop including
4746 /// the DominatorTree and LoopInfo updates.
7170 /// loop may alias (i.e. one of the memchecks failed).
7271 Loop *getNonVersionedLoop() { return NonVersionedLoop; }
7372
73 /// \brief Sets the runtime alias checks for versioning the loop.
74 void setAliasChecks(
75 const SmallVector Checks);
76
77 /// \brief Sets the runtime SCEV checks for versioning the loop.
78 void setSCEVChecks(SCEVUnionPredicate Check);
79
7480 private:
7581 /// \brief Adds the necessary PHI nodes for the versioned loops based on the
7682 /// loop-defined values used outside of the loop.
9096 /// in NonVersionedLoop.
9197 ValueToValueMapTy VMap;
9298
93 /// \brief The set of checks that we are versioning for.
94 SmallVector Checks;
99 /// \brief The set of alias checks that we are versioning for.
100 SmallVector AliasChecks;
101
102 /// \brief The set of SCEV checks that we are versioning for.
103 SCEVUnionPredicate Preds;
95104
96105 /// \brief Analyses used.
97106 const LoopAccessInfo &LAI;
98107 LoopInfo *LI;
99108 DominatorTree *DT;
109 ScalarEvolution *SE;
100110 };
101111 }
102112
5454 "if-convertible by the loop vectorizer"),
5555 cl::init(false));
5656
57 static cl::opt DistributeSCEVCheckThreshold(
58 "loop-distribute-scev-check-threshold", cl::init(8), cl::Hidden,
59 cl::desc("The maximum number of SCEV checks allowed for Loop "
60 "Distribution"));
61
5762 STATISTIC(NumLoopsDistributed, "Number of loops distributed");
5863
5964 namespace {
576581 LI = &getAnalysis().getLoopInfo();
577582 LAA = &getAnalysis();
578583 DT = &getAnalysis().getDomTree();
584 SE = &getAnalysis().getSE();
579585
580586 // Build up a worklist of inner-loops to vectorize. This is necessary as the
581587 // act of distributing a loop creates new loops and can invalidate iterators
598604 }
599605
600606 void getAnalysisUsage(AnalysisUsage &AU) const override {
607 AU.addRequired();
601608 AU.addRequired();
602609 AU.addPreserved();
603610 AU.addRequired();
752759 return false;
753760 }
754761
762 // Don't distribute the loop if we need too many SCEV run-time checks.
763 const SCEVUnionPredicate &Pred = LAI.Preds;
764 if (Pred.getComplexity() > DistributeSCEVCheckThreshold) {
765 DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
766 return false;
767 }
768
755769 DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n");
756770 // We're done forming the partitions set up the reverse mapping from
757771 // instructions to partitions.
763777 if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator())
764778 SplitBlock(PH, PH->getTerminator(), DT, LI);
765779
766 // If we need run-time checks to disambiguate pointers are run-time, version
767 // the loop now.
780 // If we need run-time checks, version the loop now.
768781 auto PtrToPartition = Partitions.computePartitionSetForPointers(LAI);
769782 const auto *RtPtrChecking = LAI.getRuntimePointerChecking();
770783 const auto &AllChecks = RtPtrChecking->getChecks();
771784 auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition,
772785 RtPtrChecking);
773 if (!Checks.empty()) {
786
787 if (!Pred.isAlwaysTrue() || !Checks.empty()) {
774788 DEBUG(dbgs() << "\nPointers:\n");
775789 DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks));
776 LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT);
790 LoopVersioning LVer(LAI, L, LI, DT, SE, false);
791 LVer.setAliasChecks(std::move(Checks));
792 LVer.setSCEVChecks(LAI.Preds);
777793 LVer.versionLoop(DefsUsedOutside);
778794 }
779795
800816 LoopInfo *LI;
801817 LoopAccessAnalysis *LAA;
802818 DominatorTree *DT;
819 ScalarEvolution *SE;
803820 };
804821 } // anonymous namespace
805822
810827 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
811828 INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis)
812829 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
830 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
813831 INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false)
814832
815833 namespace llvm {
3939 "runtime-check-per-loop-load-elim", cl::Hidden,
4040 cl::desc("Max number of memchecks allowed per eliminated load on average"),
4141 cl::init(1));
42
43 static cl::opt LoadElimSCEVCheckThreshold(
44 "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden,
45 cl::desc("The maximum number of SCEV checks allowed for Loop "
46 "Load Elimination"));
47
4248
4349 STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE");
4450
452458 return false;
453459 }
454460
461 if (LAI.Preds.getComplexity() > LoadElimSCEVCheckThreshold) {
462 DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n");
463 return false;
464 }
465
455466 // Point of no-return, start the transformation. First, version the loop if
456467 // necessary.
457 if (!Checks.empty()) {
458 LoopVersioning LV(std::move(Checks), LAI, L, LI, DT);
468 if (!Checks.empty() || !LAI.Preds.isAlwaysTrue()) {
469 LoopVersioning LV(LAI, L, LI, DT, SE, false);
470 LV.setAliasChecks(std::move(Checks));
471 LV.setSCEVChecks(LAI.Preds);
459472 LV.versionLoop();
460473 }
461474
1616
1717 #include "llvm/Analysis/LoopAccessAnalysis.h"
1818 #include "llvm/Analysis/LoopInfo.h"
19 #include "llvm/Analysis/ScalarEvolutionExpander.h"
1920 #include "llvm/IR/Dominators.h"
2021 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
2122 #include "llvm/Transforms/Utils/Cloning.h"
2223
2324 using namespace llvm;
2425
25 LoopVersioning::LoopVersioning(
26 SmallVector Checks,
27 const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT)
28 : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)),
29 LAI(LAI), LI(LI), DT(DT) {
26 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI,
27 DominatorTree *DT, ScalarEvolution *SE,
28 bool UseLAIChecks)
29 : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT),
30 SE(SE) {
3031 assert(L->getExitBlock() && "No single exit block");
3132 assert(L->getLoopPreheader() && "No preheader");
33 if (UseLAIChecks) {
34 setAliasChecks(LAI.getRuntimePointerChecking()->getChecks());
35 setSCEVChecks(LAI.Preds);
36 }
3237 }
3338
34 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L,
35 LoopInfo *LI, DominatorTree *DT)
36 : VersionedLoop(L), NonVersionedLoop(nullptr),
37 Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo),
38 LI(LI), DT(DT) {
39 assert(L->getExitBlock() && "No single exit block");
40 assert(L->getLoopPreheader() && "No preheader");
39 void LoopVersioning::setAliasChecks(
40 const SmallVector Checks) {
41 AliasChecks = std::move(Checks);
42 }
43
44 void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) {
45 Preds = std::move(Check);
4146 }
4247
4348 void LoopVersioning::versionLoop(
4449 const SmallVectorImpl &DefsUsedOutside) {
4550 Instruction *FirstCheckInst;
4651 Instruction *MemRuntimeCheck;
52 Value *SCEVRuntimeCheck;
53 Value *RuntimeCheck = nullptr;
54
4755 // Add the memcheck in the original preheader (this is empty initially).
48 BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader();
56 BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader();
4957 std::tie(FirstCheckInst, MemRuntimeCheck) =
50 LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks);
58 LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks);
5159 assert(MemRuntimeCheck && "called even though needsAnyChecking = false");
5260
61 const SCEVUnionPredicate &Pred = LAI.Preds;
62 SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(),
63 "scev.check");
64 SCEVRuntimeCheck =
65 Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator());
66 auto *CI = dyn_cast(SCEVRuntimeCheck);
67
68 // Discard the SCEV runtime check if it is always true.
69 if (CI && CI->isZero())
70 SCEVRuntimeCheck = nullptr;
71
72 if (MemRuntimeCheck && SCEVRuntimeCheck) {
73 RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck,
74 SCEVRuntimeCheck, "ldist.safe");
75 if (auto *I = dyn_cast(RuntimeCheck))
76 I->insertBefore(RuntimeCheckBB->getTerminator());
77 } else
78 RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck;
79
80 assert(RuntimeCheck && "called even though we don't need "
81 "any runtime checks");
82
5383 // Rename the block to make the IR more readable.
54 MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck");
84 RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() +
85 ".lver.check");
5586
5687 // Create empty preheader for the loop (and after cloning for the
5788 // non-versioned loop).
58 BasicBlock *PH = SplitBlock(MemCheckBB, MemCheckBB->getTerminator(), DT, LI);
89 BasicBlock *PH =
90 SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI);
5991 PH->setName(VersionedLoop->getHeader()->getName() + ".ph");
6092
6193 // Clone the loop including the preheader.
6496 // block is a join between the two loops.
6597 SmallVector NonVersionedLoopBlocks;
6698 NonVersionedLoop =
67 cloneLoopWithPreheader(PH, MemCheckBB, VersionedLoop, VMap, ".lver.orig",
68 LI, DT, NonVersionedLoopBlocks);
99 cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap,
100 ".lver.orig", LI, DT, NonVersionedLoopBlocks);
69101 remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap);
70102
71103 // Insert the conditional branch based on the result of the memchecks.
72 Instruction *OrigTerm = MemCheckBB->getTerminator();
104 Instruction *OrigTerm = RuntimeCheckBB->getTerminator();
73105 BranchInst::Create(NonVersionedLoop->getLoopPreheader(),
74 VersionedLoop->getLoopPreheader(), MemRuntimeCheck,
75 OrigTerm);
106 VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm);
76107 OrigTerm->eraseFromParent();
77108
78109 // The loops merge in the original exit block. This is now dominated by the
79110 // memchecking block.
80 DT->changeImmediateDominator(VersionedLoop->getExitBlock(), MemCheckBB);
111 DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB);
81112
82113 // Adds the necessary PHI nodes for the versioned loops based on the
83114 // loop-defined values used outside of the loop.
3535 ; Since the checks to A and A + 4 get merged, this will give us a
3636 ; total of 8 compares.
3737 ;
38 ; CHECK: for.body.lver.memcheck:
38 ; CHECK: for.body.lver.check:
3939 ; CHECK: = icmp
4040 ; CHECK: = icmp
4141
1010
1111 define void @f(i32* %A, i32* %B, i32* %C, i64 %N) {
1212
13 ; CHECK: for.body.lver.memcheck:
13 ; CHECK: for.body.lver.check:
1414 ; CHECK: %found.conflict{{.*}} =
1515 ; CHECK-NOT: %found.conflict{{.*}} =
1616
1515 entry:
1616 br label %for.body
1717
18 ; AGGRESSIVE: for.body.lver.memcheck:
18 ; AGGRESSIVE: for.body.lver.check:
1919 ; AGGRESSIVE: %found.conflict{{.*}} =
2020 ; AGGRESSIVE: %found.conflict{{.*}} =
2121 ; AGGRESSIVE-NOT: %found.conflict{{.*}} =