llvm.org GIT mirror llvm / 9cc691f
AVX-512, X86: Added lowering for shift operations for SKX. The other changes in the LowerShift() are not functional, just to make the code more convenient. So, the functional changes for SKX only. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@237129 91177308-0d34-0410-b5e6-96231b3b80d8 Elena Demikhovsky 5 years ago
2 changed file(s) with 95 addition(s) and 102 deletion(s). Raw diff Collapse all Expand all
15061506 setOperationAction(ISD::AND, MVT::v4i32, Legal);
15071507 setOperationAction(ISD::OR, MVT::v4i32, Legal);
15081508 setOperationAction(ISD::XOR, MVT::v4i32, Legal);
1509 setOperationAction(ISD::SRA, MVT::v2i64, Custom);
1510 setOperationAction(ISD::SRA, MVT::v4i64, Custom);
15091511 }
15101512
15111513 // We want to custom lower some of our intrinsics.
1632716329 return DAG.getMergeValues(Ops, dl);
1632816330 }
1632916331
16332 // Return true if the requred (according to Opcode) shift-imm form is natively
16333 // supported by the Subtarget
16334 static bool SupportedVectorShiftWithImm(MVT VT, const X86Subtarget *Subtarget,
16335 unsigned Opcode) {
16336 if (VT.getScalarSizeInBits() < 16)
16337 return false;
16338
16339 if (VT.is512BitVector() &&
16340 (VT.getScalarSizeInBits() > 16 || Subtarget->hasBWI()))
16341 return true;
16342
16343 bool LShift = VT.is128BitVector() ||
16344 (VT.is256BitVector() && Subtarget->hasInt256());
16345
16346 bool AShift = LShift && (Subtarget->hasVLX() ||
16347 (VT != MVT::v2i64 && VT != MVT::v4i64));
16348 return (Opcode == ISD::SRA) ? AShift : LShift;
16349 }
16350
16351 // The shift amount is a variable, but it is the same for all vector lanes.
16352 // These instrcutions are defined together with shift-immediate.
16353 static
16354 bool SupportedVectorShiftWithBaseAmnt(MVT VT, const X86Subtarget *Subtarget,
16355 unsigned Opcode) {
16356 return SupportedVectorShiftWithImm(VT, Subtarget, Opcode);
16357 }
16358
16359 // Return true if the requred (according to Opcode) variable-shift form is
16360 // natively supported by the Subtarget
16361 static bool SupportedVectorVarShift(MVT VT, const X86Subtarget *Subtarget,
16362 unsigned Opcode) {
16363
16364 if (!Subtarget->hasInt256() || VT.getScalarSizeInBits() < 16)
16365 return false;
16366
16367 // vXi16 supported only on AVX-512, BWI
16368 if (VT.getScalarSizeInBits() == 16 && !Subtarget->hasBWI())
16369 return false;
16370
16371 if (VT.is512BitVector() || Subtarget->hasVLX())
16372 return true;
16373
16374 bool LShift = VT.is128BitVector() || VT.is256BitVector();
16375 bool AShift = LShift && VT != MVT::v2i64 && VT != MVT::v4i64;
16376 return (Opcode == ISD::SRA) ? AShift : LShift;
16377 }
16378
1633016379 static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
1633116380 const X86Subtarget *Subtarget) {
1633216381 MVT VT = Op.getSimpleValueType();
1633416383 SDValue R = Op.getOperand(0);
1633516384 SDValue Amt = Op.getOperand(1);
1633616385
16386 unsigned X86Opc = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
16387 (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
16388
1633716389 // Optimize shl/srl/sra with constant shift amount.
1633816390 if (auto *BVAmt = dyn_cast(Amt)) {
1633916391 if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
1634016392 uint64_t ShiftAmt = ShiftConst->getZExtValue();
1634116393
16342 if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 ||
16343 (Subtarget->hasInt256() &&
16344 (VT == MVT::v4i64 || VT == MVT::v8i32 || VT == MVT::v16i16)) ||
16345 (Subtarget->hasAVX512() &&
16346 (VT == MVT::v8i64 || VT == MVT::v16i32))) {
16347 if (Op.getOpcode() == ISD::SHL)
16348 return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
16349 DAG);
16350 if (Op.getOpcode() == ISD::SRL)
16351 return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
16352 DAG);
16353 if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
16354 return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
16355 DAG);
16356 }
16394 if (SupportedVectorShiftWithImm(VT, Subtarget, Op.getOpcode()))
16395 return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
1635716396
1635816397 if (VT == MVT::v16i8 || (Subtarget->hasInt256() && VT == MVT::v32i8)) {
1635916398 unsigned NumElts = VT.getVectorNumElements();
1643416473 if (ShAmt != ShiftAmt)
1643516474 return SDValue();
1643616475 }
16437 switch (Op.getOpcode()) {
16438 default:
16439 llvm_unreachable("Unknown shift opcode!");
16440 case ISD::SHL:
16441 return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
16442 DAG);
16443 case ISD::SRL:
16444 return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
16445 DAG);
16446 case ISD::SRA:
16447 return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
16448 DAG);
16449 }
16476 return getTargetVShiftByConstNode(X86Opc, dl, VT, R, ShiftAmt, DAG);
1645016477 }
1645116478
1645216479 return SDValue();
1645916486 SDValue R = Op.getOperand(0);
1646016487 SDValue Amt = Op.getOperand(1);
1646116488
16462 if ((VT == MVT::v2i64 && Op.getOpcode() != ISD::SRA) ||
16463 VT == MVT::v4i32 || VT == MVT::v8i16 ||
16464 (Subtarget->hasInt256() &&
16465 ((VT == MVT::v4i64 && Op.getOpcode() != ISD::SRA) ||
16466 VT == MVT::v8i32 || VT == MVT::v16i16)) ||
16467 (Subtarget->hasAVX512() && (VT == MVT::v8i64 || VT == MVT::v16i32))) {
16489 unsigned X86OpcI = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHLI :
16490 (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRLI : X86ISD::VSRAI;
16491
16492 unsigned X86OpcV = (Op.getOpcode() == ISD::SHL) ? X86ISD::VSHL :
16493 (Op.getOpcode() == ISD::SRL) ? X86ISD::VSRL : X86ISD::VSRA;
16494
16495 if (SupportedVectorShiftWithBaseAmnt(VT, Subtarget, Op.getOpcode())) {
1646816496 SDValue BaseShAmt;
1646916497 EVT EltVT = VT.getVectorElementType();
1647016498
1650816536 else if (EltVT.bitsLT(MVT::i32))
1650916537 BaseShAmt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, BaseShAmt);
1651016538
16511 switch (Op.getOpcode()) {
16512 default:
16513 llvm_unreachable("Unknown shift opcode!");
16514 case ISD::SHL:
16515 switch (VT.SimpleTy) {
16516 default: return SDValue();
16517 case MVT::v2i64:
16518 case MVT::v4i32:
16519 case MVT::v8i16:
16520 case MVT::v4i64:
16521 case MVT::v8i32:
16522 case MVT::v16i16:
16523 case MVT::v16i32:
16524 case MVT::v8i64:
16525 return getTargetVShiftNode(X86ISD::VSHLI, dl, VT, R, BaseShAmt, DAG);
16526 }
16527 case ISD::SRA:
16528 switch (VT.SimpleTy) {
16529 default: return SDValue();
16530 case MVT::v4i32:
16531 case MVT::v8i16:
16532 case MVT::v8i32:
16533 case MVT::v16i16:
16534 case MVT::v16i32:
16535 case MVT::v8i64:
16536 return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, R, BaseShAmt, DAG);
16537 }
16538 case ISD::SRL:
16539 switch (VT.SimpleTy) {
16540 default: return SDValue();
16541 case MVT::v2i64:
16542 case MVT::v4i32:
16543 case MVT::v8i16:
16544 case MVT::v4i64:
16545 case MVT::v8i32:
16546 case MVT::v16i16:
16547 case MVT::v16i32:
16548 case MVT::v8i64:
16549 return getTargetVShiftNode(X86ISD::VSRLI, dl, VT, R, BaseShAmt, DAG);
16550 }
16551 }
16539 return getTargetVShiftNode(X86OpcI, dl, VT, R, BaseShAmt, DAG);
1655216540 }
1655316541 }
1655416542
1656716555 if (Vals[j] != Amt.getOperand(i + j))
1656816556 return SDValue();
1656916557 }
16570 switch (Op.getOpcode()) {
16571 default:
16572 llvm_unreachable("Unknown shift opcode!");
16573 case ISD::SHL:
16574 return DAG.getNode(X86ISD::VSHL, dl, VT, R, Op.getOperand(1));
16575 case ISD::SRL:
16576 return DAG.getNode(X86ISD::VSRL, dl, VT, R, Op.getOperand(1));
16577 case ISD::SRA:
16578 return DAG.getNode(X86ISD::VSRA, dl, VT, R, Op.getOperand(1));
16579 }
16580 }
16581
16558 return DAG.getNode(X86OpcV, dl, VT, R, Op.getOperand(1));
16559 }
1658216560 return SDValue();
1658316561 }
1658416562
1659816576 if (SDValue V = LowerScalarVariableShift(Op, DAG, Subtarget))
1659916577 return V;
1660016578
16601 if (Subtarget->hasAVX512() && (VT == MVT::v16i32 || VT == MVT::v8i64))
16579 if (SupportedVectorVarShift(VT, Subtarget, Op.getOpcode()))
1660216580 return Op;
16603
16604 // AVX2 has VPSLLV/VPSRAV/VPSRLV.
16605 if (Subtarget->hasInt256()) {
16606 if (Op.getOpcode() == ISD::SRL &&
16607 (VT == MVT::v2i64 || VT == MVT::v4i32 ||
16608 VT == MVT::v4i64 || VT == MVT::v8i32))
16609 return Op;
16610 if (Op.getOpcode() == ISD::SHL &&
16611 (VT == MVT::v2i64 || VT == MVT::v4i32 ||
16612 VT == MVT::v4i64 || VT == MVT::v8i32))
16613 return Op;
16614 if (Op.getOpcode() == ISD::SRA && (VT == MVT::v4i32 || VT == MVT::v8i32))
16615 return Op;
16616 }
1661716581
1661816582 // 2i64 vector logical shifts can efficiently avoid scalarization - do the
1661916583 // shifts per-lane and then shuffle the partial results back together.
0 ;RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=knl | FileCheck %s
1 ;RUN: llc < %s -mtriple=x86_64-apple-darwin -mcpu=skx | FileCheck --check-prefix=SKX %s
12
23 ;CHECK-LABEL: shift_16_i32
34 ;CHECK: vpsrld
2122 %c = shl <8 x i64> %b,
2223 %d = ashr <8 x i64> %c,
2324 ret <8 x i64> %d;
25 }
26
27 ;SKX-LABEL: shift_4_i64
28 ;SKX: vpsrlq
29 ;SKX: vpsllq
30 ;SKX: vpsraq
31 ;SKX: ret
32 define <4 x i64> @shift_4_i64(<4 x i64> %a) {
33 %b = lshr <4 x i64> %a,
34 %c = shl <4 x i64> %b,
35 %d = ashr <4 x i64> %c,
36 ret <4 x i64> %d;
2437 }
2538
2639 ; CHECK-LABEL: variable_shl4
7184 ret <8 x i64> %k
7285 }
7386
87 ; SKX-LABEL: variable_sra3
88 ; SKX: vpsravq %ymm
89 ; SKX: ret
90 define <4 x i64> @variable_sra3(<4 x i64> %x, <4 x i64> %y) {
91 %k = ashr <4 x i64> %x, %y
92 ret <4 x i64> %k
93 }
94
95 ; SKX-LABEL: variable_sra4
96 ; SKX: vpsravw %xmm
97 ; SKX: ret
98 define <8 x i16> @variable_sra4(<8 x i16> %x, <8 x i16> %y) {
99 %k = ashr <8 x i16> %x, %y
100 ret <8 x i16> %k
101 }
102
74103 ; CHECK-LABEL: variable_sra01_load
75104 ; CHECK: vpsravd (%
76105 ; CHECK: ret