llvm.org GIT mirror llvm / c69190a
[DAGCombiner] Support (shl (ext (shl x, c1)), c2) -> 0 non-uniform folds. Use matchBinaryPredicate instead of isConstOrConstSplat to let us handle non-uniform shift cases. This requires us to tweak matchBinaryPredicate to allow it to (optionally) handle constants with different type widths. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@363792 91177308-0d34-0410-b5e6-96231b3b80d8 Simon Pilgrim 26 days ago
4 changed file(s) with 35 addition(s) and 35 deletion(s). Raw diff Collapse all Expand all
26162616 /// Attempt to match a binary predicate against a pair of scalar/splat
26172617 /// constants or every element of a pair of constant BUILD_VECTORs.
26182618 /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
2619 /// If AllowTypeMismatch is true then RetType + ArgTypes don't need to match.
26192620 bool matchBinaryPredicate(
26202621 SDValue LHS, SDValue RHS,
26212622 std::function Match,
2622 bool AllowUndefs = false);
2623 bool AllowUndefs = false, bool AllowTypeMismatch = false);
26232624 } // end namespace ISD
26242625
26252626 } // end namespace llvm
72097209 // that are shifted out by the inner shift in the first form. This means
72107210 // the outer shift size must be >= the number of bits added by the ext.
72117211 // As a corollary, we don't care what kind of ext it is.
7212 if (N1C && (N0.getOpcode() == ISD::ZERO_EXTEND ||
7213 N0.getOpcode() == ISD::ANY_EXTEND ||
7214 N0.getOpcode() == ISD::SIGN_EXTEND) &&
7212 if ((N0.getOpcode() == ISD::ZERO_EXTEND ||
7213 N0.getOpcode() == ISD::ANY_EXTEND ||
7214 N0.getOpcode() == ISD::SIGN_EXTEND) &&
72157215 N0.getOperand(0).getOpcode() == ISD::SHL) {
72167216 SDValue N0Op0 = N0.getOperand(0);
7217 if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) {
7217 SDValue InnerShiftAmt = N0Op0.getOperand(1);
7218 EVT InnerVT = N0Op0.getValueType();
7219 uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits();
7220
7221 auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS,
7222 ConstantSDNode *RHS) {
7223 APInt c1 = LHS->getAPIntValue();
7224 APInt c2 = RHS->getAPIntValue();
7225 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
7226 return c2.uge(OpSizeInBits - InnerBitwidth) &&
7227 (c1 + c2).uge(OpSizeInBits);
7228 };
7229 if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange,
7230 /*AllowUndefs*/ false,
7231 /*AllowTypeMismatch*/ true))
7232 return DAG.getConstant(0, SDLoc(N), VT);
7233
7234 ConstantSDNode *N0Op0C1 = isConstOrConstSplat(InnerShiftAmt);
7235 if (N1C && N0Op0C1) {
72187236 APInt c1 = N0Op0C1->getAPIntValue();
72197237 APInt c2 = N1C->getAPIntValue();
72207238 zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
72217239
7222 EVT InnerShiftVT = N0Op0.getValueType();
7223 uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits();
7224 if (c2.uge(OpSizeInBits - InnerShiftSize)) {
7240 if (c2.uge(OpSizeInBits - InnerBitwidth)) {
72257241 SDLoc DL(N0);
72267242 APInt Sum = c1 + c2;
72277243 if (Sum.uge(OpSizeInBits))
293293 bool ISD::matchBinaryPredicate(
294294 SDValue LHS, SDValue RHS,
295295 std::function Match,
296 bool AllowUndefs) {
297 if (LHS.getValueType() != RHS.getValueType())
296 bool AllowUndefs, bool AllowTypeMismatch) {
297 if (!AllowTypeMismatch && LHS.getValueType() != RHS.getValueType())
298298 return false;
299299
300300 // TODO: Add support for scalar UNDEF cases?
317317 auto *RHSCst = dyn_cast(RHSOp);
318318 if ((!LHSCst && !LHSUndef) || (!RHSCst && !RHSUndef))
319319 return false;
320 if (LHSOp.getValueType() != SVT ||
321 LHSOp.getValueType() != RHSOp.getValueType())
320 if (!AllowTypeMismatch && (LHSOp.getValueType() != SVT ||
321 LHSOp.getValueType() != RHSOp.getValueType()))
322322 return false;
323323 if (!Match(LHSCst, RHSCst))
324324 return false;
263263 ret <8 x i32> %3
264264 }
265265
266 ; TODO - this should fold to ZERO.
267266 define <8 x i32> @combine_vec_shl_ext_shl1(<8 x i16> %x) {
268 ; SSE2-LABEL: combine_vec_shl_ext_shl1:
269 ; SSE2: # %bb.0:
270 ; SSE2-NEXT: pmullw {{.*}}(%rip), %xmm0
271 ; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3]
272 ; SSE2-NEXT: pslld $30, %xmm0
273 ; SSE2-NEXT: xorpd %xmm1, %xmm1
274 ; SSE2-NEXT: movsd {{.*#+}} xmm0 = xmm1[0],xmm0[1]
275 ; SSE2-NEXT: movsd {{.*#+}} xmm1 = xmm1[0,1]
276 ; SSE2-NEXT: retq
277 ;
278 ; SSE41-LABEL: combine_vec_shl_ext_shl1:
279 ; SSE41: # %bb.0:
280 ; SSE41-NEXT: pmullw {{.*}}(%rip), %xmm0
281 ; SSE41-NEXT: pmovsxwd %xmm0, %xmm0
282 ; SSE41-NEXT: pslld $30, %xmm0
283 ; SSE41-NEXT: pxor %xmm1, %xmm1
284 ; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3],xmm0[4,5,6,7]
285 ; SSE41-NEXT: pxor %xmm1, %xmm1
286 ; SSE41-NEXT: retq
267 ; SSE-LABEL: combine_vec_shl_ext_shl1:
268 ; SSE: # %bb.0:
269 ; SSE-NEXT: xorps %xmm0, %xmm0
270 ; SSE-NEXT: xorps %xmm1, %xmm1
271 ; SSE-NEXT: retq
287272 ;
288273 ; AVX-LABEL: combine_vec_shl_ext_shl1:
289274 ; AVX: # %bb.0:
290 ; AVX-NEXT: vpmullw {{.*}}(%rip), %xmm0, %xmm0
291 ; AVX-NEXT: vpmovsxwd %xmm0, %ymm0
292 ; AVX-NEXT: vpsllvd {{.*}}(%rip), %ymm0, %ymm0
275 ; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0
293276 ; AVX-NEXT: retq
294277 %1 = shl <8 x i16> %x,
295278 %2 = sext <8 x i16> %1 to <8 x i32>