llvm.org GIT mirror llvm / 6c00c6a
Gracefully degrade precision in branch probability numbers. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@148946 91177308-0d34-0410-b5e6-96231b3b80d8 Nick Lewycky 7 years ago
1 changed file(s) with 75 addition(s) and 20 deletion(s). Raw diff Collapse all Expand all
14651465 return true;
14661466 }
14671467
1468 /// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In
1469 /// the event of overflow, logically-shifts all four inputs right until the
1470 /// multiply fits.
1471 static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D,
1472 unsigned &BitsLost) {
1473 BitsLost = 0;
1474 bool Overflow = false;
1475 APInt Result = A.umul_ov(B, Overflow);
1476 if (Overflow) {
1477 APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A);
1478 do {
1479 B = B.lshr(1);
1480 ++BitsLost;
1481 } while (B.ugt(MaxB));
1482 A = A.lshr(BitsLost);
1483 C = C.lshr(BitsLost);
1484 D = D.lshr(BitsLost);
1485 Result = A * B;
1486 }
1487 return Result;
1488 }
1489
1490
14681491 /// FoldBranchToCommonDest - If this basic block is simple enough, and if a
14691492 /// predecessor branches to us and one of our successors, fold the block into
14701493 /// the predecessor and use logical operations to pick the right destination.
16641687 // we get:
16651688 // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
16661689
1667 bool Overflow1 = false, Overflow2 = false, Overflow3 = false;
1668 bool Overflow4 = false, Overflow5 = false, Overflow6 = false;
1669 APInt ProbTrue = A.umul_ov(C, Overflow1);
1670
1671 APInt Tmp1 = A.umul_ov(D, Overflow2);
1672 APInt Tmp2 = B.umul_ov(C, Overflow3);
1673 APInt Tmp3 = B.umul_ov(D, Overflow4);
1674 APInt Tmp4 = Tmp1.uadd_ov(Tmp2, Overflow5);
1675 APInt ProbFalse = Tmp4.uadd_ov(Tmp3, Overflow6);
1676
1677 APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
1678 ProbTrue = ProbTrue.udiv(GCD);
1679 ProbFalse = ProbFalse.udiv(GCD);
1680
1681 if (Overflow1 || Overflow2 || Overflow3 || Overflow4 || Overflow5 ||
1682 Overflow6) {
1683 DEBUG(dbgs() << "Overflow recomputing branch weight on: " << *PBI
1684 << "when merging with: " << *BI);
1685 PBI->setMetadata(LLVMContext::MD_prof, NULL);
1686 } else {
1690 // In the event of overflow, we want to drop the LSB of the input
1691 // probabilities.
1692 unsigned BitsLost;
1693
1694 // Ignore overflow result on ProbTrue.
1695 APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
1696
1697 APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
1698 if (BitsLost) {
1699 ProbTrue = ProbTrue.lshr(BitsLost*2);
1700 }
1701
1702 APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
1703 if (BitsLost) {
1704 ProbTrue = ProbTrue.lshr(BitsLost*2);
1705 Tmp1 = Tmp1.lshr(BitsLost*2);
1706 }
1707
1708 APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
1709 if (BitsLost) {
1710 ProbTrue = ProbTrue.lshr(BitsLost*2);
1711 Tmp1 = Tmp1.lshr(BitsLost*2);
1712 Tmp2 = Tmp2.lshr(BitsLost*2);
1713 }
1714
1715 bool Overflow1 = false, Overflow2 = false;
1716 APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
1717 APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
1718
1719 if (Overflow1 || Overflow2) {
1720 ProbTrue = ProbTrue.lshr(1);
1721 Tmp1 = Tmp1.lshr(1);
1722 Tmp2 = Tmp2.lshr(1);
1723 Tmp3 = Tmp3.lshr(1);
1724 Tmp4 = Tmp2 + Tmp3;
1725 ProbFalse = Tmp4 + Tmp1;
1726 }
1727
1728 // The sum of branch weights must fit in 32-bits.
1729 if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
1730 ProbTrue = ProbTrue.lshr(1);
1731 ProbFalse = ProbFalse.lshr(1);
1732 }
1733
1734 if (ProbTrue != ProbFalse) {
1735 // Normalize the result.
1736 APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
1737 ProbTrue = ProbTrue.udiv(GCD);
1738 ProbFalse = ProbFalse.udiv(GCD);
1739
16871740 LLVMContext &Context = BI->getContext();
16881741 Value *Ops[3];
16891742 Ops[0] = BI->getMetadata(LLVMContext::MD_prof)->getOperand(0);
16901743 Ops[1] = ConstantInt::get(Context, ProbTrue);
16911744 Ops[2] = ConstantInt::get(Context, ProbFalse);
16921745 PBI->setMetadata(LLVMContext::MD_prof, MDNode::get(Context, Ops));
1746 } else {
1747 PBI->setMetadata(LLVMContext::MD_prof, NULL);
16931748 }
16941749 } else {
16951750 PBI->setMetadata(LLVMContext::MD_prof, NULL);