llvm.org GIT mirror llvm / 894ab0e
[InstCombine] Dropping redundant masking before left-shift [2/5] (PR42563) Summary: If we have some pattern that leaves only some low bits set, and then performs left-shift of those bits, if none of the bits that are left after the final shift are modified by the mask, we can omit the mask. There are many variants to this pattern: c. `(x & (-1 >> MaskShAmt)) << ShiftShAmt` All these patterns can be simplified to just: `x << ShiftShAmt` iff: c. `(ShiftShAmt-MaskShAmt) s>= 0` (i.e. `ShiftShAmt u>= MaskShAmt`) alive proofs: c: https://rise4fun.com/Alive/RgJh For now let's start with patterns where both shift amounts are variable, with trivial constant "offset" between them, since i believe this is both simplest to handle and i think this is most common. But again, there are likely other variants where we could use ValueTracking/ConstantRange to handle more cases. https://bugs.llvm.org/show_bug.cgi?id=42563 Differential Revision: https://reviews.llvm.org/D64517 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@366537 91177308-0d34-0410-b5e6-96231b3b80d8 Roman Lebedev a month ago
2 changed file(s) with 42 addition(s) and 26 deletion(s). Raw diff Collapse all Expand all
7171 // There are many variants to this pattern:
7272 // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
7373 // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt
74 // c) (x & (-1 >> MaskShAmt)) << ShiftShAmt
7475 // All these patterns can be simplified to just:
7576 // x << ShiftShAmt
7677 // iff:
7778 // a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
79 // c) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
7880 static Instruction *
7981 dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
8082 const SimplifyQuery &SQ) {
9092 auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
9193 // (~(-1 << maskNbits))
9294 auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
95 // (-1 >> MaskShAmt)
96 auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
9397
9498 Value *X;
95 if (!match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X))))
96 return nullptr;
97
98 // Can we simplify (MaskShAmt+ShiftShAmt) ?
99 Value *SumOfShAmts =
100 SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
101 SQ.getWithInstruction(OuterShift));
102 if (!SumOfShAmts)
103 return nullptr; // Did not simplify.
104 // Is the total shift amount *not* smaller than the bit width?
105 // FIXME: could also rely on ConstantRange.
106 unsigned BitWidth = X->getType()->getScalarSizeInBits();
107 if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
108 APInt(BitWidth, BitWidth))))
109 return nullptr;
110 // All good, we can do this fold.
99 if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
100 // Can we simplify (MaskShAmt+ShiftShAmt) ?
101 Value *SumOfShAmts =
102 SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
103 SQ.getWithInstruction(OuterShift));
104 if (!SumOfShAmts)
105 return nullptr; // Did not simplify.
106 // Is the total shift amount *not* smaller than the bit width?
107 // FIXME: could also rely on ConstantRange.
108 unsigned BitWidth = X->getType()->getScalarSizeInBits();
109 if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
110 APInt(BitWidth, BitWidth))))
111 return nullptr;
112 // All good, we can do this fold.
113 } else if (match(Masked, m_c_And(MaskC, m_Value(X)))) {
114 // Can we simplify (ShiftShAmt-MaskShAmt) ?
115 Value *ShAmtsDiff =
116 SimplifySubInst(ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
117 SQ.getWithInstruction(OuterShift));
118 if (!ShAmtsDiff)
119 return nullptr; // Did not simplify.
120 // Is the difference non-negative? (is ShiftShAmt u>= MaskShAmt ?)
121 // FIXME: could also rely on ConstantRange.
122 if (!match(ShAmtsDiff, m_NonNegative()))
123 return nullptr;
124 // All good, we can do this fold.
125 } else
126 return nullptr; // Don't know anything about this pattern.
111127
112128 // No 'NUW'/'NSW'!
113129 // We no longer know that we won't shift-out non-0 bits.
2020 ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
2121 ; CHECK-NEXT: call void @use32(i32 [[T0]])
2222 ; CHECK-NEXT: call void @use32(i32 [[T1]])
23 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
23 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
2424 ; CHECK-NEXT: ret i32 [[T2]]
2525 ;
2626 %t0 = lshr i32 -1, %nbits
3939 ; CHECK-NEXT: call void @use32(i32 [[T0]])
4040 ; CHECK-NEXT: call void @use32(i32 [[T1]])
4141 ; CHECK-NEXT: call void @use32(i32 [[T2]])
42 ; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[T2]]
42 ; CHECK-NEXT: [[T3:%.*]] = shl i32 [[X]], [[T2]]
4343 ; CHECK-NEXT: ret i32 [[T3]]
4444 ;
4545 %t0 = lshr i32 -1, %nbits
6464 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
6565 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
6666 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
67 ; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
67 ; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
6868 ; CHECK-NEXT: ret <3 x i32> [[T3]]
6969 ;
7070 %t0 = lshr <3 x i32> , %nbits
8585 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
8686 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
8787 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
88 ; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
88 ; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
8989 ; CHECK-NEXT: ret <3 x i32> [[T3]]
9090 ;
9191 %t0 = lshr <3 x i32> , %nbits
106106 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
107107 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
108108 ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
109 ; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
109 ; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
110110 ; CHECK-NEXT: ret <3 x i32> [[T3]]
111111 ;
112112 %t0 = lshr <3 x i32> , %nbits
130130 ; CHECK-NEXT: [[T1:%.*]] = and i32 [[X]], [[T0]]
131131 ; CHECK-NEXT: call void @use32(i32 [[T0]])
132132 ; CHECK-NEXT: call void @use32(i32 [[T1]])
133 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
133 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
134134 ; CHECK-NEXT: ret i32 [[T2]]
135135 ;
136136 %x = call i32 @gen32()
150150 ; CHECK-NEXT: call void @use32(i32 [[T0]])
151151 ; CHECK-NEXT: call void @use32(i32 [[T1]])
152152 ; CHECK-NEXT: call void @use32(i32 [[T2]])
153 ; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T2]], [[NBITS0]]
153 ; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[NBITS0]]
154154 ; CHECK-NEXT: ret i32 [[T3]]
155155 ;
156156 %t0 = lshr i32 -1, %nbits0
191191 ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
192192 ; CHECK-NEXT: call void @use32(i32 [[T0]])
193193 ; CHECK-NEXT: call void @use32(i32 [[T1]])
194 ; CHECK-NEXT: [[T2:%.*]] = shl nuw i32 [[T1]], [[NBITS]]
194 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
195195 ; CHECK-NEXT: ret i32 [[T2]]
196196 ;
197197 %t0 = lshr i32 -1, %nbits
208208 ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
209209 ; CHECK-NEXT: call void @use32(i32 [[T0]])
210210 ; CHECK-NEXT: call void @use32(i32 [[T1]])
211 ; CHECK-NEXT: [[T2:%.*]] = shl nsw i32 [[T1]], [[NBITS]]
211 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
212212 ; CHECK-NEXT: ret i32 [[T2]]
213213 ;
214214 %t0 = lshr i32 -1, %nbits
225225 ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
226226 ; CHECK-NEXT: call void @use32(i32 [[T0]])
227227 ; CHECK-NEXT: call void @use32(i32 [[T1]])
228 ; CHECK-NEXT: [[T2:%.*]] = shl nuw nsw i32 [[T1]], [[NBITS]]
228 ; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
229229 ; CHECK-NEXT: ret i32 [[T2]]
230230 ;
231231 %t0 = lshr i32 -1, %nbits