llvm.org GIT mirror llvm / 863b0d4
[NVPTX] Add support for [SHL,SRA,SRL]_PARTS git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211936 91177308-0d34-0410-b5e6-96231b3b80d8 Justin Holewinski 6 years ago
4 changed file(s) with 208 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
150150 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
151151 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
152152 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
153
154 setOperationAction(ISD::SHL_PARTS, MVT::i32 , Custom);
155 setOperationAction(ISD::SRA_PARTS, MVT::i32 , Custom);
156 setOperationAction(ISD::SRL_PARTS, MVT::i32 , Custom);
157 setOperationAction(ISD::SHL_PARTS, MVT::i64 , Custom);
158 setOperationAction(ISD::SRA_PARTS, MVT::i64 , Custom);
159 setOperationAction(ISD::SRL_PARTS, MVT::i64 , Custom);
153160
154161 if (nvptxSubtarget.hasROT64()) {
155162 setOperationAction(ISD::ROTL, MVT::i64, Legal);
344351 return "NVPTXISD::FUN_SHFL_CLAMP";
345352 case NVPTXISD::FUN_SHFR_CLAMP:
346353 return "NVPTXISD::FUN_SHFR_CLAMP";
354 case NVPTXISD::IMAD:
355 return "NVPTXISD::IMAD";
356 case NVPTXISD::MUL_WIDE_SIGNED:
357 return "NVPTXISD::MUL_WIDE_SIGNED";
358 case NVPTXISD::MUL_WIDE_UNSIGNED:
359 return "NVPTXISD::MUL_WIDE_UNSIGNED";
347360 case NVPTXISD::Tex1DFloatI32: return "NVPTXISD::Tex1DFloatI32";
348361 case NVPTXISD::Tex1DFloatFloat: return "NVPTXISD::Tex1DFloatFloat";
349362 case NVPTXISD::Tex1DFloatFloatLevel:
12781291 return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), Ops);
12791292 }
12801293
1294 /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
1295 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
1296 /// amount, or
1297 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
1298 /// amount.
1299 SDValue NVPTXTargetLowering::LowerShiftRightParts(SDValue Op,
1300 SelectionDAG &DAG) const {
1301 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
1302 assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
1303
1304 EVT VT = Op.getValueType();
1305 unsigned VTBits = VT.getSizeInBits();
1306 SDLoc dl(Op);
1307 SDValue ShOpLo = Op.getOperand(0);
1308 SDValue ShOpHi = Op.getOperand(1);
1309 SDValue ShAmt = Op.getOperand(2);
1310 unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
1311
1312 if (VTBits == 32 && nvptxSubtarget.getSmVersion() >= 35) {
1313
1314 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
1315 // {dHi, dLo} = {aHi, aLo} >> Amt
1316 // dHi = aHi >> Amt
1317 // dLo = shf.r.clamp aLo, aHi, Amt
1318
1319 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
1320 SDValue Lo = DAG.getNode(NVPTXISD::FUN_SHFR_CLAMP, dl, VT, ShOpLo, ShOpHi,
1321 ShAmt);
1322
1323 SDValue Ops[2] = { Lo, Hi };
1324 return DAG.getMergeValues(Ops, dl);
1325 }
1326 else {
1327
1328 // {dHi, dLo} = {aHi, aLo} >> Amt
1329 // - if (Amt>=size) then
1330 // dLo = aHi >> (Amt-size)
1331 // dHi = aHi >> Amt (this is either all 0 or all 1)
1332 // else
1333 // dLo = (aLo >>logic Amt) | (aHi << (size-Amt))
1334 // dHi = aHi >> Amt
1335
1336 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
1337 DAG.getConstant(VTBits, MVT::i32), ShAmt);
1338 SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
1339 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
1340 DAG.getConstant(VTBits, MVT::i32));
1341 SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
1342 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
1343 SDValue TrueVal = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
1344
1345 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
1346 DAG.getConstant(VTBits, MVT::i32), ISD::SETGE);
1347 SDValue Hi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
1348 SDValue Lo = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
1349
1350 SDValue Ops[2] = { Lo, Hi };
1351 return DAG.getMergeValues(Ops, dl);
1352 }
1353 }
1354
1355 /// LowerShiftLeftParts - Lower SHL_PARTS, which
1356 /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
1357 /// amount, or
1358 /// 2) returns two i64 values and take a 2 x i64 value to shift plus a shift
1359 /// amount.
1360 SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
1361 SelectionDAG &DAG) const {
1362 assert(Op.getNumOperands() == 3 && "Not a double-shift!");
1363 assert(Op.getOpcode() == ISD::SHL_PARTS);
1364
1365 EVT VT = Op.getValueType();
1366 unsigned VTBits = VT.getSizeInBits();
1367 SDLoc dl(Op);
1368 SDValue ShOpLo = Op.getOperand(0);
1369 SDValue ShOpHi = Op.getOperand(1);
1370 SDValue ShAmt = Op.getOperand(2);
1371
1372 if (VTBits == 32 && nvptxSubtarget.getSmVersion() >= 35) {
1373
1374 // For 32bit and sm35, we can use the funnel shift 'shf' instruction.
1375 // {dHi, dLo} = {aHi, aLo} << Amt
1376 // dHi = shf.l.clamp aLo, aHi, Amt
1377 // dLo = aLo << Amt
1378
1379 SDValue Hi = DAG.getNode(NVPTXISD::FUN_SHFL_CLAMP, dl, VT, ShOpLo, ShOpHi,
1380 ShAmt);
1381 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
1382
1383 SDValue Ops[2] = { Lo, Hi };
1384 return DAG.getMergeValues(Ops, dl);
1385 }
1386 else {
1387
1388 // {dHi, dLo} = {aHi, aLo} << Amt
1389 // - if (Amt>=size) then
1390 // dLo = aLo << Amt (all 0)
1391 // dLo = aLo << (Amt-size)
1392 // else
1393 // dLo = aLo << Amt
1394 // dHi = (aHi << Amt) | (aLo >> (size-Amt))
1395
1396 SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32,
1397 DAG.getConstant(VTBits, MVT::i32), ShAmt);
1398 SDValue Tmp1 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
1399 SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i32, ShAmt,
1400 DAG.getConstant(VTBits, MVT::i32));
1401 SDValue Tmp2 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
1402 SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
1403 SDValue TrueVal = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
1404
1405 SDValue Cmp = DAG.getSetCC(dl, MVT::i1, ShAmt,
1406 DAG.getConstant(VTBits, MVT::i32), ISD::SETGE);
1407 SDValue Lo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
1408 SDValue Hi = DAG.getNode(ISD::SELECT, dl, VT, Cmp, TrueVal, FalseVal);
1409
1410 SDValue Ops[2] = { Lo, Hi };
1411 return DAG.getMergeValues(Ops, dl);
1412 }
1413 }
1414
12811415 SDValue
12821416 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
12831417 switch (Op.getOpcode()) {
12981432 return LowerSTORE(Op, DAG);
12991433 case ISD::LOAD:
13001434 return LowerLOAD(Op, DAG);
1435 case ISD::SHL_PARTS:
1436 return LowerShiftLeftParts(Op, DAG);
1437 case ISD::SRA_PARTS:
1438 case ISD::SRL_PARTS:
1439 return LowerShiftRightParts(Op, DAG);
13011440 default:
13021441 llvm_unreachable("Custom lowering not defined for operation");
13031442 }
4848 CallSeqBegin,
4949 CallSeqEnd,
5050 CallPrototype,
51 FUN_SHFL_CLAMP,
52 FUN_SHFR_CLAMP,
5153 MUL_WIDE_SIGNED,
5254 MUL_WIDE_UNSIGNED,
5355 IMAD,
258260 SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
259261 SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
260262
263 SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
264 SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
265
261266 void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results,
262267 SelectionDAG &DAG) const override;
263268 SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
13621362 def : Pat<(i1 (select Int1Regs:$p, Int1Regs:$a, Int1Regs:$b)),
13631363 (ORb1rr (ANDb1rr Int1Regs:$p, Int1Regs:$a),
13641364 (ANDb1rr (NOT1 Int1Regs:$p), Int1Regs:$b))>;
1365
1366 //
1367 // Funnnel shift in clamp mode
1368 //
1369 // - SDNodes are created so they can be used in the DAG code,
1370 // e.g. NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts)
1371 //
1372 def SDTIntShiftDOp: SDTypeProfile<1, 3,
1373 [SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>,
1374 SDTCisInt<0>, SDTCisInt<3>]>;
1375 def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>;
1376 def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>;
1377
1378 def FUNSHFLCLAMP : NVPTXInst<(outs Int32Regs:$dst),
1379 (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
1380 "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;",
1381 [(set Int32Regs:$dst,
1382 (FUN_SHFL_CLAMP Int32Regs:$lo,
1383 Int32Regs:$hi, Int32Regs:$amt))]>;
1384
1385 def FUNSHFRCLAMP : NVPTXInst<(outs Int32Regs:$dst),
1386 (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt),
1387 "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;",
1388 [(set Int32Regs:$dst,
1389 (FUN_SHFR_CLAMP Int32Regs:$lo,
1390 Int32Regs:$hi, Int32Regs:$amt))]>;
13651391
13661392 //-----------------------------------
13671393 // Data Movement (Load / Store, Move)
0 ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
1
2 ; CHECK: shift_parts_left_128
3 define void @shift_parts_left_128(i128* %val, i128* %amtptr) {
4 ; CHECK: shl.b64
5 ; CHECK: mov.u32
6 ; CHECK: sub.s32
7 ; CHECK: shr.u64
8 ; CHECK: or.b64
9 ; CHECK: add.s32
10 ; CHECK: shl.b64
11 ; CHECK: setp.gt.s32
12 ; CHECK: selp.b64
13 ; CHECK: shl.b64
14 %amt = load i128* %amtptr
15 %a = load i128* %val
16 %val0 = shl i128 %a, %amt
17 store i128 %val0, i128* %val
18 ret void
19 }
20
21 ; CHECK: shift_parts_right_128
22 define void @shift_parts_right_128(i128* %val, i128* %amtptr) {
23 ; CHECK: shr.u64
24 ; CHECK: sub.s32
25 ; CHECK: shl.b64
26 ; CHECK: or.b64
27 ; CHECK: add.s32
28 ; CHECK: shr.s64
29 ; CHECK: setp.gt.s32
30 ; CHECK: selp.b64
31 ; CHECK: shr.s64
32 %amt = load i128* %amtptr
33 %a = load i128* %val
34 %val0 = ashr i128 %a, %amt
35 store i128 %val0, i128* %val
36 ret void
37 }