llvm.org GIT mirror llvm / e97728e
The product of two chrec's can always be represented as a chrec. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@141066 91177308-0d34-0410-b5e6-96231b3b80d8 Nick Lewycky 9 years ago
3 changed file(s) with 241 addition(s) and 41 deletion(s). Raw diff Collapse all Expand all
587587 Ops.push_back(RHS);
588588 return getMulExpr(Ops, Flags);
589589 }
590 const SCEV *getMulExpr(const SCEV *Op0, const SCEV *Op1, const SCEV *Op2,
591 SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap) {
592 SmallVector Ops;
593 Ops.push_back(Op0);
594 Ops.push_back(Op1);
595 Ops.push_back(Op2);
596 return getMulExpr(Ops, Flags);
597 }
590598 const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS);
591599 const SCEV *getAddRecExpr(const SCEV *Start, const SCEV *Step,
592600 const Loop *L, SCEV::NoWrapFlags Flags);
18111811 return S;
18121812 }
18131813
1814 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
1815 uint64_t k = i*j;
1816 if (j > 1 && k / j != i) Overflow = true;
1817 return k;
1818 }
1819
1820 /// Compute the result of "n choose k", the binomial coefficient. If an
1821 /// intermediate computation overflows, Overflow will be set and the return will
1822 /// be garbage. Overflow is not cleared on absense of overflow.
1823 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
1824 // We use the multiplicative formula:
1825 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
1826 // At each iteration, we take the n-th term of the numeral and divide by the
1827 // (k-n)th term of the denominator. This division will always produce an
1828 // integral result, and helps reduce the chance of overflow in the
1829 // intermediate computations. However, we can still overflow even when the
1830 // final result would fit.
1831
1832 if (n == 0 || n == k) return 1;
1833 if (k > n) return 0;
1834
1835 if (k > n/2)
1836 k = n-k;
1837
1838 uint64_t r = 1;
1839 for (uint64_t i = 1; i <= k; ++i) {
1840 r = umul_ov(r, n-(i-1), Overflow);
1841 r /= i;
1842 }
1843 return r;
1844 }
1845
18141846 /// getMulExpr - Get a canonical multiply expression, or something simpler if
18151847 /// possible.
18161848 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops,
19862018 for (unsigned OtherIdx = Idx+1;
19872019 OtherIdx < Ops.size() && isa(Ops[OtherIdx]);
19882020 ++OtherIdx) {
1989 bool Retry = false;
19902021 if (AddRecLoop == cast(Ops[OtherIdx])->getLoop()) {
1991 // {A,+,B} * {C,+,D} --> {A*C,+,A*D + B*C + B*D,+,2*B*D}
2022 // {A1,+,A2,+,...,+,An} * {B1,+,B2,+,...,+,Bn}
2023 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2024 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2025 // ]]],+,...up to x=2n}.
2026 // Note that the arguments to choose() are always integers with values
2027 // known at compile time, never SCEV objects.
19922028 //
1993 // {A,+,B} * {C,+,D} = A+It*B * C+It*D = A*C + (A*D + B*C)*It + B*D*It^2
1994 // Given an equation of the form x + y*It + z*It^2 (above), we want to
1995 // express it in terms of {X,+,Y,+,Z}.
1996 // {X,+,Y,+,Z} = X + Y*It + Z*(It^2 - It)/2.
1997 // Rearranging, X = x, Y = y+z, Z = 2z.
1998 //
1999 // x = A*C, y = (A*D + B*C), z = B*D.
2000 // Therefore X = A*C, Y = A*D + B*C + B*D and Z = 2*B*D.
2029 // The implementation avoids pointless extra computations when the two
2030 // addrec's are of different length (mathematically, it's equivalent to
2031 // an infinite stream of zeros on the right).
2032 bool OpsModified = false;
20012033 for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]);
20022034 ++OtherIdx)
20032035 if (const SCEVAddRecExpr *OtherAddRec =
20042036 dyn_cast(Ops[OtherIdx]))
20052037 if (OtherAddRec->getLoop() == AddRecLoop) {
2006 const SCEV *A = AddRec->getStart();
2007 const SCEV *B = AddRec->getStepRecurrence(*this);
2008 const SCEV *C = OtherAddRec->getStart();
2009 const SCEV *D = OtherAddRec->getStepRecurrence(*this);
2010 const SCEV *NewStart = getMulExpr(A, C);
2011 const SCEV *BD = getMulExpr(B, D);
2012 const SCEV *NewStep = getAddExpr(getMulExpr(A, D),
2013 getMulExpr(B, C), BD);
2014 const SCEV *NewSecondOrderStep =
2015 getMulExpr(BD, getConstant(BD->getType(), 2));
2016
2017 // This can happen when AddRec or OtherAddRec have >3 operands.
2018 // TODO: support these add-recs.
2019 if (isLoopInvariant(NewStart, AddRecLoop) &&
2020 isLoopInvariant(NewStep, AddRecLoop) &&
2021 isLoopInvariant(NewSecondOrderStep, AddRecLoop)) {
2022 SmallVector AddRecOps;
2023 AddRecOps.push_back(NewStart);
2024 AddRecOps.push_back(NewStep);
2025 AddRecOps.push_back(NewSecondOrderStep);
2038 bool Overflow = false;
2039 Type *Ty = AddRec->getType();
2040 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2041 SmallVector AddRecOps;
2042 for (int x = 0, xe = AddRec->getNumOperands() +
2043 OtherAddRec->getNumOperands() - 1;
2044 x != xe && !Overflow; ++x) {
2045 const SCEV *Term = getConstant(Ty, 0);
2046 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2047 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2048 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2049 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2050 z < ze && !Overflow; ++z) {
2051 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2052 uint64_t Coeff;
2053 if (LargerThan64Bits)
2054 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2055 else
2056 Coeff = Coeff1*Coeff2;
2057 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2058 const SCEV *Term1 = AddRec->getOperand(y-z);
2059 const SCEV *Term2 = OtherAddRec->getOperand(z);
2060 Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
2061 }
2062 }
2063 AddRecOps.push_back(Term);
2064 }
2065 if (!Overflow) {
20262066 const SCEV *NewAddRec = getAddRecExpr(AddRecOps,
20272067 AddRec->getLoop(),
20282068 SCEV::FlagAnyWrap);
20292069 if (Ops.size() == 2) return NewAddRec;
20302070 Ops[Idx] = AddRec = cast(NewAddRec);
20312071 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2032 Retry = true;
2072 OpsModified = true;
20332073 }
20342074 }
2035 if (Retry)
2075 if (OpsModified)
20362076 return getMulExpr(Ops);
20372077 }
20382078 }
77 //===----------------------------------------------------------------------===//
88
99 #include
10 #include
1011 #include
1112 #include
1213 #include
1314 #include
1415 #include
16 #include
1517 #include "gtest/gtest.h"
1618
1719 namespace llvm {
1820 namespace {
1921
20 TEST(ScalarEvolutionsTest, SCEVUnknownRAUW) {
22 // We use this fixture to ensure that we clean up ScalarEvolution before
23 // deleting the PassManager.
24 class ScalarEvolutionsTest : public testing::Test {
25 protected:
26 ScalarEvolutionsTest() : M("", Context), SE(*new ScalarEvolution) {}
27 ~ScalarEvolutionsTest() {
28 // Manually clean up, since we allocated new SCEV objects after the
29 // pass was finished.
30 SE.releaseMemory();
31 }
2132 LLVMContext Context;
22 Module M("world", Context);
23
33 Module M;
34 PassManager PM;
35 ScalarEvolution &SE;
36 };
37
38 TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
2439 FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context),
2540 std::vector(), false);
2641 Function *F = cast(M.getOrInsertFunction("f", FTy));
3449 Value *V2 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V2");
3550
3651 // Create a ScalarEvolution and "run" it so that it gets initialized.
37 PassManager PM;
38 ScalarEvolution &SE = *new ScalarEvolution();
3952 PM.add(&SE);
4053 PM.run(M);
4154
7184 EXPECT_EQ(cast(M0->getOperand(1))->getValue(), V0);
7285 EXPECT_EQ(cast(M1->getOperand(1))->getValue(), V0);
7386 EXPECT_EQ(cast(M2->getOperand(1))->getValue(), V0);
74
75 // Manually clean up, since we allocated new SCEV objects after the
76 // pass was finished.
77 SE.releaseMemory();
87 }
88
89 TEST_F(ScalarEvolutionsTest, SCEVMultiplyAddRecs) {
90 Type *Ty = Type::getInt32Ty(Context);
91 SmallVector Types;
92 Types.append(10, Ty);
93 FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false);
94 Function *F = cast(M.getOrInsertFunction("f", FTy));
95 BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
96 ReturnInst::Create(Context, 0, BB);
97
98 // Create a ScalarEvolution and "run" it so that it gets initialized.
99 PM.add(&SE);
100 PM.run(M);
101
102 // It's possible to produce an empty loop through the default constructor,
103 // but you can't add any blocks to it without a LoopInfo pass.
104 Loop L;
105 const_cast&>(L.getBlocks()).push_back(BB);
106
107 Function::arg_iterator AI = F->arg_begin();
108 SmallVector A;
109 A.push_back(SE.getSCEV(&*AI++));
110 A.push_back(SE.getSCEV(&*AI++));
111 A.push_back(SE.getSCEV(&*AI++));
112 A.push_back(SE.getSCEV(&*AI++));
113 A.push_back(SE.getSCEV(&*AI++));
114 const SCEV *A_rec = SE.getAddRecExpr(A, &L, SCEV::FlagAnyWrap);
115
116 SmallVector B;
117 B.push_back(SE.getSCEV(&*AI++));
118 B.push_back(SE.getSCEV(&*AI++));
119 B.push_back(SE.getSCEV(&*AI++));
120 B.push_back(SE.getSCEV(&*AI++));
121 B.push_back(SE.getSCEV(&*AI++));
122 const SCEV *B_rec = SE.getAddRecExpr(B, &L, SCEV::FlagAnyWrap);
123
124 /* Spot check that we perform this transformation:
125 {A0,+,A1,+,A2,+,A3,+,A4} * {B0,+,B1,+,B2,+,B3,+,B4} =
126 {A0*B0,+,
127 A1*B0 + A0*B1 + A1*B1,+,
128 A2*B0 + 2A1*B1 + A0*B2 + 2A2*B1 + 2A1*B2 + A2*B2,+,
129 A3*B0 + 3A2*B1 + 3A1*B2 + A0*B3 + 3A3*B1 + 6A2*B2 + 3A1*B3 + 3A3*B2 +
130 3A2*B3 + A3*B3,+,
131 A4*B0 + 4A3*B1 + 6A2*B2 + 4A1*B3 + A0*B4 + 4A4*B1 + 12A3*B2 + 12A2*B3 +
132 4A1*B4 + 6A4*B2 + 12A3*B3 + 6A2*B4 + 4A4*B3 + 4A3*B4 + A4*B4,+,
133 5A4*B1 + 10A3*B2 + 10A2*B3 + 5A1*B4 + 20A4*B2 + 30A3*B3 + 20A2*B4 +
134 30A4*B3 + 30A3*B4 + 20A4*B4,+,
135 15A4*B2 + 20A3*B3 + 15A2*B4 + 60A4*B3 + 60A3*B4 + 90A4*B4,+,
136 35A4*B3 + 35A3*B4 + 140A4*B4,+,
137 70A4*B4}
138 */
139
140 const SCEVAddRecExpr *Product =
141 dyn_cast(SE.getMulExpr(A_rec, B_rec));
142 ASSERT_TRUE(Product);
143 ASSERT_EQ(Product->getNumOperands(), 9u);
144
145 SmallVector Sum;
146 Sum.push_back(SE.getMulExpr(A[0], B[0]));
147 EXPECT_EQ(Product->getOperand(0), SE.getAddExpr(Sum));
148 Sum.clear();
149
150 // SCEV produces different an equal but different expression for these.
151 // Re-enable when PR11052 is fixed.
152 #if 0
153 Sum.push_back(SE.getMulExpr(A[1], B[0]));
154 Sum.push_back(SE.getMulExpr(A[0], B[1]));
155 Sum.push_back(SE.getMulExpr(A[1], B[1]));
156 EXPECT_EQ(Product->getOperand(1), SE.getAddExpr(Sum));
157 Sum.clear();
158
159 Sum.push_back(SE.getMulExpr(A[2], B[0]));
160 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[1]));
161 Sum.push_back(SE.getMulExpr(A[0], B[2]));
162 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[2], B[1]));
163 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[2]));
164 Sum.push_back(SE.getMulExpr(A[2], B[2]));
165 EXPECT_EQ(Product->getOperand(2), SE.getAddExpr(Sum));
166 Sum.clear();
167
168 Sum.push_back(SE.getMulExpr(A[3], B[0]));
169 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[1]));
170 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[2]));
171 Sum.push_back(SE.getMulExpr(A[0], B[3]));
172 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[1]));
173 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2]));
174 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[3]));
175 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[2]));
176 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[3]));
177 Sum.push_back(SE.getMulExpr(A[3], B[3]));
178 EXPECT_EQ(Product->getOperand(3), SE.getAddExpr(Sum));
179 Sum.clear();
180
181 Sum.push_back(SE.getMulExpr(A[4], B[0]));
182 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[1]));
183 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2]));
184 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[3]));
185 Sum.push_back(SE.getMulExpr(A[0], B[4]));
186 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[1]));
187 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[2]));
188 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[2], B[3]));
189 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[4]));
190 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[4], B[2]));
191 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[3]));
192 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[4]));
193 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[3]));
194 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[4]));
195 Sum.push_back(SE.getMulExpr(A[4], B[4]));
196 EXPECT_EQ(Product->getOperand(4), SE.getAddExpr(Sum));
197 Sum.clear();
198
199 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[4], B[1]));
200 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[3], B[2]));
201 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[2], B[3]));
202 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[1], B[4]));
203 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[2]));
204 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[3]));
205 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[2], B[4]));
206 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[4], B[3]));
207 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[4]));
208 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[4]));
209 EXPECT_EQ(Product->getOperand(5), SE.getAddExpr(Sum));
210 Sum.clear();
211
212 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[4], B[2]));
213 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[3], B[3]));
214 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[2], B[4]));
215 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[4], B[3]));
216 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[3], B[4]));
217 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 90), A[4], B[4]));
218 EXPECT_EQ(Product->getOperand(6), SE.getAddExpr(Sum));
219 Sum.clear();
220
221 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[4], B[3]));
222 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[3], B[4]));
223 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 140), A[4], B[4]));
224 EXPECT_EQ(Product->getOperand(7), SE.getAddExpr(Sum));
225 Sum.clear();
226 #endif
227
228 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 70), A[4], B[4]));
229 EXPECT_EQ(Product->getOperand(8), SE.getAddExpr(Sum));
78230 }
79231
80232 } // end anonymous namespace