llvm.org GIT mirror llvm / de0f238
Fix PR2088. Use modulo linear equation solver to compute loop iteration count. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@53810 91177308-0d34-0410-b5e6-96231b3b80d8 Wojciech Matyjewicz 11 years ago
5 changed file(s) with 104 addition(s) and 26 deletion(s). Raw diff Collapse all Expand all
7777 #include "llvm/Support/MathExtras.h"
7878 #include "llvm/Support/Streams.h"
7979 #include "llvm/ADT/Statistic.h"
80 //TMP:
81 #include "llvm/Support/Debug.h"
8082 #include
8183 #include
8284 #include
24602462 return UnknownValue;
24612463 }
24622464
2465 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
2466 /// following equation:
2467 ///
2468 /// A * X = B (mod N)
2469 ///
2470 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
2471 /// A and B isn't important.
2472 ///
2473 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
2474 static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
2475 ScalarEvolution &SE) {
2476 uint32_t BW = A.getBitWidth();
2477 assert(BW == B.getBitWidth() && "Bit widths must be the same.");
2478 assert(A != 0 && "A must be non-zero.");
2479
2480 // 1. D = gcd(A, N)
2481 //
2482 // The gcd of A and N may have only one prime factor: 2. The number of
2483 // trailing zeros in A is its multiplicity
2484 uint32_t Mult2 = A.countTrailingZeros();
2485 // D = 2^Mult2
2486
2487 // 2. Check if B is divisible by D.
2488 //
2489 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
2490 // is not less than multiplicity of this prime factor for D.
2491 if (B.countTrailingZeros() < Mult2)
2492 return new SCEVCouldNotCompute();
2493
2494 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
2495 // modulo (N / D).
2496 //
2497 // (N / D) may need BW+1 bits in its representation. Hence, we'll use this
2498 // bit width during computations.
2499 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
2500 APInt Mod(BW + 1, 0);
2501 Mod.set(BW - Mult2); // Mod = N / D
2502 APInt I = AD.multiplicativeInverse(Mod);
2503
2504 // 4. Compute the minimum unsigned root of the equation:
2505 // I * (B / D) mod (N / D)
2506 APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
2507
2508 // The result is guaranteed to be less than 2^BW so we may truncate it to BW
2509 // bits.
2510 return SE.getConstant(Result.trunc(BW));
2511 }
24632512
24642513 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the
24652514 /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which
25322581 return UnknownValue;
25332582
25342583 if (AddRec->isAffine()) {
2535 // If this is an affine expression the execution count of this branch is
2536 // equal to:
2584 // If this is an affine expression, the execution count of this branch is
2585 // the minimum unsigned root of the following equation:
25372586 //
2538 // (0 - Start/Step) iff Start % Step == 0
2587 // Start + Step*N = 0 (mod 2^BW)
25392588 //
2589 // equivalent to:
2590 //
2591 // Step*N = -Start (mod 2^BW)
2592 //
2593 // where BW is the common bit width of Start and Step.
2594
25402595 // Get the initial value for the loop.
25412596 SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
25422597 if (isa(Start)) return UnknownValue;
2543 SCEVHandle Step = AddRec->getOperand(1);
2544
2545 Step = getSCEVAtScope(Step, L->getParentLoop());
2546
2547 // Figure out if Start % Step == 0.
2548 // FIXME: We should add DivExpr and RemExpr operations to our AST.
2598
2599 SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
2600
25492601 if (SCEVConstant *StepC = dyn_cast(Step)) {
2550 if (StepC->getValue()->equalsInt(1)) // N % 1 == 0
2551 return SE.getNegativeSCEV(Start); // 0 - Start/1 == -Start
2552 if (StepC->getValue()->isAllOnesValue()) // N % -1 == 0
2553 return Start; // 0 - Start/-1 == Start
2554
2555 // Check to see if Start is divisible by SC with no remainder.
2556 if (SCEVConstant *StartC = dyn_cast(Start)) {
2557 ConstantInt *StartCC = StartC->getValue();
2558 Constant *StartNegC = ConstantExpr::getNeg(StartCC);
2559 Constant *Rem = ConstantExpr::getURem(StartNegC, StepC->getValue());
2560 if (Rem->isNullValue()) {
2561 Constant *Result = ConstantExpr::getUDiv(StartNegC,StepC->getValue());
2562 return SE.getUnknown(Result);
2563 }
2564 }
2602 // For now we handle only constant steps.
2603
2604 // First, handle unitary steps.
2605 if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so:
2606 return SE.getNegativeSCEV(Start); // N = -Start (as unsigned)
2607 if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so:
2608 return Start; // N = Start (as unsigned)
2609
2610 // Then, try to solve the above equation provided that Start is constant.
2611 if (SCEVConstant *StartC = dyn_cast(Start))
2612 return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
2613 -StartC->getValue()->getValue(),SE);
25652614 }
25662615 } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
25672616 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
14651465 // The next-to-last t is the multiplicative inverse. However, we are
14661466 // interested in a positive inverse. Calcuate a positive one from a negative
14671467 // one if necessary. A simple addition of the modulo suffices because
1468 // abs(t[i]) is known to less than *this/2 (see the link above).
1468 // abs(t[i]) is known to be less than *this/2 (see the link above).
14691469 return t[i].isNegative() ? t[i] + modulo : t[i];
14701470 }
14711471
0 ; RUN: llvm-as < %s | opt -analyze -scalar-evolution \
11 ; RUN: -scalar-evolution-max-iterations=0 | grep {61 iterations}
22 ; PR2364
3 ; XFAIL: *
43
54 define i32 @func_6() nounwind {
65 entry:
0 ; RUN: llvm-as < %s | opt -analyze -scalar-evolution \
1 ; RUN: -scalar-evolution-max-iterations=0 | grep Unpredictable
2 ; PR2088
3
4 define void @fun() {
5 entry:
6 br label %loop
7 loop:
8 %i = phi i8 [ 0, %entry ], [ %i.next, %loop ]
9 %i.next = add i8 %i, 4
10 %cond = icmp ne i8 %i.next, 6
11 br i1 %cond, label %loop, label %exit
12 exit:
13 ret void
14 }
0 ; RUN: llvm-as < %s | opt -analyze -scalar-evolution \
1 ; RUN: -scalar-evolution-max-iterations=0 | grep {113 iterations}
2 ; PR2088
3
4 define void @fun() {
5 entry:
6 br label %loop
7 loop:
8 %i = phi i8 [ 0, %entry ], [ %i.next, %loop ]
9 %i.next = add i8 %i, 18
10 %cond = icmp ne i8 %i.next, 4
11 br i1 %cond, label %loop, label %exit
12 exit:
13 ret void
14 }