llvm.org GIT mirror
Fix PR2088. Use modulo linear equation solver to compute loop iteration count. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@53810 91177308-0d34-0410-b5e6-96231b3b80d8 Wojciech Matyjewicz 11 years ago
5 changed file(s) with 104 addition(s) and 26 deletion(s).
 77 77 #include "llvm/Support/MathExtras.h" 78 78 #include "llvm/Support/Streams.h" 79 79 #include "llvm/ADT/Statistic.h" 80 //TMP: 81 #include "llvm/Support/Debug.h" 80 82 #include 81 83 #include 82 84 #include 2460 2462 return UnknownValue; 2461 2463 } 2462 2464 2465 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the 2466 /// following equation: 2467 /// 2468 /// A * X = B (mod N) 2469 /// 2470 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of 2471 /// A and B isn't important. 2472 /// 2473 /// If the equation does not have a solution, SCEVCouldNotCompute is returned. 2474 static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, 2475 ScalarEvolution &SE) { 2476 uint32_t BW = A.getBitWidth(); 2477 assert(BW == B.getBitWidth() && "Bit widths must be the same."); 2478 assert(A != 0 && "A must be non-zero."); 2479 2480 // 1. D = gcd(A, N) 2481 // 2482 // The gcd of A and N may have only one prime factor: 2. The number of 2483 // trailing zeros in A is its multiplicity 2484 uint32_t Mult2 = A.countTrailingZeros(); 2485 // D = 2^Mult2 2486 2487 // 2. Check if B is divisible by D. 2488 // 2489 // B is divisible by D if and only if the multiplicity of prime factor 2 for B 2490 // is not less than multiplicity of this prime factor for D. 2491 if (B.countTrailingZeros() < Mult2) 2492 return new SCEVCouldNotCompute(); 2493 2494 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic 2495 // modulo (N / D). 2496 // 2497 // (N / D) may need BW+1 bits in its representation. Hence, we'll use this 2498 // bit width during computations. 2499 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D 2500 APInt Mod(BW + 1, 0); 2501 Mod.set(BW - Mult2); // Mod = N / D 2502 APInt I = AD.multiplicativeInverse(Mod); 2503 2504 // 4. Compute the minimum unsigned root of the equation: 2505 // I * (B / D) mod (N / D) 2506 APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); 2507 2508 // The result is guaranteed to be less than 2^BW so we may truncate it to BW 2509 // bits. 2510 return SE.getConstant(Result.trunc(BW)); 2511 } 2463 2512 2464 2513 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the 2465 2514 /// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which 2532 2581 return UnknownValue; 2533 2582 2534 2583 if (AddRec->isAffine()) { 2535 // If this is an affine expression the execution count of this branch is 2536 // equal to:⏎ 2584 // If this is an affine expression, the execution count of this branch is⏎ 2585 // the minimum unsigned root of the following equation: 2537 2586 // 2538 // (0 - Start/Step) iff Start % Step == 0⏎ 2587 // Start + Step*N = 0 (mod 2^BW)⏎ 2539 2588 // 2589 // equivalent to: 2590 // 2591 // Step*N = -Start (mod 2^BW) 2592 // 2593 // where BW is the common bit width of Start and Step. 2594 2540 2595 // Get the initial value for the loop. 2541 2596 SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); 2542 2597 if (isa(Start)) return UnknownValue; 2543 SCEVHandle Step = AddRec->getOperand(1); 2544 2545 Step = getSCEVAtScope(Step, L->getParentLoop()); 2546 2547 // Figure out if Start % Step == 0. 2548 // FIXME: We should add DivExpr and RemExpr operations to our AST.⏎ 2598 ⏎ 2599 SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); 2600 2549 2601 if (SCEVConstant *StepC = dyn_cast(Step)) { 2550 if (StepC->getValue()->equalsInt(1)) // N % 1 == 0 2551 return SE.getNegativeSCEV(Start); // 0 - Start/1 == -Start 2552 if (StepC->getValue()->isAllOnesValue()) // N % -1 == 0 2553 return Start; // 0 - Start/-1 == Start 2554 2555 // Check to see if Start is divisible by SC with no remainder. 2556 if (SCEVConstant *StartC = dyn_cast(Start)) { 2557 ConstantInt *StartCC = StartC->getValue(); 2558 Constant *StartNegC = ConstantExpr::getNeg(StartCC); 2559 Constant *Rem = ConstantExpr::getURem(StartNegC, StepC->getValue()); 2560 if (Rem->isNullValue()) { 2561 Constant *Result = ConstantExpr::getUDiv(StartNegC,StepC->getValue()); 2562 return SE.getUnknown(Result); 2563 } 2564 }⏎ 2602 // For now we handle only constant steps.⏎ 2603 2604 // First, handle unitary steps. 2605 if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: 2606 return SE.getNegativeSCEV(Start); // N = -Start (as unsigned) 2607 if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: 2608 return Start; // N = Start (as unsigned) 2609 2610 // Then, try to solve the above equation provided that Start is constant. 2611 if (SCEVConstant *StartC = dyn_cast(Start)) 2612 return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), 2613 -StartC->getValue()->getValue(),SE); 2565 2614 } 2566 2615 } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { 2567 2616 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
 1465 1465 // The next-to-last t is the multiplicative inverse. However, we are 1466 1466 // interested in a positive inverse. Calcuate a positive one from a negative 1467 1467 // one if necessary. A simple addition of the modulo suffices because 1468 // abs(t[i]) is known to less than *this/2 (see the link above).⏎ 1468 // abs(t[i]) is known to be less than *this/2 (see the link above).⏎ 1469 1469 return t[i].isNegative() ? t[i] + modulo : t[i]; 1470 1470 } 1471 1471
 0 ; RUN: llvm-as < %s | opt -analyze -scalar-evolution \ 1 1 ; RUN: -scalar-evolution-max-iterations=0 | grep {61 iterations} 2 2 ; PR2364 3 ; XFAIL: * 4 3 5 4 define i32 @func_6() nounwind { 6 5 entry:
 0 ; RUN: llvm-as < %s | opt -analyze -scalar-evolution \ 1 ; RUN: -scalar-evolution-max-iterations=0 | grep Unpredictable 2 ; PR2088 3 4 define void @fun() { 5 entry: 6 br label %loop 7 loop: 8 %i = phi i8 [ 0, %entry ], [ %i.next, %loop ] 9 %i.next = add i8 %i, 4 10 %cond = icmp ne i8 %i.next, 6 11 br i1 %cond, label %loop, label %exit 12 exit: 13 ret void 14 }
 0 ; RUN: llvm-as < %s | opt -analyze -scalar-evolution \ 1 ; RUN: -scalar-evolution-max-iterations=0 | grep {113 iterations} 2 ; PR2088 3 4 define void @fun() { 5 entry: 6 br label %loop 7 loop: 8 %i = phi i8 [ 0, %entry ], [ %i.next, %loop ] 9 %i.next = add i8 %i, 18 10 %cond = icmp ne i8 %i.next, 4 11 br i1 %cond, label %loop, label %exit 12 exit: 13 ret void 14 }