llvm.org GIT mirror llvm / 6fcf601
[Loop Peeling] Fix the handling of branch weights of peeled off branches. Current algorithm to update branch weights of latch block and its copies is based on the assumption that number of peeling iterations is approximately equal to trip count. However it is not correct. According to profitability check in one case we can decide to peel in case it helps to reduce the number of phi nodes. In this case the number of peeled iteration can be less then estimated trip count. This patch introduces another way to set the branch weights to peeled of branches. Let F is a weight of the edge from latch to header. Let E is a weight of the edge from latch to exit. F/(F+E) is a probability to go to loop and E/(F+E) is a probability to go to exit. Then, Estimated TripCount = F / E. For I-th (counting from 0) peeled off iteration we set the the weights for the peeled latch as (TC - I, 1). It gives us reasonable distribution, The probability to go to exit 1/(TC-I) increases. At the same time the estimated trip count of remaining loop reduces by I. As a result after peeling off N iteration the weights will be (F - N * E, E) and trip count of loop becomes F / E - N or TC - N. The idea is taken from the review of the patch D63918 proposed by Philip. Reviewers: reames, mkuper, iajbar, fhahn Reviewed By: reames Subscribers: hiraditya, zzheng, llvm-commits Differential Revision: https://reviews.llvm.org/D64235 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@366665 91177308-0d34-0410-b5e6-96231b3b80d8 Serguei Katkov 30 days ago
3 changed file(s) with 57 addition(s) and 70 deletion(s). Raw diff Collapse all Expand all
363363 /// iteration.
364364 /// This sets the branch weights for the latch of the recently peeled off loop
365365 /// iteration correctly.
366 /// Our goal is to make sure that:
367 /// a) The total weight of all the copies of the loop body is preserved.
368 /// b) The total weight of the loop exit is preserved.
369 /// c) The body weight is reasonably distributed between the peeled iterations.
366 /// Let F is a weight of the edge from latch to header.
367 /// Let E is a weight of the edge from latch to exit.
368 /// F/(F+E) is a probability to go to loop and E/(F+E) is a probability to
369 /// go to exit.
370 /// Then, Estimated TripCount = F / E.
371 /// For I-th (counting from 0) peeled off iteration we set the the weights for
372 /// the peeled latch as (TC - I, 1). It gives us reasonable distribution,
373 /// The probability to go to exit 1/(TC-I) increases. At the same time
374 /// the estimated trip count of remaining loop reduces by I.
375 /// To avoid dealing with division rounding we can just multiple both part
376 /// of weights to E and use weight as (F - I * E, E).
370377 ///
371378 /// \param Header The copy of the header block that belongs to next iteration.
372379 /// \param LatchBR The copy of the latch branch that belongs to this iteration.
373 /// \param IterNumber The serial number of the iteration that was just
374 /// peeled off.
375 /// \param AvgIters The average number of iterations we expect the loop to have.
376 /// \param[in,out] PeeledHeaderWeight The total number of dynamic loop
377 /// iterations that are unaccounted for. As an input, it represents the number
378 /// of times we expect to enter the header of the iteration currently being
379 /// peeled off. The output is the number of times we expect to enter the
380 /// header of the next iteration.
380 /// \param[in,out] FallThroughWeight The weight of the edge from latch to
381 /// header before peeling (in) and after peeled off one iteration (out).
381382 static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
382 unsigned IterNumber, unsigned AvgIters,
383 uint64_t &PeeledHeaderWeight) {
384 if (!PeeledHeaderWeight)
383 uint64_t ExitWeight,
384 uint64_t &FallThroughWeight) {
385 // FallThroughWeight is 0 means that there is no branch weights on original
386 // latch block or estimated trip count is zero.
387 if (!FallThroughWeight)
385388 return;
386 // FIXME: Pick a more realistic distribution.
387 // Currently the proportion of weight we assign to the fall-through
388 // side of the branch drops linearly with the iteration number, and we use
389 // a 0.9 fudge factor to make the drop-off less sharp...
390 uint64_t FallThruWeight =
391 PeeledHeaderWeight * ((float)(AvgIters - IterNumber) / AvgIters * 0.9);
392 uint64_t ExitWeight = PeeledHeaderWeight - FallThruWeight;
393 PeeledHeaderWeight -= ExitWeight;
394389
395390 unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1);
396391 MDBuilder MDB(LatchBR->getContext());
397392 MDNode *WeightNode =
398 HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThruWeight)
399 : MDB.createBranchWeights(FallThruWeight, ExitWeight);
393 HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight)
394 : MDB.createBranchWeights(FallThroughWeight, ExitWeight);
400395 LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
396 FallThroughWeight =
397 FallThroughWeight > ExitWeight ? FallThroughWeight - ExitWeight : 1;
401398 }
402399
403400 /// Initialize the weights.
404401 ///
405402 /// \param Header The header block.
406403 /// \param LatchBR The latch branch.
407 /// \param AvgIters The average number of iterations we expect the loop to have.
408 /// \param[out] ExitWeight The # of times the edge from Latch to Exit is taken.
409 /// \param[out] CurHeaderWeight The # of times the header is executed.
404 /// \param[out] ExitWeight The weight of the edge from Latch to Exit.
405 /// \param[out] FallThroughWeight The weight of the edge from Latch to Header.
410406 static void initBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
411 unsigned AvgIters, uint64_t &ExitWeight,
412 uint64_t &CurHeaderWeight) {
407 uint64_t &ExitWeight,
408 uint64_t &FallThroughWeight) {
413409 uint64_t TrueWeight, FalseWeight;
414410 if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight))
415411 return;
416412 unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
417413 ExitWeight = HeaderIdx ? TrueWeight : FalseWeight;
418 // The # of times the loop body executes is the sum of the exit block
419 // is taken and the # of times the backedges are taken.
420 CurHeaderWeight = TrueWeight + FalseWeight;
414 FallThroughWeight = HeaderIdx ? FalseWeight : TrueWeight;
421415 }
422416
423417 /// Update the weights of original Latch block after peeling off all iterations.
424418 ///
425419 /// \param Header The header block.
426420 /// \param LatchBR The latch branch.
427 /// \param ExitWeight The weight of the edge from Latch to Exit block.
428 /// \param CurHeaderWeight The # of time the header is executed.
421 /// \param ExitWeight The weight of the edge from Latch to Exit.
422 /// \param FallThroughWeight The weight of the edge from Latch to Header.
429423 static void fixupBranchWeights(BasicBlock *Header, BranchInst *LatchBR,
430 uint64_t ExitWeight, uint64_t CurHeaderWeight) {
431 // Adjust the branch weights on the loop exit.
432 if (!ExitWeight)
424 uint64_t ExitWeight,
425 uint64_t FallThroughWeight) {
426 // FallThroughWeight is 0 means that there is no branch weights on original
427 // latch block or estimated trip count is zero.
428 if (!FallThroughWeight)
433429 return;
434430
435 // The backedge count is the difference of current header weight and
436 // current loop exit weight. If the current header weight is smaller than
437 // the current loop exit weight, we mark the loop backedge weight as 1.
438 uint64_t BackEdgeWeight = 0;
439 if (ExitWeight < CurHeaderWeight)
440 BackEdgeWeight = CurHeaderWeight - ExitWeight;
441 else
442 BackEdgeWeight = 1;
431 // Sets the branch weights on the loop exit.
443432 MDBuilder MDB(LatchBR->getContext());
444433 unsigned HeaderIdx = LatchBR->getSuccessor(0) == Header ? 0 : 1;
445434 MDNode *WeightNode =
446 HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight)
447 : MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
435 HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThroughWeight)
436 : MDB.createBranchWeights(FallThroughWeight, ExitWeight);
448437 LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode);
449438 }
450439
658647 // newly created branches.
659648 BranchInst *LatchBR =
660649 cast(cast(Latch)->getTerminator());
661 uint64_t ExitWeight = 0, CurHeaderWeight = 0;
662 initBranchWeights(Header, LatchBR, PeelCount, ExitWeight, CurHeaderWeight);
650 uint64_t ExitWeight = 0, FallThroughWeight = 0;
651 initBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight);
663652
664653 // For each peeled-off iteration, make a copy of the loop.
665654 for (unsigned Iter = 0; Iter < PeelCount; ++Iter) {
666655 SmallVector NewBlocks;
667656 ValueToValueMapTy VMap;
668
669 // Subtract the exit weight from the current header weight -- the exit
670 // weight is exactly the weight of the previous iteration's header.
671 // FIXME: due to the way the distribution is constructed, we need a
672 // guard here to make sure we don't end up with non-positive weights.
673 if (ExitWeight < CurHeaderWeight)
674 CurHeaderWeight -= ExitWeight;
675 else
676 CurHeaderWeight = 1;
677657
678658 cloneLoopBlocks(L, Iter, InsertTop, InsertBot, ExitEdges, NewBlocks,
679659 LoopBlocks, VMap, LVMap, DT, LI);
696676 }
697677
698678 auto *LatchBRCopy = cast(VMap[LatchBR]);
699 updateBranchWeights(InsertBot, LatchBRCopy, Iter,
700 PeelCount, ExitWeight);
679 updateBranchWeights(InsertBot, LatchBRCopy, ExitWeight, FallThroughWeight);
701680 // Remove Loop metadata from the latch branch instruction
702681 // because it is not the Loop's latch branch anymore.
703682 LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr);
723702 PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
724703 }
725704
726 fixupBranchWeights(Header, LatchBR, ExitWeight, CurHeaderWeight);
705 fixupBranchWeights(Header, LatchBR, ExitWeight, FallThroughWeight);
727706
728707 if (Loop *ParentLoop = L->getParentLoop())
729708 L = ParentLoop;
33
44 ; Make sure we use the profile information correctly to peel-off 3 iterations
55 ; from the loop, and update the branch weights for the peeled loop properly.
6 ; All side exits to deopt does not change weigths.
67
78 ; CHECK: Loop Unroll: F[basic]
89 ; CHECK: PEELING loop %for.body with iteration count 3!
910
1011 ; CHECK-LABEL: @basic
12 ; CHECK: br i1 %c, label %{{.*}}, label %side_exit, !prof !15
1113 ; CHECK: br i1 %{{.*}}, label %[[NEXT0:.*]], label %for.cond.for.end_crit_edge, !prof !16
1214 ; CHECK: [[NEXT0]]:
15 ; CHECK: br i1 %c, label %{{.*}}, label %side_exit, !prof !15
1316 ; CHECK: br i1 %{{.*}}, label %[[NEXT1:.*]], label %for.cond.for.end_crit_edge, !prof !17
1417 ; CHECK: [[NEXT1]]:
18 ; CHECK: br i1 %c, label %{{.*}}, label %side_exit, !prof !15
1519 ; CHECK: br i1 %{{.*}}, label %[[NEXT2:.*]], label %for.cond.for.end_crit_edge, !prof !18
1620 ; CHECK: [[NEXT2]]:
21 ; CHECK: br i1 %c, label %{{.*}}, label %side_exit.loopexit, !prof !15
1722 ; CHECK: br i1 %{{.*}}, label %for.body, label %{{.*}}, !prof !19
1823
1924 define i32 @basic(i32* %p, i32 %k, i1 %c) #0 !prof !15 {
7378 !16 = !{!"branch_weights", i32 3001, i32 1001}
7479 !17 = !{!"branch_weights", i32 1, i32 0}
7580
76 ;CHECK: !16 = !{!"branch_weights", i32 900, i32 101}
77 ;CHECK: !17 = !{!"branch_weights", i32 540, i32 360}
78 ;CHECK: !18 = !{!"branch_weights", i32 162, i32 378}
79 ;CHECK: !19 = !{!"branch_weights", i32 1399, i32 162}
81 ; This is a weights of deopt side-exit.
82 ;CHECK: !15 = !{!"branch_weights", i32 1, i32 0}
83 ; This is a weights of latch and its copies.
84 ;CHECK: !16 = !{!"branch_weights", i32 3001, i32 1001}
85 ;CHECK: !17 = !{!"branch_weights", i32 2000, i32 1001}
86 ;CHECK: !18 = !{!"branch_weights", i32 999, i32 1001}
87 ;CHECK: !19 = !{!"branch_weights", i32 1, i32 1001}
8088
102102 !15 = !{!"function_entry_count", i64 1}
103103 !16 = !{!"branch_weights", i32 3001, i32 1001}
104104
105 ;CHECK: !15 = !{!"branch_weights", i32 900, i32 101}
106 ;CHECK: !16 = !{!"branch_weights", i32 540, i32 360}
107 ;CHECK: !17 = !{!"branch_weights", i32 162, i32 378}
108 ;CHECK: !18 = !{!"branch_weights", i32 1399, i32 162}
105 ;CHECK: !15 = !{!"branch_weights", i32 3001, i32 1001}
106 ;CHECK: !16 = !{!"branch_weights", i32 2000, i32 1001}
107 ;CHECK: !17 = !{!"branch_weights", i32 999, i32 1001}
108 ;CHECK: !18 = !{!"branch_weights", i32 1, i32 1001}
109109