llvm.org GIT mirror llvm / 10da165
[NVPTX] Implement fma and imad contraction as target DAGCombiner patterns This also introduces DAGCombiner patterns for mul.wide to multiply two smaller integers and produce a larger integer git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211935 91177308-0d34-0410-b5e6-96231b3b80d8 Justin Holewinski 6 years ago
6 changed file(s) with 595 addition(s) and 126 deletion(s). Raw diff Collapse all Expand all
2323
2424 #define DEBUG_TYPE "nvptx-isel"
2525
26 static cl::opt
27 FMAContractLevel("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden,
28 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
29 " 1: do it 2: do it aggressively"),
30 cl::init(2));
26 unsigned FMAContractLevel = 0;
27
28 static cl::opt
29 FMAContractLevelOpt("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden,
30 cl::desc("NVPTX Specific: FMA contraction (0: don't do it"
31 " 1: do it 2: do it aggressively"),
32 cl::location(FMAContractLevel),
33 cl::init(2));
3134
3235 static cl::opt UsePrecDivF32(
3336 "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden,
241241 setOperationAction(ISD::CTPOP, MVT::i16, Legal);
242242 setOperationAction(ISD::CTPOP, MVT::i32, Legal);
243243 setOperationAction(ISD::CTPOP, MVT::i64, Legal);
244
245 // We have some custom DAG combine patterns for these nodes
246 setTargetDAGCombine(ISD::ADD);
247 setTargetDAGCombine(ISD::AND);
248 setTargetDAGCombine(ISD::FADD);
249 setTargetDAGCombine(ISD::MUL);
250 setTargetDAGCombine(ISD::SHL);
244251
245252 // Now deduce the information based on the above mentioned
246253 // actions
333340 return "NVPTXISD::StoreV2";
334341 case NVPTXISD::StoreV4:
335342 return "NVPTXISD::StoreV4";
343 case NVPTXISD::FUN_SHFL_CLAMP:
344 return "NVPTXISD::FUN_SHFL_CLAMP";
345 case NVPTXISD::FUN_SHFR_CLAMP:
346 return "NVPTXISD::FUN_SHFR_CLAMP";
336347 case NVPTXISD::Tex1DFloatI32: return "NVPTXISD::Tex1DFloatI32";
337348 case NVPTXISD::Tex1DFloatFloat: return "NVPTXISD::Tex1DFloatFloat";
338349 case NVPTXISD::Tex1DFloatFloatLevel:
24742485 return 4;
24752486 }
24762487
2488 //===----------------------------------------------------------------------===//
2489 // NVPTX DAG Combining
2490 //===----------------------------------------------------------------------===//
2491
2492 extern unsigned FMAContractLevel;
2493
2494 /// PerformADDCombineWithOperands - Try DAG combinations for an ADD with
2495 /// operands N0 and N1. This is a helper for PerformADDCombine that is
2496 /// called with the default operands, and if that fails, with commuted
2497 /// operands.
2498 static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
2499 TargetLowering::DAGCombinerInfo &DCI,
2500 const NVPTXSubtarget &Subtarget,
2501 CodeGenOpt::Level OptLevel) {
2502 SelectionDAG &DAG = DCI.DAG;
2503 // Skip non-integer, non-scalar case
2504 EVT VT=N0.getValueType();
2505 if (VT.isVector())
2506 return SDValue();
2507
2508 // fold (add (mul a, b), c) -> (mad a, b, c)
2509 //
2510 if (N0.getOpcode() == ISD::MUL) {
2511 assert (VT.isInteger());
2512 // For integer:
2513 // Since integer multiply-add costs the same as integer multiply
2514 // but is more costly than integer add, do the fusion only when
2515 // the mul is only used in the add.
2516 if (OptLevel==CodeGenOpt::None || VT != MVT::i32 ||
2517 !N0.getNode()->hasOneUse())
2518 return SDValue();
2519
2520 // Do the folding
2521 return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
2522 N0.getOperand(0), N0.getOperand(1), N1);
2523 }
2524 else if (N0.getOpcode() == ISD::FMUL) {
2525 if (VT == MVT::f32 || VT == MVT::f64) {
2526 if (FMAContractLevel == 0)
2527 return SDValue();
2528
2529 // For floating point:
2530 // Do the fusion only when the mul has less than 5 uses and all
2531 // are add.
2532 // The heuristic is that if a use is not an add, then that use
2533 // cannot be fused into fma, therefore mul is still needed anyway.
2534 // If there are more than 4 uses, even if they are all add, fusing
2535 // them will increase register pressue.
2536 //
2537 int numUses = 0;
2538 int nonAddCount = 0;
2539 for (SDNode::use_iterator UI = N0.getNode()->use_begin(),
2540 UE = N0.getNode()->use_end();
2541 UI != UE; ++UI) {
2542 numUses++;
2543 SDNode *User = *UI;
2544 if (User->getOpcode() != ISD::FADD)
2545 ++nonAddCount;
2546 }
2547 if (numUses >= 5)
2548 return SDValue();
2549 if (nonAddCount) {
2550 int orderNo = N->getIROrder();
2551 int orderNo2 = N0.getNode()->getIROrder();
2552 // simple heuristics here for considering potential register
2553 // pressure, the logics here is that the differnce are used
2554 // to measure the distance between def and use, the longer distance
2555 // more likely cause register pressure.
2556 if (orderNo - orderNo2 < 500)
2557 return SDValue();
2558
2559 // Now, check if at least one of the FMUL's operands is live beyond the node N,
2560 // which guarantees that the FMA will not increase register pressure at node N.
2561 bool opIsLive = false;
2562 const SDNode *left = N0.getOperand(0).getNode();
2563 const SDNode *right = N0.getOperand(1).getNode();
2564
2565 if (dyn_cast(left) || dyn_cast(right))
2566 opIsLive = true;
2567
2568 if (!opIsLive)
2569 for (SDNode::use_iterator UI = left->use_begin(), UE = left->use_end(); UI != UE; ++UI) {
2570 SDNode *User = *UI;
2571 int orderNo3 = User->getIROrder();
2572 if (orderNo3 > orderNo) {
2573 opIsLive = true;
2574 break;
2575 }
2576 }
2577
2578 if (!opIsLive)
2579 for (SDNode::use_iterator UI = right->use_begin(), UE = right->use_end(); UI != UE; ++UI) {
2580 SDNode *User = *UI;
2581 int orderNo3 = User->getIROrder();
2582 if (orderNo3 > orderNo) {
2583 opIsLive = true;
2584 break;
2585 }
2586 }
2587
2588 if (!opIsLive)
2589 return SDValue();
2590 }
2591
2592 return DAG.getNode(ISD::FMA, SDLoc(N), VT,
2593 N0.getOperand(0), N0.getOperand(1), N1);
2594 }
2595 }
2596
2597 return SDValue();
2598 }
2599
2600 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
2601 ///
2602 static SDValue PerformADDCombine(SDNode *N,
2603 TargetLowering::DAGCombinerInfo &DCI,
2604 const NVPTXSubtarget &Subtarget,
2605 CodeGenOpt::Level OptLevel) {
2606 SDValue N0 = N->getOperand(0);
2607 SDValue N1 = N->getOperand(1);
2608
2609 // First try with the default operand order.
2610 SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget,
2611 OptLevel);
2612 if (Result.getNode())
2613 return Result;
2614
2615 // If that didn't work, try again with the operands commuted.
2616 return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel);
2617 }
2618
2619 static SDValue PerformANDCombine(SDNode *N,
2620 TargetLowering::DAGCombinerInfo &DCI) {
2621 // The type legalizer turns a vector load of i8 values into a zextload to i16
2622 // registers, optionally ANY_EXTENDs it (if target type is integer),
2623 // and ANDs off the high 8 bits. Since we turn this load into a
2624 // target-specific DAG node, the DAG combiner fails to eliminate these AND
2625 // nodes. Do that here.
2626 SDValue Val = N->getOperand(0);
2627 SDValue Mask = N->getOperand(1);
2628
2629 if (isa(Val)) {
2630 std::swap(Val, Mask);
2631 }
2632
2633 SDValue AExt;
2634 // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
2635 if (Val.getOpcode() == ISD::ANY_EXTEND) {
2636 AExt = Val;
2637 Val = Val->getOperand(0);
2638 }
2639
2640 if (Val->isMachineOpcode() && Val->getMachineOpcode() == NVPTX::IMOV16rr) {
2641 Val = Val->getOperand(0);
2642 }
2643
2644 if (Val->getOpcode() == NVPTXISD::LoadV2 ||
2645 Val->getOpcode() == NVPTXISD::LoadV4) {
2646 ConstantSDNode *MaskCnst = dyn_cast(Mask);
2647 if (!MaskCnst) {
2648 // Not an AND with a constant
2649 return SDValue();
2650 }
2651
2652 uint64_t MaskVal = MaskCnst->getZExtValue();
2653 if (MaskVal != 0xff) {
2654 // Not an AND that chops off top 8 bits
2655 return SDValue();
2656 }
2657
2658 MemSDNode *Mem = dyn_cast(Val);
2659 if (!Mem) {
2660 // Not a MemSDNode?!?
2661 return SDValue();
2662 }
2663
2664 EVT MemVT = Mem->getMemoryVT();
2665 if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
2666 // We only handle the i8 case
2667 return SDValue();
2668 }
2669
2670 unsigned ExtType =
2671 cast(Val->getOperand(Val->getNumOperands()-1))->
2672 getZExtValue();
2673 if (ExtType == ISD::SEXTLOAD) {
2674 // If for some reason the load is a sextload, the and is needed to zero
2675 // out the high 8 bits
2676 return SDValue();
2677 }
2678
2679 bool AddTo = false;
2680 if (AExt.getNode() != 0) {
2681 // Re-insert the ext as a zext.
2682 Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
2683 AExt.getValueType(), Val);
2684 AddTo = true;
2685 }
2686
2687 // If we get here, the AND is unnecessary. Just replace it with the load
2688 DCI.CombineTo(N, Val, AddTo);
2689 }
2690
2691 return SDValue();
2692 }
2693
2694 enum OperandSignedness {
2695 Signed = 0,
2696 Unsigned,
2697 Unknown
2698 };
2699
2700 /// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand
2701 /// that can be demoted to \p OptSize bits without loss of information. The
2702 /// signedness of the operand, if determinable, is placed in \p S.
2703 static bool IsMulWideOperandDemotable(SDValue Op,
2704 unsigned OptSize,
2705 OperandSignedness &S) {
2706 S = Unknown;
2707
2708 if (Op.getOpcode() == ISD::SIGN_EXTEND ||
2709 Op.getOpcode() == ISD::SIGN_EXTEND_INREG) {
2710 EVT OrigVT = Op.getOperand(0).getValueType();
2711 if (OrigVT.getSizeInBits() == OptSize) {
2712 S = Signed;
2713 return true;
2714 }
2715 } else if (Op.getOpcode() == ISD::ZERO_EXTEND) {
2716 EVT OrigVT = Op.getOperand(0).getValueType();
2717 if (OrigVT.getSizeInBits() == OptSize) {
2718 S = Unsigned;
2719 return true;
2720 }
2721 }
2722
2723 return false;
2724 }
2725
2726 /// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can
2727 /// be demoted to \p OptSize bits without loss of information. If the operands
2728 /// contain a constant, it should appear as the RHS operand. The signedness of
2729 /// the operands is placed in \p IsSigned.
2730 static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS,
2731 unsigned OptSize,
2732 bool &IsSigned) {
2733
2734 OperandSignedness LHSSign;
2735
2736 // The LHS operand must be a demotable op
2737 if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign))
2738 return false;
2739
2740 // We should have been able to determine the signedness from the LHS
2741 if (LHSSign == Unknown)
2742 return false;
2743
2744 IsSigned = (LHSSign == Signed);
2745
2746 // The RHS can be a demotable op or a constant
2747 if (ConstantSDNode *CI = dyn_cast(RHS)) {
2748 APInt Val = CI->getAPIntValue();
2749 if (LHSSign == Unsigned) {
2750 if (Val.isIntN(OptSize)) {
2751 return true;
2752 }
2753 return false;
2754 } else {
2755 if (Val.isSignedIntN(OptSize)) {
2756 return true;
2757 }
2758 return false;
2759 }
2760 } else {
2761 OperandSignedness RHSSign;
2762 if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign))
2763 return false;
2764
2765 if (LHSSign != RHSSign)
2766 return false;
2767
2768 return true;
2769 }
2770 }
2771
2772 /// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply
2773 /// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform
2774 /// works on both multiply DAG nodes and SHL DAG nodes with a constant shift
2775 /// amount.
2776 static SDValue TryMULWIDECombine(SDNode *N,
2777 TargetLowering::DAGCombinerInfo &DCI) {
2778 EVT MulType = N->getValueType(0);
2779 if (MulType != MVT::i32 && MulType != MVT::i64) {
2780 return SDValue();
2781 }
2782
2783 unsigned OptSize = MulType.getSizeInBits() >> 1;
2784 SDValue LHS = N->getOperand(0);
2785 SDValue RHS = N->getOperand(1);
2786
2787 // Canonicalize the multiply so the constant (if any) is on the right
2788 if (N->getOpcode() == ISD::MUL) {
2789 if (isa(LHS)) {
2790 std::swap(LHS, RHS);
2791 }
2792 }
2793
2794 // If we have a SHL, determine the actual multiply amount
2795 if (N->getOpcode() == ISD::SHL) {
2796 ConstantSDNode *ShlRHS = dyn_cast(RHS);
2797 if (!ShlRHS) {
2798 return SDValue();
2799 }
2800
2801 APInt ShiftAmt = ShlRHS->getAPIntValue();
2802 unsigned BitWidth = MulType.getSizeInBits();
2803 if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) {
2804 APInt MulVal = APInt(BitWidth, 1) << ShiftAmt;
2805 RHS = DCI.DAG.getConstant(MulVal, MulType);
2806 } else {
2807 return SDValue();
2808 }
2809 }
2810
2811 bool Signed;
2812 // Verify that our operands are demotable
2813 if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) {
2814 return SDValue();
2815 }
2816
2817 EVT DemotedVT;
2818 if (MulType == MVT::i32) {
2819 DemotedVT = MVT::i16;
2820 } else {
2821 DemotedVT = MVT::i32;
2822 }
2823
2824 // Truncate the operands to the correct size. Note that these are just for
2825 // type consistency and will (likely) be eliminated in later phases.
2826 SDValue TruncLHS =
2827 DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, LHS);
2828 SDValue TruncRHS =
2829 DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, RHS);
2830
2831 unsigned Opc;
2832 if (Signed) {
2833 Opc = NVPTXISD::MUL_WIDE_SIGNED;
2834 } else {
2835 Opc = NVPTXISD::MUL_WIDE_UNSIGNED;
2836 }
2837
2838 return DCI.DAG.getNode(Opc, SDLoc(N), MulType, TruncLHS, TruncRHS);
2839 }
2840
2841 /// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes.
2842 static SDValue PerformMULCombine(SDNode *N,
2843 TargetLowering::DAGCombinerInfo &DCI,
2844 CodeGenOpt::Level OptLevel) {
2845 if (OptLevel > 0) {
2846 // Try mul.wide combining at OptLevel > 0
2847 SDValue Ret = TryMULWIDECombine(N, DCI);
2848 if (Ret.getNode())
2849 return Ret;
2850 }
2851
2852 return SDValue();
2853 }
2854
2855 /// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes.
2856 static SDValue PerformSHLCombine(SDNode *N,
2857 TargetLowering::DAGCombinerInfo &DCI,
2858 CodeGenOpt::Level OptLevel) {
2859 if (OptLevel > 0) {
2860 // Try mul.wide combining at OptLevel > 0
2861 SDValue Ret = TryMULWIDECombine(N, DCI);
2862 if (Ret.getNode())
2863 return Ret;
2864 }
2865
2866 return SDValue();
2867 }
2868
2869 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
2870 DAGCombinerInfo &DCI) const {
2871 // FIXME: Get this from the DAG somehow
2872 CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
2873 switch (N->getOpcode()) {
2874 default: break;
2875 case ISD::ADD:
2876 case ISD::FADD:
2877 return PerformADDCombine(N, DCI, nvptxSubtarget, OptLevel);
2878 case ISD::MUL:
2879 return PerformMULCombine(N, DCI, OptLevel);
2880 case ISD::SHL:
2881 return PerformSHLCombine(N, DCI, OptLevel);
2882 case ISD::AND:
2883 return PerformANDCombine(N, DCI);
2884 }
2885 return SDValue();
2886 }
2887
24772888 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
24782889 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
24792890 SmallVectorImpl &Results) {
4848 CallSeqBegin,
4949 CallSeqEnd,
5050 CallPrototype,
51 MUL_WIDE_SIGNED,
52 MUL_WIDE_UNSIGNED,
53 IMAD,
5154 Dummy,
5255
5356 LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
257260
258261 void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results,
259262 SelectionDAG &DAG) const override;
263 SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
260264
261265 unsigned getArgumentAlignment(SDValue Callee, const ImmutableCallSite *CS,
262266 Type *Ty, unsigned Idx) const;
463463 return CurDAG->getTargetConstant(temp.shl(v), MVT::i16);
464464 }]>;
465465
466 def MULWIDES64 : NVPTXInst<(outs Int64Regs:$dst),
467 (ins Int32Regs:$a, Int32Regs:$b),
466 def MULWIDES64
467 : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
468 "mul.wide.s32 \t$dst, $a, $b;", []>;
469 def MULWIDES64Imm
470 : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
468471 "mul.wide.s32 \t$dst, $a, $b;", []>;
469 def MULWIDES64Imm : NVPTXInst<(outs Int64Regs:$dst),
470 (ins Int32Regs:$a, i64imm:$b),
472 def MULWIDES64Imm64
473 : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b),
471474 "mul.wide.s32 \t$dst, $a, $b;", []>;
472475
473 def MULWIDEU64 : NVPTXInst<(outs Int64Regs:$dst),
474 (ins Int32Regs:$a, Int32Regs:$b),
476 def MULWIDEU64
477 : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b),
478 "mul.wide.u32 \t$dst, $a, $b;", []>;
479 def MULWIDEU64Imm
480 : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b),
475481 "mul.wide.u32 \t$dst, $a, $b;", []>;
476 def MULWIDEU64Imm : NVPTXInst<(outs Int64Regs:$dst),
477 (ins Int32Regs:$a, i64imm:$b),
482 def MULWIDEU64Imm64
483 : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b),
478484 "mul.wide.u32 \t$dst, $a, $b;", []>;
479485
480 def MULWIDES32 : NVPTXInst<(outs Int32Regs:$dst),
481 (ins Int16Regs:$a, Int16Regs:$b),
486 def MULWIDES32
487 : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
482488 "mul.wide.s16 \t$dst, $a, $b;", []>;
483 def MULWIDES32Imm : NVPTXInst<(outs Int32Regs:$dst),
484 (ins Int16Regs:$a, i32imm:$b),
489 def MULWIDES32Imm
490 : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
491 "mul.wide.s16 \t$dst, $a, $b;", []>;
492 def MULWIDES32Imm32
493 : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b),
485494 "mul.wide.s16 \t$dst, $a, $b;", []>;
486495
487 def MULWIDEU32 : NVPTXInst<(outs Int32Regs:$dst),
488 (ins Int16Regs:$a, Int16Regs:$b),
496 def MULWIDEU32
497 : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b),
498 "mul.wide.u16 \t$dst, $a, $b;", []>;
499 def MULWIDEU32Imm
500 : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b),
489501 "mul.wide.u16 \t$dst, $a, $b;", []>;
490 def MULWIDEU32Imm : NVPTXInst<(outs Int32Regs:$dst),
491 (ins Int16Regs:$a, i32imm:$b),
492 "mul.wide.u16 \t$dst, $a, $b;", []>;
502 def MULWIDEU32Imm32
503 : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b),
504 "mul.wide.u16 \t$dst, $a, $b;", []>;
493505
494506 def : Pat<(shl (sext Int32Regs:$a), (i32 Int5Const:$b)),
495507 (MULWIDES64Imm Int32Regs:$a, (SHL2MUL32 node:$b))>,
509521 (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>,
510522 Requires<[doMulWide]>;
511523 def : Pat<(mul (sext Int32Regs:$a), (i64 SInt32Const:$b)),
512 (MULWIDES64Imm Int32Regs:$a, (i64 SInt32Const:$b))>,
524 (MULWIDES64Imm64 Int32Regs:$a, (i64 SInt32Const:$b))>,
513525 Requires<[doMulWide]>;
514526
515527 def : Pat<(mul (zext Int32Regs:$a), (zext Int32Regs:$b)),
516 (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, Requires<[doMulWide]>;
528 (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>,
529 Requires<[doMulWide]>;
517530 def : Pat<(mul (zext Int32Regs:$a), (i64 UInt32Const:$b)),
518 (MULWIDEU64Imm Int32Regs:$a, (i64 UInt32Const:$b))>,
531 (MULWIDEU64Imm64 Int32Regs:$a, (i64 UInt32Const:$b))>,
519532 Requires<[doMulWide]>;
520533
521534 def : Pat<(mul (sext Int16Regs:$a), (sext Int16Regs:$b)),
522 (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>;
535 (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>,
536 Requires<[doMulWide]>;
523537 def : Pat<(mul (sext Int16Regs:$a), (i32 SInt16Const:$b)),
524 (MULWIDES32Imm Int16Regs:$a, (i32 SInt16Const:$b))>,
538 (MULWIDES32Imm32 Int16Regs:$a, (i32 SInt16Const:$b))>,
525539 Requires<[doMulWide]>;
526540
527541 def : Pat<(mul (zext Int16Regs:$a), (zext Int16Regs:$b)),
528 (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>;
542 (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>,
543 Requires<[doMulWide]>;
529544 def : Pat<(mul (zext Int16Regs:$a), (i32 UInt16Const:$b)),
530 (MULWIDEU32Imm Int16Regs:$a, (i32 UInt16Const:$b))>,
545 (MULWIDEU32Imm32 Int16Regs:$a, (i32 UInt16Const:$b))>,
546 Requires<[doMulWide]>;
547
548
549 def SDTMulWide
550 : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>;
551 def mul_wide_signed
552 : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>;
553 def mul_wide_unsigned
554 : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>;
555
556 def : Pat<(i32 (mul_wide_signed Int16Regs:$a, Int16Regs:$b)),
557 (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>,
558 Requires<[doMulWide]>;
559 def : Pat<(i32 (mul_wide_signed Int16Regs:$a, imm:$b)),
560 (MULWIDES32Imm Int16Regs:$a, imm:$b)>,
561 Requires<[doMulWide]>;
562 def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, Int16Regs:$b)),
563 (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>,
564 Requires<[doMulWide]>;
565 def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, imm:$b)),
566 (MULWIDEU32Imm Int16Regs:$a, imm:$b)>,
567 Requires<[doMulWide]>;
568
569
570 def : Pat<(i64 (mul_wide_signed Int32Regs:$a, Int32Regs:$b)),
571 (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>,
572 Requires<[doMulWide]>;
573 def : Pat<(i64 (mul_wide_signed Int32Regs:$a, imm:$b)),
574 (MULWIDES64Imm Int32Regs:$a, imm:$b)>,
575 Requires<[doMulWide]>;
576 def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, Int32Regs:$b)),
577 (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>,
578 Requires<[doMulWide]>;
579 def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, imm:$b)),
580 (MULWIDEU64Imm Int32Regs:$a, imm:$b)>,
531581 Requires<[doMulWide]>;
532582
533583 defm MULT : I3<"mul.lo.s", mul>;
543593 defm UREM : I3<"rem.u", urem>;
544594 // The ri version will not be selected as DAGCombiner::visitUREM will lower it.
545595
596 def SDTIMAD
597 : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>,
598 SDTCisInt<2>, SDTCisSameAs<0, 2>,
599 SDTCisSameAs<0, 3>]>;
600 def imad
601 : SDNode<"NVPTXISD::IMAD", SDTIMAD>;
602
546603 def MAD16rrr : NVPTXInst<(outs Int16Regs:$dst),
547604 (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
548605 "mad.lo.s16 \t$dst, $a, $b, $c;",
549 [(set Int16Regs:$dst, (add
550 (mul Int16Regs:$a, Int16Regs:$b), Int16Regs:$c))]>;
606 [(set Int16Regs:$dst,
607 (imad Int16Regs:$a, Int16Regs:$b, Int16Regs:$c))]>;
551608 def MAD16rri : NVPTXInst<(outs Int16Regs:$dst),
552609 (ins Int16Regs:$a, Int16Regs:$b, i16imm:$c),
553610 "mad.lo.s16 \t$dst, $a, $b, $c;",
554 [(set Int16Regs:$dst, (add
555 (mul Int16Regs:$a, Int16Regs:$b), imm:$c))]>;
611 [(set Int16Regs:$dst,
612 (imad Int16Regs:$a, Int16Regs:$b, imm:$c))]>;
556613 def MAD16rir : NVPTXInst<(outs Int16Regs:$dst),
557614 (ins Int16Regs:$a, i16imm:$b, Int16Regs:$c),
558615 "mad.lo.s16 \t$dst, $a, $b, $c;",
559 [(set Int16Regs:$dst, (add
560 (mul Int16Regs:$a, imm:$b), Int16Regs:$c))]>;
616 [(set Int16Regs:$dst,
617 (imad Int16Regs:$a, imm:$b, Int16Regs:$c))]>;
561618 def MAD16rii : NVPTXInst<(outs Int16Regs:$dst),
562619 (ins Int16Regs:$a, i16imm:$b, i16imm:$c),
563620 "mad.lo.s16 \t$dst, $a, $b, $c;",
564 [(set Int16Regs:$dst, (add (mul Int16Regs:$a, imm:$b),
565 imm:$c))]>;
621 [(set Int16Regs:$dst,
622 (imad Int16Regs:$a, imm:$b, imm:$c))]>;
566623
567624 def MAD32rrr : NVPTXInst<(outs Int32Regs:$dst),
568625 (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
569626 "mad.lo.s32 \t$dst, $a, $b, $c;",
570 [(set Int32Regs:$dst, (add
571 (mul Int32Regs:$a, Int32Regs:$b), Int32Regs:$c))]>;
627 [(set Int32Regs:$dst,
628 (imad Int32Regs:$a, Int32Regs:$b, Int32Regs:$c))]>;
572629 def MAD32rri : NVPTXInst<(outs Int32Regs:$dst),
573630 (ins Int32Regs:$a, Int32Regs:$b, i32imm:$c),
574631 "mad.lo.s32 \t$dst, $a, $b, $c;",
575 [(set Int32Regs:$dst, (add
576 (mul Int32Regs:$a, Int32Regs:$b), imm:$c))]>;
632 [(set Int32Regs:$dst,
633 (imad Int32Regs:$a, Int32Regs:$b, imm:$c))]>;
577634 def MAD32rir : NVPTXInst<(outs Int32Regs:$dst),
578635 (ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
579636 "mad.lo.s32 \t$dst, $a, $b, $c;",
580 [(set Int32Regs:$dst, (add
581 (mul Int32Regs:$a, imm:$b), Int32Regs:$c))]>;
637 [(set Int32Regs:$dst,
638 (imad Int32Regs:$a, imm:$b, Int32Regs:$c))]>;
582639 def MAD32rii : NVPTXInst<(outs Int32Regs:$dst),
583640 (ins Int32Regs:$a, i32imm:$b, i32imm:$c),
584641 "mad.lo.s32 \t$dst, $a, $b, $c;",
585 [(set Int32Regs:$dst, (add
586 (mul Int32Regs:$a, imm:$b), imm:$c))]>;
642 [(set Int32Regs:$dst,
643 (imad Int32Regs:$a, imm:$b, imm:$c))]>;
587644
588645 def MAD64rrr : NVPTXInst<(outs Int64Regs:$dst),
589646 (ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
590647 "mad.lo.s64 \t$dst, $a, $b, $c;",
591 [(set Int64Regs:$dst, (add
592 (mul Int64Regs:$a, Int64Regs:$b), Int64Regs:$c))]>;
648 [(set Int64Regs:$dst,
649 (imad Int64Regs:$a, Int64Regs:$b, Int64Regs:$c))]>;
593650 def MAD64rri : NVPTXInst<(outs Int64Regs:$dst),
594651 (ins Int64Regs:$a, Int64Regs:$b, i64imm:$c),
595652 "mad.lo.s64 \t$dst, $a, $b, $c;",
596 [(set Int64Regs:$dst, (add
597 (mul Int64Regs:$a, Int64Regs:$b), imm:$c))]>;
653 [(set Int64Regs:$dst,
654 (imad Int64Regs:$a, Int64Regs:$b, imm:$c))]>;
598655 def MAD64rir : NVPTXInst<(outs Int64Regs:$dst),
599656 (ins Int64Regs:$a, i64imm:$b, Int64Regs:$c),
600657 "mad.lo.s64 \t$dst, $a, $b, $c;",
601 [(set Int64Regs:$dst, (add
602 (mul Int64Regs:$a, imm:$b), Int64Regs:$c))]>;
658 [(set Int64Regs:$dst,
659 (imad Int64Regs:$a, imm:$b, Int64Regs:$c))]>;
603660 def MAD64rii : NVPTXInst<(outs Int64Regs:$dst),
604661 (ins Int64Regs:$a, i64imm:$b, i64imm:$c),
605662 "mad.lo.s64 \t$dst, $a, $b, $c;",
606 [(set Int64Regs:$dst, (add
607 (mul Int64Regs:$a, imm:$b), imm:$c))]>;
608
663 [(set Int64Regs:$dst,
664 (imad Int64Regs:$a, imm:$b, imm:$c))]>;
609665
610666 def INEG16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
611667 "neg.s16 \t$dst, $src;",
811867 def rrr : NVPTXInst<(outs Float32Regs:$dst),
812868 (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
813869 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
814 [(set Float32Regs:$dst, (fadd
815 (fmul Float32Regs:$a, Float32Regs:$b),
816 Float32Regs:$c))]>, Requires<[Pred]>;
817 // This is to WAR a weird bug in Tablegen that does not automatically
818 // generate the following permutated rule rrr2 from the above rrr.
819 // So we explicitly add it here. This happens to FMA32 only.
820 // See the comments at FMAD32 and FMA32 for more information.
821 def rrr2 : NVPTXInst<(outs Float32Regs:$dst),
822 (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
823 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
824 [(set Float32Regs:$dst, (fadd Float32Regs:$c,
825 (fmul Float32Regs:$a, Float32Regs:$b)))]>,
870 [(set Float32Regs:$dst,
871 (fma Float32Regs:$a, Float32Regs:$b, Float32Regs:$c))]>,
826872 Requires<[Pred]>;
827873 def rri : NVPTXInst<(outs Float32Regs:$dst),
828874 (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
829875 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
830 [(set Float32Regs:$dst, (fadd
831 (fmul Float32Regs:$a, Float32Regs:$b), fpimm:$c))]>,
876 [(set Float32Regs:$dst,
877 (fma Float32Regs:$a, Float32Regs:$b, fpimm:$c))]>,
832878 Requires<[Pred]>;
833879 def rir : NVPTXInst<(outs Float32Regs:$dst),
834880 (ins Float32Regs:$a, f32imm:$b, Float32Regs:$c),
835881 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
836 [(set Float32Regs:$dst, (fadd
837 (fmul Float32Regs:$a, fpimm:$b), Float32Regs:$c))]>,
882 [(set Float32Regs:$dst,
883 (fma Float32Regs:$a, fpimm:$b, Float32Regs:$c))]>,
838884 Requires<[Pred]>;
839885 def rii : NVPTXInst<(outs Float32Regs:$dst),
840886 (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
841887 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
842 [(set Float32Regs:$dst, (fadd
843 (fmul Float32Regs:$a, fpimm:$b), fpimm:$c))]>,
888 [(set Float32Regs:$dst,
889 (fma Float32Regs:$a, fpimm:$b, fpimm:$c))]>,
844890 Requires<[Pred]>;
845891 }
846892
848894 def rrr : NVPTXInst<(outs Float64Regs:$dst),
849895 (ins Float64Regs:$a, Float64Regs:$b, Float64Regs:$c),
850896 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
851 [(set Float64Regs:$dst, (fadd
852 (fmul Float64Regs:$a, Float64Regs:$b),
853 Float64Regs:$c))]>, Requires<[Pred]>;
897 [(set Float64Regs:$dst,
898 (fma Float64Regs:$a, Float64Regs:$b, Float64Regs:$c))]>,
899 Requires<[Pred]>;
854900 def rri : NVPTXInst<(outs Float64Regs:$dst),
855901 (ins Float64Regs:$a, Float64Regs:$b, f64imm:$c),
856902 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
857 [(set Float64Regs:$dst, (fadd (fmul Float64Regs:$a,
858 Float64Regs:$b), fpimm:$c))]>, Requires<[Pred]>;
903 [(set Float64Regs:$dst,
904 (fma Float64Regs:$a, Float64Regs:$b, fpimm:$c))]>,
905 Requires<[Pred]>;
859906 def rir : NVPTXInst<(outs Float64Regs:$dst),
860907 (ins Float64Regs:$a, f64imm:$b, Float64Regs:$c),
861908 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
862 [(set Float64Regs:$dst, (fadd
863 (fmul Float64Regs:$a, fpimm:$b), Float64Regs:$c))]>,
909 [(set Float64Regs:$dst,
910 (fma Float64Regs:$a, fpimm:$b, Float64Regs:$c))]>,
864911 Requires<[Pred]>;
865912 def rii : NVPTXInst<(outs Float64Regs:$dst),
866913 (ins Float64Regs:$a, f64imm:$b, f64imm:$c),
867914 !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
868 [(set Float64Regs:$dst, (fadd
869 (fmul Float64Regs:$a, fpimm:$b), fpimm:$c))]>,
915 [(set Float64Regs:$dst,
916 (fma Float64Regs:$a, fpimm:$b, fpimm:$c))]>,
870917 Requires<[Pred]>;
871918 }
872919
873 // Due to a unknown reason (most likely a bug in tablegen), tablegen does not
874 // automatically generate the rrr2 rule from
875 // the rrr rule (see FPCONTRACT32) for FMA32, though it does for FMAD32.
876 // If we reverse the order of the following two lines, then rrr2 rule will be
877 // generated for FMA32, but not for rrr.
878 // Therefore, we manually write the rrr2 rule in FPCONTRACT32.
879 defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doFMAF32_ftz>;
880 defm FMA32 : FPCONTRACT32<"fma.rn.f32", doFMAF32>;
881 defm FMA64 : FPCONTRACT64<"fma.rn.f64", doFMAF64>;
882
883 // b*c-a => fmad(b, c, -a)
884 multiclass FPCONTRACT32_SUB_PAT_MAD {
885 def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a),
886 (Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>,
887 Requires<[Pred]>;
888 }
889
890 // a-b*c => fmad(-b,c, a)
891 // - legal because a-b*c <=> a+(-b*c) <=> a+(-b)*c
892 // b*c-a => fmad(b, c, -a)
893 // - legal because b*c-a <=> b*c+(-a)
894 multiclass FPCONTRACT32_SUB_PAT {
895 def : Pat<(fsub Float32Regs:$a, (fmul Float32Regs:$b, Float32Regs:$c)),
896 (Inst (FNEGf32 Float32Regs:$b), Float32Regs:$c, Float32Regs:$a)>,
897 Requires<[Pred]>;
898 def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a),
899 (Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>,
900 Requires<[Pred]>;
901 }
902
903 // a-b*c => fmad(-b,c, a)
904 // b*c-a => fmad(b, c, -a)
905 multiclass FPCONTRACT64_SUB_PAT {
906 def : Pat<(fsub Float64Regs:$a, (fmul Float64Regs:$b, Float64Regs:$c)),
907 (Inst (FNEGf64 Float64Regs:$b), Float64Regs:$c, Float64Regs:$a)>,
908 Requires<[Pred]>;
909
910 def : Pat<(fsub (fmul Float64Regs:$b, Float64Regs:$c), Float64Regs:$a),
911 (Inst Float64Regs:$b, Float64Regs:$c, (FNEGf64 Float64Regs:$a))>,
912 Requires<[Pred]>;
913 }
914
915 defm FMAF32ext_ftz : FPCONTRACT32_SUB_PAT;
916 defm FMAF32ext : FPCONTRACT32_SUB_PAT;
917 defm FMAF64ext : FPCONTRACT64_SUB_PAT>;
920 defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doF32FTZ>;
921 defm FMA32 : FPCONTRACT32<"fma.rn.f32", doNoF32FTZ>;
922 defm FMA64 : FPCONTRACT64<"fma.rn.f64", doNoF32FTZ>;
918923
919924 def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
920925 "sin.approx.f32 \t$dst, $src;",
0 ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
1
2 ; CHECK: imad
3 define i32 @imad(i32 %a, i32 %b, i32 %c) {
4 ; CHECK: mad.lo.s32
5 %val0 = mul i32 %a, %b
6 %val1 = add i32 %val0, %c
7 ret i32 %val1
8 }
0 ; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
1
2 ; CHECK: mulwide16
3 define i32 @mulwide16(i16 %a, i16 %b) {
4 ; CHECK: mul.wide.s16
5 %val0 = sext i16 %a to i32
6 %val1 = sext i16 %b to i32
7 %val2 = mul i32 %val0, %val1
8 ret i32 %val2
9 }
10
11 ; CHECK: mulwideu16
12 define i32 @mulwideu16(i16 %a, i16 %b) {
13 ; CHECK: mul.wide.u16
14 %val0 = zext i16 %a to i32
15 %val1 = zext i16 %b to i32
16 %val2 = mul i32 %val0, %val1
17 ret i32 %val2
18 }
19
20 ; CHECK: mulwide32
21 define i64 @mulwide32(i32 %a, i32 %b) {
22 ; CHECK: mul.wide.s32
23 %val0 = sext i32 %a to i64
24 %val1 = sext i32 %b to i64
25 %val2 = mul i64 %val0, %val1
26 ret i64 %val2
27 }
28
29 ; CHECK: mulwideu32
30 define i64 @mulwideu32(i32 %a, i32 %b) {
31 ; CHECK: mul.wide.u32
32 %val0 = zext i32 %a to i64
33 %val1 = zext i32 %b to i64
34 %val2 = mul i64 %val0, %val1
35 ret i64 %val2
36 }