llvm.org GIT mirror llvm / 627d346
[NVPTX] Fix the codegen for llvm.round. Summary: Previously, we translate llvm.round to PTX cvt.rni, which rounds to the even interger when the source is equidistant between two integers. This is not correct as llvm.round should round away from zero. This change replaces llvm.round with a round away from zero implementation through target specific custom lowering. Modify a few affected tests to not check for cvt.rni. Instead, we check for the use of a few constants used in implementing round. We are also adding CUDA runnable tests to check for the values produced by llvm.round to test-suites/External/CUDA. Reviewers: tra Subscribers: jholewinski, sanjoy, jlebar, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D59947 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@357407 91177308-0d34-0410-b5e6-96231b3b80d8 Bixia Zheng 1 year, 8 months ago
6 changed file(s) with 128 addition(s) and 22 deletion(s). Raw diff Collapse all Expand all
545545
546546 // These map to conversion instructions for scalar FP types.
547547 for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
548 ISD::FROUND, ISD::FTRUNC}) {
548 ISD::FTRUNC}) {
549549 setOperationAction(Op, MVT::f16, Legal);
550550 setOperationAction(Op, MVT::f32, Legal);
551551 setOperationAction(Op, MVT::f64, Legal);
552552 setOperationAction(Op, MVT::v2f16, Expand);
553553 }
554
555 setOperationAction(ISD::FROUND, MVT::f16, Promote);
556 setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
557 setOperationAction(ISD::FROUND, MVT::f32, Custom);
558 setOperationAction(ISD::FROUND, MVT::f64, Custom);
559
554560
555561 // 'Expand' implements FCOPYSIGN without calling an external library.
556562 setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
20672073 }
20682074 }
20692075
2076 SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
2077 EVT VT = Op.getValueType();
2078
2079 if (VT == MVT::f32)
2080 return LowerFROUND32(Op, DAG);
2081
2082 if (VT == MVT::f64)
2083 return LowerFROUND64(Op, DAG);
2084
2085 llvm_unreachable("unhandled type");
2086 }
2087
2088 // This is the the rounding method used in CUDA libdevice in C like code:
2089 // float roundf(float A)
2090 // {
2091 // float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
2092 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2093 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2094 // }
2095 SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
2096 SelectionDAG &DAG) const {
2097 SDLoc SL(Op);
2098 SDValue A = Op.getOperand(0);
2099 EVT VT = Op.getValueType();
2100
2101 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2102
2103 // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
2104 SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
2105 const int SignBitMask = 0x80000000;
2106 SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
2107 DAG.getConstant(SignBitMask, SL, MVT::i32));
2108 const int PointFiveInBits = 0x3F000000;
2109 SDValue PointFiveWithSignRaw =
2110 DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
2111 DAG.getConstant(PointFiveInBits, SL, MVT::i32));
2112 SDValue PointFiveWithSign =
2113 DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
2114 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
2115 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2116
2117 // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
2118 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2119 SDValue IsLarge =
2120 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
2121 ISD::SETOGT);
2122 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2123
2124 // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
2125 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2126 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2127 SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
2128 return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
2129 }
2130
2131 // The implementation of round(double) is similar to that of round(float) in
2132 // that they both separate the value range into three regions and use a method
2133 // specific to the region to round the values. However, round(double) first
2134 // calculates the round of the absolute value and then adds the sign back while
2135 // round(float) directly rounds the value with sign.
2136 SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
2137 SelectionDAG &DAG) const {
2138 SDLoc SL(Op);
2139 SDValue A = Op.getOperand(0);
2140 EVT VT = Op.getValueType();
2141
2142 SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
2143
2144 // double RoundedA = (double) (int) (abs(A) + 0.5f);
2145 SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
2146 DAG.getConstantFP(0.5, SL, VT));
2147 SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
2148
2149 // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
2150 EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
2151 SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
2152 DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
2153 RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
2154 DAG.getConstantFP(0, SL, VT),
2155 RoundedA);
2156
2157 // Add sign to rounded_A
2158 RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
2159 DAG.getNode(ISD::FTRUNC, SL, VT, A);
2160
2161 // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
2162 SDValue IsLarge =
2163 DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
2164 ISD::SETOGT);
2165 return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
2166 }
2167
2168
2169
20702170 SDValue
20712171 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
20722172 switch (Op.getOpcode()) {
20972197 return LowerShiftRightParts(Op, DAG);
20982198 case ISD::SELECT:
20992199 return LowerSelect(Op, DAG);
2200 case ISD::FROUND:
2201 return LowerFROUND(Op, DAG);
21002202 default:
21012203 llvm_unreachable("Custom lowering not defined for operation");
21022204 }
555555 SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
556556 SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
557557
558 SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
559 SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
560 SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
561
558562 SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
559563 SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
560564
30013001 def : Pat<(ffloor Float64Regs:$a),
30023002 (CVT_f64_f64 Float64Regs:$a, CvtRMI)>;
30033003
3004 def : Pat<(f16 (fround Float16Regs:$a)),
3005 (CVT_f16_f16 Float16Regs:$a, CvtRNI)>;
3006 def : Pat<(fround Float32Regs:$a),
3007 (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
3008 def : Pat<(f32 (fround Float32Regs:$a)),
3009 (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
3010 def : Pat<(f64 (fround Float64Regs:$a)),
3011 (CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
3012
30133004 def : Pat<(ftrunc Float16Regs:$a),
30143005 (CVT_f16_f16 Float16Regs:$a, CvtRZI)>;
30153006 def : Pat<(ftrunc Float32Regs:$a),
11061106 }
11071107
11081108 ; CHECK-LABEL: test_round(
1109 ; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_round_param_0];
1110 ; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]];
1111 ; CHECK: st.param.b16 [func_retval0+0], [[R]];
1109 ; CHECK: ld.param.b16 {{.*}}, [test_round_param_0];
1110 ; check the use of sign mask and 0.5 to implement round
1111 ; CHECK: and.b32 [[R:%r[0-9]+]], {{.*}}, -2147483648;
1112 ; CHECK: or.b32 {{.*}}, [[R]], 1056964608;
1113 ; CHECK: st.param.b16 [func_retval0+0], {{.*}};
11121114 ; CHECK: ret;
11131115 define half @test_round(half %a) #0 {
11141116 %r = call half @llvm.round.f16(half %a)
13771377 }
13781378
13791379 ; CHECK-LABEL: test_round(
1380 ; CHECK: ld.param.b32 [[A:%hh[0-9]+]], [test_round_param_0];
1381 ; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]];
1382 ; CHECK-DAG: cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]];
1383 ; CHECK-DAG: cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]];
1384 ; CHECK: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
1385 ; CHECK: st.param.b32 [func_retval0+0], [[R]];
1380 ; CHECK: ld.param.b32 {{.*}}, [test_round_param_0];
1381 ; check the use of sign mask and 0.5 to implement round
1382 ; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
1383 ; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
1384 ; CHECK: and.b32 [[R2:%r[0-9]+]], {{.*}}, -2147483648;
1385 ; CHECK: or.b32 {{.*}}, [[R2]], 1056964608;
1386 ; CHECK: st.param.b32 [func_retval0+0], {{.*}};
13861387 ; CHECK: ret;
13871388 define <2 x half> @test_round(<2 x half> %a) #0 {
13881389 %r = call <2 x half> @llvm.round.f16(<2 x half> %a)
7373
7474 ; CHECK-LABEL: round_float
7575 define float @round_float(float %a) {
76 ; CHECK: cvt.rni.f32.f32
76 ; check the use of sign mask and 0.5 to implement round
77 ; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
78 ; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
7779 %b = call float @llvm.round.f32(float %a)
7880 ret float %b
7981 }
8082
8183 ; CHECK-LABEL: round_float_ftz
8284 define float @round_float_ftz(float %a) #1 {
83 ; CHECK: cvt.rni.ftz.f32.f32
85 ; check the use of sign mask and 0.5 to implement round
86 ; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
87 ; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
8488 %b = call float @llvm.round.f32(float %a)
8589 ret float %b
8690 }
8791
8892 ; CHECK-LABEL: round_double
8993 define double @round_double(double %a) {
90 ; CHECK: cvt.rni.f64.f64
94 ; check the use of 0.5 to implement round
95 ; CHECK: setp.lt.f64 {{.*}}, [[R:%fd[0-9]+]], 0d3FE0000000000000;
96 ; CHECK: add.rn.f64 {{.*}}, [[R]], 0d3FE0000000000000;
9197 %b = call double @llvm.round.f64(double %a)
9298 ret double %b
9399 }