llvm.org GIT mirror llvm / d59827d
Model ashr(shl(x, n), m) as mul(x, 2^(n-m)) when n > m Given below case: %y = shl %x, n %z = ashr %y, m when n = m, SCEV models it as sext(trunc(x)). This patch tries to handle the case where n > m by using sext(mul(trunc(x), 2^(n-m)))) as the SCEV expression. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@298631 91177308-0d34-0410-b5e6-96231b3b80d8 Zhaoshi Zheng 3 years ago
3 changed file(s) with 175 addition(s) and 20 deletion(s). Raw diff Collapse all Expand all
53555355 break;
53565356
53575357 case Instruction::AShr:
5358 // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
5359 if (ConstantInt *CI = dyn_cast(BO->RHS))
5360 if (Operator *L = dyn_cast(BO->LHS))
5361 if (L->getOpcode() == Instruction::Shl &&
5362 L->getOperand(1) == BO->RHS) {
5363 uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType());
5364
5365 // If the shift count is not less than the bitwidth, the result of
5366 // the shift is undefined. Don't try to analyze it, because the
5367 // resolution chosen here may differ from the resolution chosen in
5368 // other parts of the compiler.
5369 if (CI->getValue().uge(BitWidth))
5370 break;
5371
5372 uint64_t Amt = BitWidth - CI->getZExtValue();
5373 if (Amt == BitWidth)
5374 return getSCEV(L->getOperand(0)); // shift by zero --> noop
5358 // AShr X, C, where C is a constant.
5359 ConstantInt *CI = dyn_cast(BO->RHS);
5360 if (!CI)
5361 break;
5362
5363 Type *OuterTy = BO->LHS->getType();
5364 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
5365 // If the shift count is not less than the bitwidth, the result of
5366 // the shift is undefined. Don't try to analyze it, because the
5367 // resolution chosen here may differ from the resolution chosen in
5368 // other parts of the compiler.
5369 if (CI->getValue().uge(BitWidth))
5370 break;
5371
5372 if (CI->isNullValue())
5373 return getSCEV(BO->LHS); // shift by zero --> noop
5374
5375 uint64_t AShrAmt = CI->getZExtValue();
5376 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
5377
5378 Operator *L = dyn_cast(BO->LHS);
5379 if (L && L->getOpcode() == Instruction::Shl) {
5380 // X = Shl A, n
5381 // Y = AShr X, m
5382 // Both n and m are constant.
5383
5384 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
5385 if (L->getOperand(1) == BO->RHS)
5386 // For a two-shift sext-inreg, i.e. n = m,
5387 // use sext(trunc(x)) as the SCEV expression.
5388 return getSignExtendExpr(
5389 getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
5390
5391 ConstantInt *ShlAmtCI = dyn_cast(L->getOperand(1));
5392 if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
5393 uint64_t ShlAmt = ShlAmtCI->getZExtValue();
5394 if (ShlAmt > AShrAmt) {
5395 // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
5396 // expression. We already checked that ShlAmt < BitWidth, so
5397 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
5398 // ShlAmt - AShrAmt < Amt.
5399 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
5400 ShlAmt - AShrAmt);
53755401 return getSignExtendExpr(
5376 getTruncateExpr(getSCEV(L->getOperand(0)),
5377 IntegerType::get(getContext(), Amt)),
5378 BO->LHS->getType());
5402 getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
5403 getConstant(Mul)), OuterTy);
53795404 }
5405 }
5406 }
53805407 break;
53815408 }
53825409 }
0 ; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s
1
2 ; CHECK: %tmp9 = shl i64 %tmp8, 33
3 ; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64)))
4 ; CHECK: %tmp10 = ashr exact i64 %tmp9, 32
5 ; CHECK-NEXT: --> {{.*}} Exits: (sext i32 (-2 + (2 * %arg2)) to i64)
6 ; CHECK: %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10
7 ; CHECK-NEXT: --> {{.*}} Exits: ((4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg)
8 ; CHECK: %tmp14 = or i64 %tmp10, 1
9 ; CHECK-NEXT: --> {{.*}} Exits: (1 + (sext i32 (-2 + (2 * %arg2)) to i64))
10 ; CHECK: %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14
11 ; CHECK-NEXT: --> {{.*}} Exits: (4 + (4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg)
12 ; CHECK:Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg2 to i64))
13 ; CHECK-NEXT:Loop %bb7: max backedge-taken count is -1
14 ; CHECK-NEXT:Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg2 to i64))
15
16 define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) {
17 bb:
18 %tmp = icmp sgt i32 %arg2, 0
19 br i1 %tmp, label %bb3, label %bb6
20
21 bb3: ; preds = %bb
22 %tmp4 = zext i32 %arg2 to i64
23 br label %bb7
24
25 bb5: ; preds = %bb7
26 br label %bb6
27
28 bb6: ; preds = %bb5, %bb
29 ret void
30
31 bb7: ; preds = %bb7, %bb3
32 %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ]
33 %tmp9 = shl i64 %tmp8, 33
34 %tmp10 = ashr exact i64 %tmp9, 32
35 %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10
36 %tmp12 = load i32, i32* %tmp11, align 4
37 %tmp13 = sub nsw i32 %tmp12, %arg1
38 store i32 %tmp13, i32* %tmp11, align 4
39 %tmp14 = or i64 %tmp10, 1
40 %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14
41 %tmp16 = load i32, i32* %tmp15, align 4
42 %tmp17 = mul nsw i32 %tmp16, %arg1
43 store i32 %tmp17, i32* %tmp15, align 4
44 %tmp18 = add nuw nsw i64 %tmp8, 1
45 %tmp19 = icmp eq i64 %tmp18, %tmp4
46 br i1 %tmp19, label %bb5, label %bb7
47 }
48
49 ; CHECK: %t10 = ashr exact i128 %t9, 1
50 ; CHECK-NEXT: --> {{.*}} Exits: (sext i127 (-633825300114114700748351602688 + (633825300114114700748351602688 * (zext i32 %arg5 to i127))) to i128)
51 ; CHECK: %t14 = or i128 %t10, 1
52 ; CHECK-NEXT: --> {{.*}} Exits: (1 + (sext i127 (-633825300114114700748351602688 + (633825300114114700748351602688 * (zext i32 %arg5 to i127))) to i128))
53 ; CHECK: Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg5 to i128))
54 ; CHECK-NEXT: Loop %bb7: max backedge-taken count is -1
55 ; CHECK-NEXT: Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg5 to i128))
56
57 define void @goo(i32* nocapture %arg3, i32 %arg4, i32 %arg5) {
58 bb:
59 %t = icmp sgt i32 %arg5, 0
60 br i1 %t, label %bb3, label %bb6
61
62 bb3: ; preds = %bb
63 %t4 = zext i32 %arg5 to i128
64 br label %bb7
65
66 bb5: ; preds = %bb7
67 br label %bb6
68
69 bb6: ; preds = %bb5, %bb
70 ret void
71
72 bb7: ; preds = %bb7, %bb3
73 %t8 = phi i128 [ %t18, %bb7 ], [ 0, %bb3 ]
74 %t9 = shl i128 %t8, 100
75 %t10 = ashr exact i128 %t9, 1
76 %t11 = getelementptr inbounds i32, i32* %arg3, i128 %t10
77 %t12 = load i32, i32* %t11, align 4
78 %t13 = sub nsw i32 %t12, %arg4
79 store i32 %t13, i32* %t11, align 4
80 %t14 = or i128 %t10, 1
81 %t15 = getelementptr inbounds i32, i32* %arg3, i128 %t14
82 %t16 = load i32, i32* %t15, align 4
83 %t17 = mul nsw i32 %t16, %arg4
84 store i32 %t17, i32* %t15, align 4
85 %t18 = add nuw nsw i128 %t8, 1
86 %t19 = icmp eq i128 %t18, %t4
87 br i1 %t19, label %bb5, label %bb7
88 }
0 ; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s
1
2 ; CHECK: %tmp9 = shl i64 %tmp8, 33
3 ; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64)))
4 ; CHECK-NEXT: %tmp10 = ashr exact i64 %tmp9, 0
5 ; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64)))
6
7 define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) {
8 bb:
9 %tmp = icmp sgt i32 %arg2, 0
10 br i1 %tmp, label %bb3, label %bb6
11
12 bb3: ; preds = %bb
13 %tmp4 = zext i32 %arg2 to i64
14 br label %bb7
15
16 bb5: ; preds = %bb7
17 br label %bb6
18
19 bb6: ; preds = %bb5, %bb
20 ret void
21
22 bb7: ; preds = %bb7, %bb3
23 %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ]
24 %tmp9 = shl i64 %tmp8, 33
25 %tmp10 = ashr exact i64 %tmp9, 0
26 %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10
27 %tmp12 = load i32, i32* %tmp11, align 4
28 %tmp13 = sub nsw i32 %tmp12, %arg1
29 store i32 %tmp13, i32* %tmp11, align 4
30 %tmp14 = or i64 %tmp10, 1
31 %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14
32 %tmp16 = load i32, i32* %tmp15, align 4
33 %tmp17 = mul nsw i32 %tmp16, %arg1
34 store i32 %tmp17, i32* %tmp15, align 4
35 %tmp18 = add nuw nsw i64 %tmp8, 1
36 %tmp19 = icmp eq i64 %tmp18, %tmp4
37 br i1 %tmp19, label %bb5, label %bb7
38 }